mirror of
https://github.com/element-hq/synapse
synced 2024-06-30 12:33:33 +00:00
Merge branch 'develop' into shay/account_suspension_pt_2
This commit is contained in:
commit
9b56bf8497
|
@ -1,3 +1,10 @@
|
|||
# Synapse 1.108.0 (2024-05-28)
|
||||
|
||||
No significant changes since 1.108.0rc1.
|
||||
|
||||
|
||||
|
||||
|
||||
# Synapse 1.108.0rc1 (2024-05-21)
|
||||
|
||||
### Features
|
||||
|
|
8
Cargo.lock
generated
8
Cargo.lock
generated
|
@ -485,18 +485,18 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
|||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.202"
|
||||
version = "1.0.203"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "226b61a0d411b2ba5ff6d7f73a476ac4f8bb900373459cd00fab8512828ba395"
|
||||
checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.202"
|
||||
version = "1.0.203"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6048858004bcff69094cd972ed40a32500f153bd3be9f716b2eed2e8217c4838"
|
||||
checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
|
1
changelog.d/17083.misc
Normal file
1
changelog.d/17083.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Improve DB usage when fetching related events.
|
1
changelog.d/17164.bugfix
Normal file
1
changelog.d/17164.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix deduplicating of membership events to not create unused state groups.
|
1
changelog.d/17167.feature
Normal file
1
changelog.d/17167.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync/e2ee` endpoint for To-Device messages and device encryption info.
|
1
changelog.d/17176.misc
Normal file
1
changelog.d/17176.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Log exceptions when failing to auto-join new user according to the `auto_join_rooms` option.
|
1
changelog.d/17204.doc
Normal file
1
changelog.d/17204.doc
Normal file
|
@ -0,0 +1 @@
|
|||
Update OIDC documentation: by default Matrix doesn't query userinfo endpoint, then claims should be put on id_token.
|
1
changelog.d/17211.misc
Normal file
1
changelog.d/17211.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Reduce work of calculating outbound device lists updates.
|
1
changelog.d/17213.feature
Normal file
1
changelog.d/17213.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Support MSC3916 by adding unstable media endpoints to `_matrix/client` (#17213).
|
1
changelog.d/17215.bugfix
Normal file
1
changelog.d/17215.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix bug where duplicate events could be sent down sync when using workers that are overloaded.
|
1
changelog.d/17219.feature
Normal file
1
changelog.d/17219.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add logging to tasks managed by the task scheduler, showing CPU and database usage.
|
1
changelog.d/17226.misc
Normal file
1
changelog.d/17226.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Move towards using `MultiWriterIdGenerator` everywhere.
|
1
changelog.d/17229.misc
Normal file
1
changelog.d/17229.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Replaces all usages of `StreamIdGenerator` with `MultiWriterIdGenerator`.
|
1
changelog.d/17238.misc
Normal file
1
changelog.d/17238.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Change the `allow_unsafe_locale` config option to also apply when setting up new databases.
|
1
changelog.d/17239.misc
Normal file
1
changelog.d/17239.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Fix errors in logs about closing incorrect logging contexts when media gets rejected by a module.
|
1
changelog.d/17240.bugfix
Normal file
1
changelog.d/17240.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Ignore attempts to send to-device messages to bad users, to avoid log spam when we try to connect to the bad server.
|
1
changelog.d/17241.bugfix
Normal file
1
changelog.d/17241.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix handling of duplicate concurrent uploading of device one-time-keys.
|
1
changelog.d/17242.misc
Normal file
1
changelog.d/17242.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Clean out invalid destinations from `device_federation_outbox` table.
|
1
changelog.d/17246.misc
Normal file
1
changelog.d/17246.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Fix errors in logs about closing incorrect logging contexts when media gets rejected by a module.
|
1
changelog.d/17250.misc
Normal file
1
changelog.d/17250.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Stop logging errors when receiving invalid User IDs in key querys requests.
|
1
changelog.d/17251.bugfix
Normal file
1
changelog.d/17251.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix reporting of default tags to Sentry, such as worker name. Broke in v1.108.0.
|
1
changelog.d/17252.bugfix
Normal file
1
changelog.d/17252.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix bug where typing updates would not be sent when using workers after a restart.
|
6
debian/changelog
vendored
6
debian/changelog
vendored
|
@ -1,3 +1,9 @@
|
|||
matrix-synapse-py3 (1.108.0) stable; urgency=medium
|
||||
|
||||
* New Synapse release 1.108.0.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Tue, 28 May 2024 11:54:22 +0100
|
||||
|
||||
matrix-synapse-py3 (1.108.0~rc1) stable; urgency=medium
|
||||
|
||||
* New Synapse release 1.108.0rc1.
|
||||
|
|
|
@ -525,6 +525,8 @@ oidc_providers:
|
|||
(`Options > Security > ID Token signature algorithm` and `Options > Security >
|
||||
Access Token signature algorithm`)
|
||||
- Scopes: OpenID, Email and Profile
|
||||
- Force claims into `id_token`
|
||||
(`Options > Advanced > Force claims to be returned in ID Token`)
|
||||
- Allowed redirection addresses for login (`Options > Basic > Allowed
|
||||
redirection addresses for login` ) :
|
||||
`[synapse public baseurl]/_synapse/client/oidc/callback`
|
||||
|
|
|
@ -242,12 +242,11 @@ host all all ::1/128 ident
|
|||
|
||||
### Fixing incorrect `COLLATE` or `CTYPE`
|
||||
|
||||
Synapse will refuse to set up a new database if it has the wrong values of
|
||||
`COLLATE` and `CTYPE` set. Synapse will also refuse to start an existing database with incorrect values
|
||||
of `COLLATE` and `CTYPE` unless the config flag `allow_unsafe_locale`, found in the
|
||||
`database` section of the config, is set to true. Using different locales can cause issues if the locale library is updated from
|
||||
underneath the database, or if a different version of the locale is used on any
|
||||
replicas.
|
||||
Synapse will refuse to start when using a database with incorrect values of
|
||||
`COLLATE` and `CTYPE` unless the config flag `allow_unsafe_locale`, found in the
|
||||
`database` section of the config, is set to true. Using different locales can
|
||||
cause issues if the locale library is updated from underneath the database, or
|
||||
if a different version of the locale is used on any replicas.
|
||||
|
||||
If you have a database with an unsafe locale, the safest way to fix the issue is to dump the database and recreate it with
|
||||
the correct locale parameter (as shown above). It is also possible to change the
|
||||
|
|
24
poetry.lock
generated
24
poetry.lock
generated
|
@ -1536,13 +1536,13 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "phonenumbers"
|
||||
version = "8.13.35"
|
||||
version = "8.13.37"
|
||||
description = "Python version of Google's common library for parsing, formatting, storing and validating international phone numbers."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "phonenumbers-8.13.35-py2.py3-none-any.whl", hash = "sha256:58286a8e617bd75f541e04313b28c36398be6d4443a778c85e9617a93c391310"},
|
||||
{file = "phonenumbers-8.13.35.tar.gz", hash = "sha256:64f061a967dcdae11e1c59f3688649e697b897110a33bb74d5a69c3e35321245"},
|
||||
{file = "phonenumbers-8.13.37-py2.py3-none-any.whl", hash = "sha256:4ea00ef5012422c08c7955c21131e7ae5baa9a3ef52cf2d561e963f023006b80"},
|
||||
{file = "phonenumbers-8.13.37.tar.gz", hash = "sha256:bd315fed159aea0516f7c367231810fe8344d5bec26156b88fa18374c11d1cf2"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1673,13 +1673,13 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.2.2)", "pytest (>=7.2.1)", "pytes
|
|||
|
||||
[[package]]
|
||||
name = "prometheus-client"
|
||||
version = "0.19.0"
|
||||
version = "0.20.0"
|
||||
description = "Python client for the Prometheus monitoring system."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "prometheus_client-0.19.0-py3-none-any.whl", hash = "sha256:c88b1e6ecf6b41cd8fb5731c7ae919bf66df6ec6fafa555cd6c0e16ca169ae92"},
|
||||
{file = "prometheus_client-0.19.0.tar.gz", hash = "sha256:4585b0d1223148c27a225b10dbec5ae9bc4c81a99a3fa80774fa6209935324e1"},
|
||||
{file = "prometheus_client-0.20.0-py3-none-any.whl", hash = "sha256:cde524a85bce83ca359cc837f28b8c0db5cac7aa653a588fd7e84ba061c329e7"},
|
||||
{file = "prometheus_client-0.20.0.tar.gz", hash = "sha256:287629d00b147a32dcb2be0b9df905da599b2d82f80377083ec8463309a4bb89"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
|
@ -1915,12 +1915,12 @@ plugins = ["importlib-metadata"]
|
|||
|
||||
[[package]]
|
||||
name = "pyicu"
|
||||
version = "2.13"
|
||||
version = "2.13.1"
|
||||
description = "Python extension wrapping the ICU C++ API"
|
||||
optional = true
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "PyICU-2.13.tar.gz", hash = "sha256:d481be888975df3097c2790241bbe8518f65c9676a74957cdbe790e559c828f6"},
|
||||
{file = "PyICU-2.13.1.tar.gz", hash = "sha256:d4919085eaa07da12bade8ee721e7bbf7ade0151ca0f82946a26c8f4b98cdceb"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1997,13 +1997,13 @@ tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"]
|
|||
|
||||
[[package]]
|
||||
name = "pyopenssl"
|
||||
version = "24.0.0"
|
||||
version = "24.1.0"
|
||||
description = "Python wrapper module around the OpenSSL library"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "pyOpenSSL-24.0.0-py3-none-any.whl", hash = "sha256:ba07553fb6fd6a7a2259adb9b84e12302a9a8a75c44046e8bb5d3e5ee887e3c3"},
|
||||
{file = "pyOpenSSL-24.0.0.tar.gz", hash = "sha256:6aa33039a93fffa4563e655b61d11364d01264be8ccb49906101e02a334530bf"},
|
||||
{file = "pyOpenSSL-24.1.0-py3-none-any.whl", hash = "sha256:17ed5be5936449c5418d1cd269a1a9e9081bc54c17aed272b45856a3d3dc86ad"},
|
||||
{file = "pyOpenSSL-24.1.0.tar.gz", hash = "sha256:cabed4bfaa5df9f1a16c0ef64a0cb65318b5cd077a7eda7d6970131ca2f41a6f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -2011,7 +2011,7 @@ cryptography = ">=41.0.5,<43"
|
|||
|
||||
[package.extras]
|
||||
docs = ["sphinx (!=5.2.0,!=5.2.0.post0,!=7.2.5)", "sphinx-rtd-theme"]
|
||||
test = ["flaky", "pretend", "pytest (>=3.0.1)"]
|
||||
test = ["pretend", "pytest (>=3.0.1)", "pytest-rerunfailures"]
|
||||
|
||||
[[package]]
|
||||
name = "pysaml2"
|
||||
|
|
|
@ -96,7 +96,7 @@ module-name = "synapse.synapse_rust"
|
|||
|
||||
[tool.poetry]
|
||||
name = "matrix-synapse"
|
||||
version = "1.108.0rc1"
|
||||
version = "1.108.0"
|
||||
description = "Homeserver for the Matrix decentralised comms protocol"
|
||||
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
|
||||
license = "AGPL-3.0-or-later"
|
||||
|
@ -200,10 +200,8 @@ netaddr = ">=0.7.18"
|
|||
# add a lower bound to the Jinja2 dependency.
|
||||
Jinja2 = ">=3.0"
|
||||
bleach = ">=1.4.3"
|
||||
# We use `ParamSpec` and `Concatenate`, which were added in `typing-extensions` 3.10.0.0.
|
||||
# Additionally we need https://github.com/python/typing/pull/817 to allow types to be
|
||||
# generic over ParamSpecs.
|
||||
typing-extensions = ">=3.10.0.1"
|
||||
# We use `Self`, which were added in `typing-extensions` 4.0.
|
||||
typing-extensions = ">=4.0"
|
||||
# We enforce that we have a `cryptography` version that bundles an `openssl`
|
||||
# with the latest security patches.
|
||||
cryptography = ">=3.4.7"
|
||||
|
|
|
@ -777,22 +777,74 @@ class Porter:
|
|||
await self._setup_events_stream_seqs()
|
||||
await self._setup_sequence(
|
||||
"un_partial_stated_event_stream_sequence",
|
||||
("un_partial_stated_event_stream",),
|
||||
[("un_partial_stated_event_stream", "stream_id")],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"device_inbox_sequence", ("device_inbox", "device_federation_outbox")
|
||||
"device_inbox_sequence",
|
||||
[
|
||||
("device_inbox", "stream_id"),
|
||||
("device_federation_outbox", "stream_id"),
|
||||
],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"account_data_sequence",
|
||||
("room_account_data", "room_tags_revisions", "account_data"),
|
||||
[
|
||||
("room_account_data", "stream_id"),
|
||||
("room_tags_revisions", "stream_id"),
|
||||
("account_data", "stream_id"),
|
||||
],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"receipts_sequence",
|
||||
[
|
||||
("receipts_linearized", "stream_id"),
|
||||
],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"presence_stream_sequence",
|
||||
[
|
||||
("presence_stream", "stream_id"),
|
||||
],
|
||||
)
|
||||
await self._setup_sequence("receipts_sequence", ("receipts_linearized",))
|
||||
await self._setup_sequence("presence_stream_sequence", ("presence_stream",))
|
||||
await self._setup_auth_chain_sequence()
|
||||
await self._setup_sequence(
|
||||
"application_services_txn_id_seq",
|
||||
("application_services_txns",),
|
||||
"txn_id",
|
||||
[
|
||||
(
|
||||
"application_services_txns",
|
||||
"txn_id",
|
||||
)
|
||||
],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"device_lists_sequence",
|
||||
[
|
||||
("device_lists_stream", "stream_id"),
|
||||
("user_signature_stream", "stream_id"),
|
||||
("device_lists_outbound_pokes", "stream_id"),
|
||||
("device_lists_changes_in_room", "stream_id"),
|
||||
("device_lists_remote_pending", "stream_id"),
|
||||
("device_lists_changes_converted_stream_position", "stream_id"),
|
||||
],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"e2e_cross_signing_keys_sequence",
|
||||
[
|
||||
("e2e_cross_signing_keys", "stream_id"),
|
||||
],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"push_rules_stream_sequence",
|
||||
[
|
||||
("push_rules_stream", "stream_id"),
|
||||
],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"pushers_sequence",
|
||||
[
|
||||
("pushers", "id"),
|
||||
("deleted_pushers", "stream_id"),
|
||||
],
|
||||
)
|
||||
|
||||
# Step 3. Get tables.
|
||||
|
@ -1101,12 +1153,11 @@ class Porter:
|
|||
async def _setup_sequence(
|
||||
self,
|
||||
sequence_name: str,
|
||||
stream_id_tables: Iterable[str],
|
||||
column_name: str = "stream_id",
|
||||
stream_id_tables: Iterable[Tuple[str, str]],
|
||||
) -> None:
|
||||
"""Set a sequence to the correct value."""
|
||||
current_stream_ids = []
|
||||
for stream_id_table in stream_id_tables:
|
||||
for stream_id_table, column_name in stream_id_tables:
|
||||
max_stream_id = cast(
|
||||
int,
|
||||
await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||
|
|
|
@ -681,17 +681,17 @@ def setup_sentry(hs: "HomeServer") -> None:
|
|||
)
|
||||
|
||||
# We set some default tags that give some context to this instance
|
||||
with sentry_sdk.configure_scope() as scope:
|
||||
scope.set_tag("matrix_server_name", hs.config.server.server_name)
|
||||
global_scope = sentry_sdk.Scope.get_global_scope()
|
||||
global_scope.set_tag("matrix_server_name", hs.config.server.server_name)
|
||||
|
||||
app = (
|
||||
hs.config.worker.worker_app
|
||||
if hs.config.worker.worker_app
|
||||
else "synapse.app.homeserver"
|
||||
)
|
||||
name = hs.get_instance_name()
|
||||
scope.set_tag("worker_app", app)
|
||||
scope.set_tag("worker_name", name)
|
||||
app = (
|
||||
hs.config.worker.worker_app
|
||||
if hs.config.worker.worker_app
|
||||
else "synapse.app.homeserver"
|
||||
)
|
||||
name = hs.get_instance_name()
|
||||
global_scope.set_tag("worker_app", app)
|
||||
global_scope.set_tag("worker_name", name)
|
||||
|
||||
|
||||
def setup_sdnotify(hs: "HomeServer") -> None:
|
||||
|
|
|
@ -332,6 +332,9 @@ class ExperimentalConfig(Config):
|
|||
# MSC3391: Removing account data.
|
||||
self.msc3391_enabled = experimental.get("msc3391_enabled", False)
|
||||
|
||||
# MSC3575 (Sliding Sync API endpoints)
|
||||
self.msc3575_enabled: bool = experimental.get("msc3575_enabled", False)
|
||||
|
||||
# MSC3773: Thread notifications
|
||||
self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False)
|
||||
|
||||
|
@ -440,3 +443,7 @@ class ExperimentalConfig(Config):
|
|||
self.msc3823_account_suspension = experimental.get(
|
||||
"msc3823_account_suspension", False
|
||||
)
|
||||
|
||||
self.msc3916_authenticated_media_enabled = experimental.get(
|
||||
"msc3916_authenticated_media_enabled", False
|
||||
)
|
||||
|
|
|
@ -906,6 +906,13 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
context=opentracing_context,
|
||||
)
|
||||
|
||||
await self.store.mark_redundant_device_lists_pokes(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
room_id=room_id,
|
||||
converted_upto_stream_id=stream_id,
|
||||
)
|
||||
|
||||
# Notify replication that we've updated the device list stream.
|
||||
self.notifier.notify_replication()
|
||||
|
||||
|
|
|
@ -236,6 +236,13 @@ class DeviceMessageHandler:
|
|||
local_messages = {}
|
||||
remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||
for user_id, by_device in messages.items():
|
||||
if not UserID.is_valid(user_id):
|
||||
logger.warning(
|
||||
"Ignoring attempt to send device message to invalid user: %r",
|
||||
user_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# add an opentracing log entry for each message
|
||||
for device_id, message_content in by_device.items():
|
||||
log_kv(
|
||||
|
|
|
@ -53,6 +53,9 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ONE_TIME_KEY_UPLOAD = "one_time_key_upload_lock"
|
||||
|
||||
|
||||
class E2eKeysHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.config = hs.config
|
||||
|
@ -62,6 +65,7 @@ class E2eKeysHandler:
|
|||
self._appservice_handler = hs.get_application_service_handler()
|
||||
self.is_mine = hs.is_mine
|
||||
self.clock = hs.get_clock()
|
||||
self._worker_lock_handler = hs.get_worker_locks_handler()
|
||||
|
||||
federation_registry = hs.get_federation_registry()
|
||||
|
||||
|
@ -145,6 +149,11 @@ class E2eKeysHandler:
|
|||
remote_queries = {}
|
||||
|
||||
for user_id, device_ids in device_keys_query.items():
|
||||
if not UserID.is_valid(user_id):
|
||||
# Ignore invalid user IDs, which is the same behaviour as if
|
||||
# the user existed but had no keys.
|
||||
continue
|
||||
|
||||
# we use UserID.from_string to catch invalid user ids
|
||||
if self.is_mine(UserID.from_string(user_id)):
|
||||
local_query[user_id] = device_ids
|
||||
|
@ -855,45 +864,53 @@ class E2eKeysHandler:
|
|||
async def _upload_one_time_keys_for_user(
|
||||
self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
|
||||
) -> None:
|
||||
logger.info(
|
||||
"Adding one_time_keys %r for device %r for user %r at %d",
|
||||
one_time_keys.keys(),
|
||||
device_id,
|
||||
user_id,
|
||||
time_now,
|
||||
)
|
||||
# We take out a lock so that we don't have to worry about a client
|
||||
# sending duplicate requests.
|
||||
lock_key = f"{user_id}_{device_id}"
|
||||
async with self._worker_lock_handler.acquire_lock(
|
||||
ONE_TIME_KEY_UPLOAD, lock_key
|
||||
):
|
||||
logger.info(
|
||||
"Adding one_time_keys %r for device %r for user %r at %d",
|
||||
one_time_keys.keys(),
|
||||
device_id,
|
||||
user_id,
|
||||
time_now,
|
||||
)
|
||||
|
||||
# make a list of (alg, id, key) tuples
|
||||
key_list = []
|
||||
for key_id, key_obj in one_time_keys.items():
|
||||
algorithm, key_id = key_id.split(":")
|
||||
key_list.append((algorithm, key_id, key_obj))
|
||||
# make a list of (alg, id, key) tuples
|
||||
key_list = []
|
||||
for key_id, key_obj in one_time_keys.items():
|
||||
algorithm, key_id = key_id.split(":")
|
||||
key_list.append((algorithm, key_id, key_obj))
|
||||
|
||||
# First we check if we have already persisted any of the keys.
|
||||
existing_key_map = await self.store.get_e2e_one_time_keys(
|
||||
user_id, device_id, [k_id for _, k_id, _ in key_list]
|
||||
)
|
||||
# First we check if we have already persisted any of the keys.
|
||||
existing_key_map = await self.store.get_e2e_one_time_keys(
|
||||
user_id, device_id, [k_id for _, k_id, _ in key_list]
|
||||
)
|
||||
|
||||
new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
|
||||
for algorithm, key_id, key in key_list:
|
||||
ex_json = existing_key_map.get((algorithm, key_id), None)
|
||||
if ex_json:
|
||||
if not _one_time_keys_match(ex_json, key):
|
||||
raise SynapseError(
|
||||
400,
|
||||
(
|
||||
"One time key %s:%s already exists. "
|
||||
"Old key: %s; new key: %r"
|
||||
new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
|
||||
for algorithm, key_id, key in key_list:
|
||||
ex_json = existing_key_map.get((algorithm, key_id), None)
|
||||
if ex_json:
|
||||
if not _one_time_keys_match(ex_json, key):
|
||||
raise SynapseError(
|
||||
400,
|
||||
(
|
||||
"One time key %s:%s already exists. "
|
||||
"Old key: %s; new key: %r"
|
||||
)
|
||||
% (algorithm, key_id, ex_json, key),
|
||||
)
|
||||
% (algorithm, key_id, ex_json, key),
|
||||
else:
|
||||
new_keys.append(
|
||||
(algorithm, key_id, encode_canonical_json(key).decode("ascii"))
|
||||
)
|
||||
else:
|
||||
new_keys.append(
|
||||
(algorithm, key_id, encode_canonical_json(key).decode("ascii"))
|
||||
)
|
||||
|
||||
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
|
||||
await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
|
||||
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
|
||||
await self.store.add_e2e_one_time_keys(
|
||||
user_id, device_id, time_now, new_keys
|
||||
)
|
||||
|
||||
async def upload_signing_keys_for_user(
|
||||
self, user_id: str, keys: JsonDict
|
||||
|
|
|
@ -496,13 +496,6 @@ class EventCreationHandler:
|
|||
|
||||
self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
|
||||
|
||||
self.membership_types_to_include_profile_data_in = {
|
||||
Membership.JOIN,
|
||||
Membership.KNOCK,
|
||||
}
|
||||
if self.hs.config.server.include_profile_data_on_invite:
|
||||
self.membership_types_to_include_profile_data_in.add(Membership.INVITE)
|
||||
|
||||
self.send_event = ReplicationSendEventRestServlet.make_client(hs)
|
||||
self.send_events = ReplicationSendEventsRestServlet.make_client(hs)
|
||||
|
||||
|
@ -594,8 +587,6 @@ class EventCreationHandler:
|
|||
Creates an FrozenEvent object, filling out auth_events, prev_events,
|
||||
etc.
|
||||
|
||||
Adds display names to Join membership events.
|
||||
|
||||
Args:
|
||||
requester
|
||||
event_dict: An entire event
|
||||
|
@ -683,29 +674,6 @@ class EventCreationHandler:
|
|||
|
||||
self.validator.validate_builder(builder)
|
||||
|
||||
if builder.type == EventTypes.Member:
|
||||
membership = builder.content.get("membership", None)
|
||||
target = UserID.from_string(builder.state_key)
|
||||
|
||||
if membership in self.membership_types_to_include_profile_data_in:
|
||||
# If event doesn't include a display name, add one.
|
||||
profile = self.profile_handler
|
||||
content = builder.content
|
||||
|
||||
try:
|
||||
if "displayname" not in content:
|
||||
displayname = await profile.get_displayname(target)
|
||||
if displayname is not None:
|
||||
content["displayname"] = displayname
|
||||
if "avatar_url" not in content:
|
||||
avatar_url = await profile.get_avatar_url(target)
|
||||
if avatar_url is not None:
|
||||
content["avatar_url"] = avatar_url
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Failed to get profile information for %r: %s", target, e
|
||||
)
|
||||
|
||||
is_exempt = await self._is_exempt_from_privacy_policy(builder, requester)
|
||||
if require_consent and not is_exempt:
|
||||
await self.assert_accepted_privacy_policy(requester)
|
||||
|
|
|
@ -590,7 +590,7 @@ class RegistrationHandler:
|
|||
# moving away from bare excepts is a good thing to do.
|
||||
logger.error("Failed to join new user to %r: %r", r, e)
|
||||
except Exception as e:
|
||||
logger.error("Failed to join new user to %r: %r", r, e)
|
||||
logger.error("Failed to join new user to %r: %r", r, e, exc_info=True)
|
||||
|
||||
async def _auto_join_rooms(self, user_id: str) -> None:
|
||||
"""Automatically joins users to auto join rooms - creating the room in the first place
|
||||
|
|
|
@ -393,9 +393,9 @@ class RelationsHandler:
|
|||
|
||||
# Attempt to find another event to use as the latest event.
|
||||
potential_events, _ = await self._main_store.get_relations_for_event(
|
||||
room_id,
|
||||
event_id,
|
||||
event,
|
||||
room_id,
|
||||
RelationTypes.THREAD,
|
||||
direction=Direction.FORWARDS,
|
||||
)
|
||||
|
|
|
@ -106,6 +106,13 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
self.event_auth_handler = hs.get_event_auth_handler()
|
||||
self._worker_lock_handler = hs.get_worker_locks_handler()
|
||||
|
||||
self._membership_types_to_include_profile_data_in = {
|
||||
Membership.JOIN,
|
||||
Membership.KNOCK,
|
||||
}
|
||||
if self.hs.config.server.include_profile_data_on_invite:
|
||||
self._membership_types_to_include_profile_data_in.add(Membership.INVITE)
|
||||
|
||||
self.member_linearizer: Linearizer = Linearizer(name="member")
|
||||
self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter")
|
||||
|
||||
|
@ -785,9 +792,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
if (
|
||||
not self.allow_per_room_profiles and not is_requester_server_notices_user
|
||||
) or requester.shadow_banned:
|
||||
# Strip profile data, knowing that new profile data will be added to the
|
||||
# event's content in event_creation_handler.create_event() using the target's
|
||||
# global profile.
|
||||
# Strip profile data, knowing that new profile data will be added to
|
||||
# the event's content below using the target's global profile.
|
||||
content.pop("displayname", None)
|
||||
content.pop("avatar_url", None)
|
||||
|
||||
|
@ -823,6 +829,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
if action in ["kick", "unban"]:
|
||||
effective_membership_state = "leave"
|
||||
|
||||
if effective_membership_state not in Membership.LIST:
|
||||
raise SynapseError(400, "Invalid membership key")
|
||||
|
||||
# Add profile data for joins etc, if no per-room profile.
|
||||
if (
|
||||
effective_membership_state
|
||||
in self._membership_types_to_include_profile_data_in
|
||||
):
|
||||
# If event doesn't include a display name, add one.
|
||||
profile = self.profile_handler
|
||||
|
||||
try:
|
||||
if "displayname" not in content:
|
||||
displayname = await profile.get_displayname(target)
|
||||
if displayname is not None:
|
||||
content["displayname"] = displayname
|
||||
if "avatar_url" not in content:
|
||||
avatar_url = await profile.get_avatar_url(target)
|
||||
if avatar_url is not None:
|
||||
content["avatar_url"] = avatar_url
|
||||
except Exception as e:
|
||||
logger.info("Failed to get profile information for %r: %s", target, e)
|
||||
|
||||
# if this is a join with a 3pid signature, we may need to turn a 3pid
|
||||
# invite into a normal invite before we can handle the join.
|
||||
if third_party_signed is not None:
|
||||
|
|
|
@ -28,11 +28,14 @@ from typing import (
|
|||
Dict,
|
||||
FrozenSet,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
import attr
|
||||
|
@ -128,6 +131,8 @@ class SyncVersion(Enum):
|
|||
|
||||
# Traditional `/sync` endpoint
|
||||
SYNC_V2 = "sync_v2"
|
||||
# Part of MSC3575 Sliding Sync
|
||||
E2EE_SYNC = "e2ee_sync"
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
|
@ -279,6 +284,43 @@ class SyncResult:
|
|||
or self.device_lists
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def empty(next_batch: StreamToken) -> "SyncResult":
|
||||
"Return a new empty result"
|
||||
return SyncResult(
|
||||
next_batch=next_batch,
|
||||
presence=[],
|
||||
account_data=[],
|
||||
joined=[],
|
||||
invited=[],
|
||||
knocked=[],
|
||||
archived=[],
|
||||
to_device=[],
|
||||
device_lists=DeviceListUpdates(),
|
||||
device_one_time_keys_count={},
|
||||
device_unused_fallback_key_types=[],
|
||||
)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class E2eeSyncResult:
|
||||
"""
|
||||
Attributes:
|
||||
next_batch: Token for the next sync
|
||||
to_device: List of direct messages for the device.
|
||||
device_lists: List of user_ids whose devices have changed
|
||||
device_one_time_keys_count: Dict of algorithm to count for one time keys
|
||||
for this device
|
||||
device_unused_fallback_key_types: List of key types that have an unused fallback
|
||||
key
|
||||
"""
|
||||
|
||||
next_batch: StreamToken
|
||||
to_device: List[JsonDict]
|
||||
device_lists: DeviceListUpdates
|
||||
device_one_time_keys_count: JsonMapping
|
||||
device_unused_fallback_key_types: List[str]
|
||||
|
||||
|
||||
class SyncHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
|
@ -322,6 +364,31 @@ class SyncHandler:
|
|||
|
||||
self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
|
||||
|
||||
@overload
|
||||
async def wait_for_sync_for_user(
|
||||
self,
|
||||
requester: Requester,
|
||||
sync_config: SyncConfig,
|
||||
sync_version: Literal[SyncVersion.SYNC_V2],
|
||||
request_key: SyncRequestKey,
|
||||
since_token: Optional[StreamToken] = None,
|
||||
timeout: int = 0,
|
||||
full_state: bool = False,
|
||||
) -> SyncResult: ...
|
||||
|
||||
@overload
|
||||
async def wait_for_sync_for_user(
|
||||
self,
|
||||
requester: Requester,
|
||||
sync_config: SyncConfig,
|
||||
sync_version: Literal[SyncVersion.E2EE_SYNC],
|
||||
request_key: SyncRequestKey,
|
||||
since_token: Optional[StreamToken] = None,
|
||||
timeout: int = 0,
|
||||
full_state: bool = False,
|
||||
) -> E2eeSyncResult: ...
|
||||
|
||||
@overload
|
||||
async def wait_for_sync_for_user(
|
||||
self,
|
||||
requester: Requester,
|
||||
|
@ -331,7 +398,18 @@ class SyncHandler:
|
|||
since_token: Optional[StreamToken] = None,
|
||||
timeout: int = 0,
|
||||
full_state: bool = False,
|
||||
) -> SyncResult:
|
||||
) -> Union[SyncResult, E2eeSyncResult]: ...
|
||||
|
||||
async def wait_for_sync_for_user(
|
||||
self,
|
||||
requester: Requester,
|
||||
sync_config: SyncConfig,
|
||||
sync_version: SyncVersion,
|
||||
request_key: SyncRequestKey,
|
||||
since_token: Optional[StreamToken] = None,
|
||||
timeout: int = 0,
|
||||
full_state: bool = False,
|
||||
) -> Union[SyncResult, E2eeSyncResult]:
|
||||
"""Get the sync for a client if we have new data for it now. Otherwise
|
||||
wait for new data to arrive on the server. If the timeout expires, then
|
||||
return an empty sync result.
|
||||
|
@ -344,8 +422,10 @@ class SyncHandler:
|
|||
since_token: The point in the stream to sync from.
|
||||
timeout: How long to wait for new data to arrive before giving up.
|
||||
full_state: Whether to return the full state for each room.
|
||||
|
||||
Returns:
|
||||
When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
|
||||
When `SyncVersion.E2EE_SYNC`, returns a `E2eeSyncResult`.
|
||||
"""
|
||||
# If the user is not part of the mau group, then check that limits have
|
||||
# not been exceeded (if not part of the group by this point, almost certain
|
||||
|
@ -366,6 +446,29 @@ class SyncHandler:
|
|||
logger.debug("Returning sync response for %s", user_id)
|
||||
return res
|
||||
|
||||
@overload
|
||||
async def _wait_for_sync_for_user(
|
||||
self,
|
||||
sync_config: SyncConfig,
|
||||
sync_version: Literal[SyncVersion.SYNC_V2],
|
||||
since_token: Optional[StreamToken],
|
||||
timeout: int,
|
||||
full_state: bool,
|
||||
cache_context: ResponseCacheContext[SyncRequestKey],
|
||||
) -> SyncResult: ...
|
||||
|
||||
@overload
|
||||
async def _wait_for_sync_for_user(
|
||||
self,
|
||||
sync_config: SyncConfig,
|
||||
sync_version: Literal[SyncVersion.E2EE_SYNC],
|
||||
since_token: Optional[StreamToken],
|
||||
timeout: int,
|
||||
full_state: bool,
|
||||
cache_context: ResponseCacheContext[SyncRequestKey],
|
||||
) -> E2eeSyncResult: ...
|
||||
|
||||
@overload
|
||||
async def _wait_for_sync_for_user(
|
||||
self,
|
||||
sync_config: SyncConfig,
|
||||
|
@ -374,7 +477,17 @@ class SyncHandler:
|
|||
timeout: int,
|
||||
full_state: bool,
|
||||
cache_context: ResponseCacheContext[SyncRequestKey],
|
||||
) -> SyncResult:
|
||||
) -> Union[SyncResult, E2eeSyncResult]: ...
|
||||
|
||||
async def _wait_for_sync_for_user(
|
||||
self,
|
||||
sync_config: SyncConfig,
|
||||
sync_version: SyncVersion,
|
||||
since_token: Optional[StreamToken],
|
||||
timeout: int,
|
||||
full_state: bool,
|
||||
cache_context: ResponseCacheContext[SyncRequestKey],
|
||||
) -> Union[SyncResult, E2eeSyncResult]:
|
||||
"""The start of the machinery that produces a /sync response.
|
||||
|
||||
See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details.
|
||||
|
@ -401,6 +514,24 @@ class SyncHandler:
|
|||
if context:
|
||||
context.tag = sync_label
|
||||
|
||||
if since_token is not None:
|
||||
# We need to make sure this worker has caught up with the token. If
|
||||
# this returns false it means we timed out waiting, and we should
|
||||
# just return an empty response.
|
||||
start = self.clock.time_msec()
|
||||
if not await self.notifier.wait_for_stream_token(since_token):
|
||||
logger.warning(
|
||||
"Timed out waiting for worker to catch up. Returning empty response"
|
||||
)
|
||||
return SyncResult.empty(since_token)
|
||||
|
||||
# If we've spent significant time waiting to catch up, take it off
|
||||
# the timeout.
|
||||
now = self.clock.time_msec()
|
||||
if now - start > 1_000:
|
||||
timeout -= now - start
|
||||
timeout = max(timeout, 0)
|
||||
|
||||
# if we have a since token, delete any to-device messages before that token
|
||||
# (since we now know that the device has received them)
|
||||
if since_token is not None:
|
||||
|
@ -417,14 +548,16 @@ class SyncHandler:
|
|||
if timeout == 0 or since_token is None or full_state:
|
||||
# we are going to return immediately, so don't bother calling
|
||||
# notifier.wait_for_events.
|
||||
result: SyncResult = await self.current_sync_for_user(
|
||||
sync_config, sync_version, since_token, full_state=full_state
|
||||
result: Union[SyncResult, E2eeSyncResult] = (
|
||||
await self.current_sync_for_user(
|
||||
sync_config, sync_version, since_token, full_state=full_state
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Otherwise, we wait for something to happen and report it to the user.
|
||||
async def current_sync_callback(
|
||||
before_token: StreamToken, after_token: StreamToken
|
||||
) -> SyncResult:
|
||||
) -> Union[SyncResult, E2eeSyncResult]:
|
||||
return await self.current_sync_for_user(
|
||||
sync_config, sync_version, since_token
|
||||
)
|
||||
|
@ -456,14 +589,43 @@ class SyncHandler:
|
|||
|
||||
return result
|
||||
|
||||
@overload
|
||||
async def current_sync_for_user(
|
||||
self,
|
||||
sync_config: SyncConfig,
|
||||
sync_version: Literal[SyncVersion.SYNC_V2],
|
||||
since_token: Optional[StreamToken] = None,
|
||||
full_state: bool = False,
|
||||
) -> SyncResult: ...
|
||||
|
||||
@overload
|
||||
async def current_sync_for_user(
|
||||
self,
|
||||
sync_config: SyncConfig,
|
||||
sync_version: Literal[SyncVersion.E2EE_SYNC],
|
||||
since_token: Optional[StreamToken] = None,
|
||||
full_state: bool = False,
|
||||
) -> E2eeSyncResult: ...
|
||||
|
||||
@overload
|
||||
async def current_sync_for_user(
|
||||
self,
|
||||
sync_config: SyncConfig,
|
||||
sync_version: SyncVersion,
|
||||
since_token: Optional[StreamToken] = None,
|
||||
full_state: bool = False,
|
||||
) -> SyncResult:
|
||||
"""Generates the response body of a sync result, represented as a SyncResult.
|
||||
) -> Union[SyncResult, E2eeSyncResult]: ...
|
||||
|
||||
async def current_sync_for_user(
|
||||
self,
|
||||
sync_config: SyncConfig,
|
||||
sync_version: SyncVersion,
|
||||
since_token: Optional[StreamToken] = None,
|
||||
full_state: bool = False,
|
||||
) -> Union[SyncResult, E2eeSyncResult]:
|
||||
"""
|
||||
Generates the response body of a sync result, represented as a
|
||||
`SyncResult`/`E2eeSyncResult`.
|
||||
|
||||
This is a wrapper around `generate_sync_result` which starts an open tracing
|
||||
span to track the sync. See `generate_sync_result` for the next part of your
|
||||
|
@ -474,15 +636,25 @@ class SyncHandler:
|
|||
sync_version: Determines what kind of sync response to generate.
|
||||
since_token: The point in the stream to sync from.p.
|
||||
full_state: Whether to return the full state for each room.
|
||||
|
||||
Returns:
|
||||
When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
|
||||
When `SyncVersion.E2EE_SYNC`, returns a `E2eeSyncResult`.
|
||||
"""
|
||||
with start_active_span("sync.current_sync_for_user"):
|
||||
log_kv({"since_token": since_token})
|
||||
|
||||
# Go through the `/sync` v2 path
|
||||
if sync_version == SyncVersion.SYNC_V2:
|
||||
sync_result: SyncResult = await self.generate_sync_result(
|
||||
sync_config, since_token, full_state
|
||||
sync_result: Union[SyncResult, E2eeSyncResult] = (
|
||||
await self.generate_sync_result(
|
||||
sync_config, since_token, full_state
|
||||
)
|
||||
)
|
||||
# Go through the MSC3575 Sliding Sync `/sync/e2ee` path
|
||||
elif sync_version == SyncVersion.E2EE_SYNC:
|
||||
sync_result = await self.generate_e2ee_sync_result(
|
||||
sync_config, since_token
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
|
@ -1691,6 +1863,96 @@ class SyncHandler:
|
|||
next_batch=sync_result_builder.now_token,
|
||||
)
|
||||
|
||||
async def generate_e2ee_sync_result(
|
||||
self,
|
||||
sync_config: SyncConfig,
|
||||
since_token: Optional[StreamToken] = None,
|
||||
) -> E2eeSyncResult:
|
||||
"""
|
||||
Generates the response body of a MSC3575 Sliding Sync `/sync/e2ee` result.
|
||||
|
||||
This is represented by a `E2eeSyncResult` struct, which is built from small
|
||||
pieces using a `SyncResultBuilder`. The `sync_result_builder` is passed as a
|
||||
mutable ("inout") parameter to various helper functions. These retrieve and
|
||||
process the data which forms the sync body, often writing to the
|
||||
`sync_result_builder` to store their output.
|
||||
|
||||
At the end, we transfer data from the `sync_result_builder` to a new `E2eeSyncResult`
|
||||
instance to signify that the sync calculation is complete.
|
||||
"""
|
||||
user_id = sync_config.user.to_string()
|
||||
app_service = self.store.get_app_service_by_user_id(user_id)
|
||||
if app_service:
|
||||
# We no longer support AS users using /sync directly.
|
||||
# See https://github.com/matrix-org/matrix-doc/issues/1144
|
||||
raise NotImplementedError()
|
||||
|
||||
sync_result_builder = await self.get_sync_result_builder(
|
||||
sync_config,
|
||||
since_token,
|
||||
full_state=False,
|
||||
)
|
||||
|
||||
# 1. Calculate `to_device` events
|
||||
await self._generate_sync_entry_for_to_device(sync_result_builder)
|
||||
|
||||
# 2. Calculate `device_lists`
|
||||
# Device list updates are sent if a since token is provided.
|
||||
device_lists = DeviceListUpdates()
|
||||
include_device_list_updates = bool(since_token and since_token.device_list_key)
|
||||
if include_device_list_updates:
|
||||
# Note that _generate_sync_entry_for_rooms sets sync_result_builder.joined, which
|
||||
# is used in calculate_user_changes below.
|
||||
#
|
||||
# TODO: Running `_generate_sync_entry_for_rooms()` is a lot of work just to
|
||||
# figure out the membership changes/derived info needed for
|
||||
# `_generate_sync_entry_for_device_list()`. In the future, we should try to
|
||||
# refactor this away.
|
||||
(
|
||||
newly_joined_rooms,
|
||||
newly_left_rooms,
|
||||
) = await self._generate_sync_entry_for_rooms(sync_result_builder)
|
||||
|
||||
# This uses the sync_result_builder.joined which is set in
|
||||
# `_generate_sync_entry_for_rooms`, if that didn't find any joined
|
||||
# rooms for some reason it is a no-op.
|
||||
(
|
||||
newly_joined_or_invited_or_knocked_users,
|
||||
newly_left_users,
|
||||
) = sync_result_builder.calculate_user_changes()
|
||||
|
||||
device_lists = await self._generate_sync_entry_for_device_list(
|
||||
sync_result_builder,
|
||||
newly_joined_rooms=newly_joined_rooms,
|
||||
newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
|
||||
newly_left_rooms=newly_left_rooms,
|
||||
newly_left_users=newly_left_users,
|
||||
)
|
||||
|
||||
# 3. Calculate `device_one_time_keys_count` and `device_unused_fallback_key_types`
|
||||
device_id = sync_config.device_id
|
||||
one_time_keys_count: JsonMapping = {}
|
||||
unused_fallback_key_types: List[str] = []
|
||||
if device_id:
|
||||
# TODO: We should have a way to let clients differentiate between the states of:
|
||||
# * no change in OTK count since the provided since token
|
||||
# * the server has zero OTKs left for this device
|
||||
# Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
|
||||
one_time_keys_count = await self.store.count_e2e_one_time_keys(
|
||||
user_id, device_id
|
||||
)
|
||||
unused_fallback_key_types = list(
|
||||
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
|
||||
)
|
||||
|
||||
return E2eeSyncResult(
|
||||
to_device=sync_result_builder.to_device,
|
||||
device_lists=device_lists,
|
||||
device_one_time_keys_count=one_time_keys_count,
|
||||
device_unused_fallback_key_types=unused_fallback_key_types,
|
||||
next_batch=sync_result_builder.now_token,
|
||||
)
|
||||
|
||||
async def get_sync_result_builder(
|
||||
self,
|
||||
sync_config: SyncConfig,
|
||||
|
@ -1889,7 +2151,7 @@ class SyncHandler:
|
|||
users_that_have_changed = (
|
||||
await self._device_handler.get_device_changes_in_shared_rooms(
|
||||
user_id,
|
||||
sync_result_builder.joined_room_ids,
|
||||
joined_room_ids,
|
||||
from_token=since_token,
|
||||
now_token=sync_result_builder.now_token,
|
||||
)
|
||||
|
|
|
@ -477,9 +477,9 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||
|
||||
rows = []
|
||||
for room_id in changed_rooms:
|
||||
serial = self._room_serials[room_id]
|
||||
if last_id < serial <= current_id:
|
||||
typing = self._room_typing[room_id]
|
||||
serial = self._room_serials.get(room_id)
|
||||
if serial and last_id < serial <= current_id:
|
||||
typing = self._room_typing.get(room_id, set())
|
||||
rows.append((serial, [room_id, list(typing)]))
|
||||
rows.sort()
|
||||
|
||||
|
|
|
@ -650,7 +650,7 @@ class MediaRepository:
|
|||
|
||||
file_info = FileInfo(server_name=server_name, file_id=file_id)
|
||||
|
||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
||||
async with self.media_storage.store_into_file(file_info) as (f, fname):
|
||||
try:
|
||||
length, headers = await self.client.download_media(
|
||||
server_name,
|
||||
|
@ -693,8 +693,6 @@ class MediaRepository:
|
|||
)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
|
||||
await finish()
|
||||
|
||||
if b"Content-Type" in headers:
|
||||
media_type = headers[b"Content-Type"][0].decode("ascii")
|
||||
else:
|
||||
|
@ -1045,17 +1043,17 @@ class MediaRepository:
|
|||
),
|
||||
)
|
||||
|
||||
with self.media_storage.store_into_file(file_info) as (
|
||||
f,
|
||||
fname,
|
||||
finish,
|
||||
):
|
||||
async with self.media_storage.store_into_file(file_info) as (f, fname):
|
||||
try:
|
||||
await self.media_storage.write_to_file(t_byte_source, f)
|
||||
await finish()
|
||||
finally:
|
||||
t_byte_source.close()
|
||||
|
||||
# We flush and close the file to ensure that the bytes have
|
||||
# been written before getting the size.
|
||||
f.flush()
|
||||
f.close()
|
||||
|
||||
t_len = os.path.getsize(fname)
|
||||
|
||||
# Write to database
|
||||
|
|
|
@ -27,10 +27,9 @@ from typing import (
|
|||
IO,
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
AsyncIterator,
|
||||
BinaryIO,
|
||||
Callable,
|
||||
Generator,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
|
@ -97,11 +96,9 @@ class MediaStorage:
|
|||
the file path written to in the primary media store
|
||||
"""
|
||||
|
||||
with self.store_into_file(file_info) as (f, fname, finish_cb):
|
||||
async with self.store_into_file(file_info) as (f, fname):
|
||||
# Write to the main media repository
|
||||
await self.write_to_file(source, f)
|
||||
# Write to the other storage providers
|
||||
await finish_cb()
|
||||
|
||||
return fname
|
||||
|
||||
|
@ -111,32 +108,27 @@ class MediaStorage:
|
|||
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
|
||||
|
||||
@trace_with_opname("MediaStorage.store_into_file")
|
||||
@contextlib.contextmanager
|
||||
def store_into_file(
|
||||
@contextlib.asynccontextmanager
|
||||
async def store_into_file(
|
||||
self, file_info: FileInfo
|
||||
) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
|
||||
"""Context manager used to get a file like object to write into, as
|
||||
) -> AsyncIterator[Tuple[BinaryIO, str]]:
|
||||
"""Async Context manager used to get a file like object to write into, as
|
||||
described by file_info.
|
||||
|
||||
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
|
||||
like object that can be written to, fname is the absolute path of file
|
||||
on disk, and finish_cb is a function that returns an awaitable.
|
||||
Actually yields a 2-tuple (file, fname,), where file is a file
|
||||
like object that can be written to and fname is the absolute path of file
|
||||
on disk.
|
||||
|
||||
fname can be used to read the contents from after upload, e.g. to
|
||||
generate thumbnails.
|
||||
|
||||
finish_cb must be called and waited on after the file has been successfully been
|
||||
written to. Should not be called if there was an error. Checks for spam and
|
||||
stores the file into the configured storage providers.
|
||||
|
||||
Args:
|
||||
file_info: Info about the file to store
|
||||
|
||||
Example:
|
||||
|
||||
with media_storage.store_into_file(info) as (f, fname, finish_cb):
|
||||
async with media_storage.store_into_file(info) as (f, fname,):
|
||||
# .. write into f ...
|
||||
await finish_cb()
|
||||
"""
|
||||
|
||||
path = self._file_info_to_path(file_info)
|
||||
|
@ -145,63 +137,38 @@ class MediaStorage:
|
|||
dirname = os.path.dirname(fname)
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
|
||||
finished_called = [False]
|
||||
|
||||
main_media_repo_write_trace_scope = start_active_span(
|
||||
"writing to main media repo"
|
||||
)
|
||||
main_media_repo_write_trace_scope.__enter__()
|
||||
|
||||
try:
|
||||
with open(fname, "wb") as f:
|
||||
with start_active_span("writing to main media repo"):
|
||||
with open(fname, "wb") as f:
|
||||
yield f, fname
|
||||
|
||||
async def finish() -> None:
|
||||
# When someone calls finish, we assume they are done writing to the main media repo
|
||||
main_media_repo_write_trace_scope.__exit__(None, None, None)
|
||||
with start_active_span("writing to other storage providers"):
|
||||
spam_check = (
|
||||
await self._spam_checker_module_callbacks.check_media_file_for_spam(
|
||||
ReadableFileWrapper(self.clock, fname), file_info
|
||||
)
|
||||
)
|
||||
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
|
||||
logger.info("Blocking media due to spam checker")
|
||||
# Note that we'll delete the stored media, due to the
|
||||
# try/except below. The media also won't be stored in
|
||||
# the DB.
|
||||
# We currently ignore any additional field returned by
|
||||
# the spam-check API.
|
||||
raise SpamMediaException(errcode=spam_check[0])
|
||||
|
||||
with start_active_span("writing to other storage providers"):
|
||||
# Ensure that all writes have been flushed and close the
|
||||
# file.
|
||||
f.flush()
|
||||
f.close()
|
||||
for provider in self.storage_providers:
|
||||
with start_active_span(str(provider)):
|
||||
await provider.store_file(path, file_info)
|
||||
|
||||
spam_check = await self._spam_checker_module_callbacks.check_media_file_for_spam(
|
||||
ReadableFileWrapper(self.clock, fname), file_info
|
||||
)
|
||||
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
|
||||
logger.info("Blocking media due to spam checker")
|
||||
# Note that we'll delete the stored media, due to the
|
||||
# try/except below. The media also won't be stored in
|
||||
# the DB.
|
||||
# We currently ignore any additional field returned by
|
||||
# the spam-check API.
|
||||
raise SpamMediaException(errcode=spam_check[0])
|
||||
|
||||
for provider in self.storage_providers:
|
||||
with start_active_span(str(provider)):
|
||||
await provider.store_file(path, file_info)
|
||||
|
||||
finished_called[0] = True
|
||||
|
||||
yield f, fname, finish
|
||||
except Exception as e:
|
||||
try:
|
||||
main_media_repo_write_trace_scope.__exit__(
|
||||
type(e), None, e.__traceback__
|
||||
)
|
||||
os.remove(fname)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise e from None
|
||||
|
||||
if not finished_called:
|
||||
exc = Exception("Finished callback not called")
|
||||
main_media_repo_write_trace_scope.__exit__(
|
||||
type(exc), None, exc.__traceback__
|
||||
)
|
||||
raise exc
|
||||
|
||||
async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
|
||||
"""Attempts to fetch media described by file_info from the local cache
|
||||
and configured storage providers.
|
||||
|
|
|
@ -22,11 +22,27 @@
|
|||
import logging
|
||||
from io import BytesIO
|
||||
from types import TracebackType
|
||||
from typing import Optional, Tuple, Type
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError, cs_error
|
||||
from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
|
||||
from synapse.http.server import respond_with_json
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.media._base import (
|
||||
FileInfo,
|
||||
ThumbnailInfo,
|
||||
respond_404,
|
||||
respond_with_file,
|
||||
respond_with_responder,
|
||||
)
|
||||
from synapse.media.media_storage import MediaStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.media.media_repository import MediaRepository
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -231,3 +247,471 @@ class Thumbnailer:
|
|||
def __del__(self) -> None:
|
||||
# Make sure we actually do close the image, rather than leak data.
|
||||
self.close()
|
||||
|
||||
|
||||
class ThumbnailProvider:
|
||||
def __init__(
|
||||
self,
|
||||
hs: "HomeServer",
|
||||
media_repo: "MediaRepository",
|
||||
media_storage: MediaStorage,
|
||||
):
|
||||
self.hs = hs
|
||||
self.media_repo = media_repo
|
||||
self.media_storage = media_storage
|
||||
self.store = hs.get_datastores().main
|
||||
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
|
||||
|
||||
async def respond_local_thumbnail(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
media_id: str,
|
||||
width: int,
|
||||
height: int,
|
||||
method: str,
|
||||
m_type: str,
|
||||
max_timeout_ms: int,
|
||||
) -> None:
|
||||
media_info = await self.media_repo.get_local_media_info(
|
||||
request, media_id, max_timeout_ms
|
||||
)
|
||||
if not media_info:
|
||||
return
|
||||
|
||||
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
|
||||
await self._select_and_respond_with_thumbnail(
|
||||
request,
|
||||
width,
|
||||
height,
|
||||
method,
|
||||
m_type,
|
||||
thumbnail_infos,
|
||||
media_id,
|
||||
media_id,
|
||||
url_cache=bool(media_info.url_cache),
|
||||
server_name=None,
|
||||
)
|
||||
|
||||
async def select_or_generate_local_thumbnail(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
media_id: str,
|
||||
desired_width: int,
|
||||
desired_height: int,
|
||||
desired_method: str,
|
||||
desired_type: str,
|
||||
max_timeout_ms: int,
|
||||
) -> None:
|
||||
media_info = await self.media_repo.get_local_media_info(
|
||||
request, media_id, max_timeout_ms
|
||||
)
|
||||
|
||||
if not media_info:
|
||||
return
|
||||
|
||||
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
|
||||
for info in thumbnail_infos:
|
||||
t_w = info.width == desired_width
|
||||
t_h = info.height == desired_height
|
||||
t_method = info.method == desired_method
|
||||
t_type = info.type == desired_type
|
||||
|
||||
if t_w and t_h and t_method and t_type:
|
||||
file_info = FileInfo(
|
||||
server_name=None,
|
||||
file_id=media_id,
|
||||
url_cache=bool(media_info.url_cache),
|
||||
thumbnail=info,
|
||||
)
|
||||
|
||||
responder = await self.media_storage.fetch_media(file_info)
|
||||
if responder:
|
||||
await respond_with_responder(
|
||||
request, responder, info.type, info.length
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug("We don't have a thumbnail of that size. Generating")
|
||||
|
||||
# Okay, so we generate one.
|
||||
file_path = await self.media_repo.generate_local_exact_thumbnail(
|
||||
media_id,
|
||||
desired_width,
|
||||
desired_height,
|
||||
desired_method,
|
||||
desired_type,
|
||||
url_cache=bool(media_info.url_cache),
|
||||
)
|
||||
|
||||
if file_path:
|
||||
await respond_with_file(request, desired_type, file_path)
|
||||
else:
|
||||
logger.warning("Failed to generate thumbnail")
|
||||
raise SynapseError(400, "Failed to generate thumbnail.")
|
||||
|
||||
async def select_or_generate_remote_thumbnail(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
server_name: str,
|
||||
media_id: str,
|
||||
desired_width: int,
|
||||
desired_height: int,
|
||||
desired_method: str,
|
||||
desired_type: str,
|
||||
max_timeout_ms: int,
|
||||
) -> None:
|
||||
media_info = await self.media_repo.get_remote_media_info(
|
||||
server_name, media_id, max_timeout_ms
|
||||
)
|
||||
if not media_info:
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
thumbnail_infos = await self.store.get_remote_media_thumbnails(
|
||||
server_name, media_id
|
||||
)
|
||||
|
||||
file_id = media_info.filesystem_id
|
||||
|
||||
for info in thumbnail_infos:
|
||||
t_w = info.width == desired_width
|
||||
t_h = info.height == desired_height
|
||||
t_method = info.method == desired_method
|
||||
t_type = info.type == desired_type
|
||||
|
||||
if t_w and t_h and t_method and t_type:
|
||||
file_info = FileInfo(
|
||||
server_name=server_name,
|
||||
file_id=file_id,
|
||||
thumbnail=info,
|
||||
)
|
||||
|
||||
responder = await self.media_storage.fetch_media(file_info)
|
||||
if responder:
|
||||
await respond_with_responder(
|
||||
request, responder, info.type, info.length
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug("We don't have a thumbnail of that size. Generating")
|
||||
|
||||
# Okay, so we generate one.
|
||||
file_path = await self.media_repo.generate_remote_exact_thumbnail(
|
||||
server_name,
|
||||
file_id,
|
||||
media_id,
|
||||
desired_width,
|
||||
desired_height,
|
||||
desired_method,
|
||||
desired_type,
|
||||
)
|
||||
|
||||
if file_path:
|
||||
await respond_with_file(request, desired_type, file_path)
|
||||
else:
|
||||
logger.warning("Failed to generate thumbnail")
|
||||
raise SynapseError(400, "Failed to generate thumbnail.")
|
||||
|
||||
async def respond_remote_thumbnail(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
server_name: str,
|
||||
media_id: str,
|
||||
width: int,
|
||||
height: int,
|
||||
method: str,
|
||||
m_type: str,
|
||||
max_timeout_ms: int,
|
||||
) -> None:
|
||||
# TODO: Don't download the whole remote file
|
||||
# We should proxy the thumbnail from the remote server instead of
|
||||
# downloading the remote file and generating our own thumbnails.
|
||||
media_info = await self.media_repo.get_remote_media_info(
|
||||
server_name, media_id, max_timeout_ms
|
||||
)
|
||||
if not media_info:
|
||||
return
|
||||
|
||||
thumbnail_infos = await self.store.get_remote_media_thumbnails(
|
||||
server_name, media_id
|
||||
)
|
||||
await self._select_and_respond_with_thumbnail(
|
||||
request,
|
||||
width,
|
||||
height,
|
||||
method,
|
||||
m_type,
|
||||
thumbnail_infos,
|
||||
media_id,
|
||||
media_info.filesystem_id,
|
||||
url_cache=False,
|
||||
server_name=server_name,
|
||||
)
|
||||
|
||||
async def _select_and_respond_with_thumbnail(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
desired_width: int,
|
||||
desired_height: int,
|
||||
desired_method: str,
|
||||
desired_type: str,
|
||||
thumbnail_infos: List[ThumbnailInfo],
|
||||
media_id: str,
|
||||
file_id: str,
|
||||
url_cache: bool,
|
||||
server_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
desired_width: The desired width, the returned thumbnail may be larger than this.
|
||||
desired_height: The desired height, the returned thumbnail may be larger than this.
|
||||
desired_method: The desired method used to generate the thumbnail.
|
||||
desired_type: The desired content-type of the thumbnail.
|
||||
thumbnail_infos: A list of thumbnail info of candidate thumbnails.
|
||||
file_id: The ID of the media that a thumbnail is being requested for.
|
||||
url_cache: True if this is from a URL cache.
|
||||
server_name: The server name, if this is a remote thumbnail.
|
||||
"""
|
||||
logger.debug(
|
||||
"_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s",
|
||||
media_id,
|
||||
desired_width,
|
||||
desired_height,
|
||||
desired_method,
|
||||
thumbnail_infos,
|
||||
)
|
||||
|
||||
# If `dynamic_thumbnails` is enabled, we expect Synapse to go down a
|
||||
# different code path to handle it.
|
||||
assert not self.dynamic_thumbnails
|
||||
|
||||
if thumbnail_infos:
|
||||
file_info = self._select_thumbnail(
|
||||
desired_width,
|
||||
desired_height,
|
||||
desired_method,
|
||||
desired_type,
|
||||
thumbnail_infos,
|
||||
file_id,
|
||||
url_cache,
|
||||
server_name,
|
||||
)
|
||||
if not file_info:
|
||||
logger.info("Couldn't find a thumbnail matching the desired inputs")
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
# The thumbnail property must exist.
|
||||
assert file_info.thumbnail is not None
|
||||
|
||||
responder = await self.media_storage.fetch_media(file_info)
|
||||
if responder:
|
||||
await respond_with_responder(
|
||||
request,
|
||||
responder,
|
||||
file_info.thumbnail.type,
|
||||
file_info.thumbnail.length,
|
||||
)
|
||||
return
|
||||
|
||||
# If we can't find the thumbnail we regenerate it. This can happen
|
||||
# if e.g. we've deleted the thumbnails but still have the original
|
||||
# image somewhere.
|
||||
#
|
||||
# Since we have an entry for the thumbnail in the DB we a) know we
|
||||
# have have successfully generated the thumbnail in the past (so we
|
||||
# don't need to worry about repeatedly failing to generate
|
||||
# thumbnails), and b) have already calculated that appropriate
|
||||
# width/height/method so we can just call the "generate exact"
|
||||
# methods.
|
||||
|
||||
# First let's check that we do actually have the original image
|
||||
# still. This will throw a 404 if we don't.
|
||||
# TODO: We should refetch the thumbnails for remote media.
|
||||
await self.media_storage.ensure_media_is_in_local_cache(
|
||||
FileInfo(server_name, file_id, url_cache=url_cache)
|
||||
)
|
||||
|
||||
if server_name:
|
||||
await self.media_repo.generate_remote_exact_thumbnail(
|
||||
server_name,
|
||||
file_id=file_id,
|
||||
media_id=media_id,
|
||||
t_width=file_info.thumbnail.width,
|
||||
t_height=file_info.thumbnail.height,
|
||||
t_method=file_info.thumbnail.method,
|
||||
t_type=file_info.thumbnail.type,
|
||||
)
|
||||
else:
|
||||
await self.media_repo.generate_local_exact_thumbnail(
|
||||
media_id=media_id,
|
||||
t_width=file_info.thumbnail.width,
|
||||
t_height=file_info.thumbnail.height,
|
||||
t_method=file_info.thumbnail.method,
|
||||
t_type=file_info.thumbnail.type,
|
||||
url_cache=url_cache,
|
||||
)
|
||||
|
||||
responder = await self.media_storage.fetch_media(file_info)
|
||||
await respond_with_responder(
|
||||
request,
|
||||
responder,
|
||||
file_info.thumbnail.type,
|
||||
file_info.thumbnail.length,
|
||||
)
|
||||
else:
|
||||
# This might be because:
|
||||
# 1. We can't create thumbnails for the given media (corrupted or
|
||||
# unsupported file type), or
|
||||
# 2. The thumbnailing process never ran or errored out initially
|
||||
# when the media was first uploaded (these bugs should be
|
||||
# reported and fixed).
|
||||
# Note that we don't attempt to generate a thumbnail now because
|
||||
# `dynamic_thumbnails` is disabled.
|
||||
logger.info("Failed to find any generated thumbnails")
|
||||
|
||||
assert request.path is not None
|
||||
respond_with_json(
|
||||
request,
|
||||
400,
|
||||
cs_error(
|
||||
"Cannot find any thumbnails for the requested media ('%s'). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
|
||||
% (
|
||||
request.path.decode(),
|
||||
", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
|
||||
),
|
||||
code=Codes.UNKNOWN,
|
||||
),
|
||||
send_cors=True,
|
||||
)
|
||||
|
||||
def _select_thumbnail(
|
||||
self,
|
||||
desired_width: int,
|
||||
desired_height: int,
|
||||
desired_method: str,
|
||||
desired_type: str,
|
||||
thumbnail_infos: List[ThumbnailInfo],
|
||||
file_id: str,
|
||||
url_cache: bool,
|
||||
server_name: Optional[str],
|
||||
) -> Optional[FileInfo]:
|
||||
"""
|
||||
Choose an appropriate thumbnail from the previously generated thumbnails.
|
||||
|
||||
Args:
|
||||
desired_width: The desired width, the returned thumbnail may be larger than this.
|
||||
desired_height: The desired height, the returned thumbnail may be larger than this.
|
||||
desired_method: The desired method used to generate the thumbnail.
|
||||
desired_type: The desired content-type of the thumbnail.
|
||||
thumbnail_infos: A list of thumbnail infos of candidate thumbnails.
|
||||
file_id: The ID of the media that a thumbnail is being requested for.
|
||||
url_cache: True if this is from a URL cache.
|
||||
server_name: The server name, if this is a remote thumbnail.
|
||||
|
||||
Returns:
|
||||
The thumbnail which best matches the desired parameters.
|
||||
"""
|
||||
desired_method = desired_method.lower()
|
||||
|
||||
# The chosen thumbnail.
|
||||
thumbnail_info = None
|
||||
|
||||
d_w = desired_width
|
||||
d_h = desired_height
|
||||
|
||||
if desired_method == "crop":
|
||||
# Thumbnails that match equal or larger sizes of desired width/height.
|
||||
crop_info_list: List[
|
||||
Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
|
||||
] = []
|
||||
# Other thumbnails.
|
||||
crop_info_list2: List[
|
||||
Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
|
||||
] = []
|
||||
for info in thumbnail_infos:
|
||||
# Skip thumbnails generated with different methods.
|
||||
if info.method != "crop":
|
||||
continue
|
||||
|
||||
t_w = info.width
|
||||
t_h = info.height
|
||||
aspect_quality = abs(d_w * t_h - d_h * t_w)
|
||||
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
|
||||
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
||||
type_quality = desired_type != info.type
|
||||
length_quality = info.length
|
||||
if t_w >= d_w or t_h >= d_h:
|
||||
crop_info_list.append(
|
||||
(
|
||||
aspect_quality,
|
||||
min_quality,
|
||||
size_quality,
|
||||
type_quality,
|
||||
length_quality,
|
||||
info,
|
||||
)
|
||||
)
|
||||
else:
|
||||
crop_info_list2.append(
|
||||
(
|
||||
aspect_quality,
|
||||
min_quality,
|
||||
size_quality,
|
||||
type_quality,
|
||||
length_quality,
|
||||
info,
|
||||
)
|
||||
)
|
||||
# Pick the most appropriate thumbnail. Some values of `desired_width` and
|
||||
# `desired_height` may result in a tie, in which case we avoid comparing on
|
||||
# the thumbnail info and pick the thumbnail that appears earlier
|
||||
# in the list of candidates.
|
||||
if crop_info_list:
|
||||
thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1]
|
||||
elif crop_info_list2:
|
||||
thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1]
|
||||
elif desired_method == "scale":
|
||||
# Thumbnails that match equal or larger sizes of desired width/height.
|
||||
info_list: List[Tuple[int, bool, int, ThumbnailInfo]] = []
|
||||
# Other thumbnails.
|
||||
info_list2: List[Tuple[int, bool, int, ThumbnailInfo]] = []
|
||||
|
||||
for info in thumbnail_infos:
|
||||
# Skip thumbnails generated with different methods.
|
||||
if info.method != "scale":
|
||||
continue
|
||||
|
||||
t_w = info.width
|
||||
t_h = info.height
|
||||
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
||||
type_quality = desired_type != info.type
|
||||
length_quality = info.length
|
||||
if t_w >= d_w or t_h >= d_h:
|
||||
info_list.append((size_quality, type_quality, length_quality, info))
|
||||
else:
|
||||
info_list2.append(
|
||||
(size_quality, type_quality, length_quality, info)
|
||||
)
|
||||
# Pick the most appropriate thumbnail. Some values of `desired_width` and
|
||||
# `desired_height` may result in a tie, in which case we avoid comparing on
|
||||
# the thumbnail info and pick the thumbnail that appears earlier
|
||||
# in the list of candidates.
|
||||
if info_list:
|
||||
thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1]
|
||||
elif info_list2:
|
||||
thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1]
|
||||
|
||||
if thumbnail_info:
|
||||
return FileInfo(
|
||||
file_id=file_id,
|
||||
url_cache=url_cache,
|
||||
server_name=server_name,
|
||||
thumbnail=thumbnail_info,
|
||||
)
|
||||
|
||||
# No matching thumbnail was found.
|
||||
return None
|
||||
|
|
|
@ -592,7 +592,7 @@ class UrlPreviewer:
|
|||
|
||||
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
|
||||
|
||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
||||
async with self.media_storage.store_into_file(file_info) as (f, fname):
|
||||
if url.startswith("data:"):
|
||||
if not allow_data_urls:
|
||||
raise SynapseError(
|
||||
|
@ -603,8 +603,6 @@ class UrlPreviewer:
|
|||
else:
|
||||
download_result = await self._download_url(url, f)
|
||||
|
||||
await finish()
|
||||
|
||||
try:
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
|
|
|
@ -763,6 +763,29 @@ class Notifier:
|
|||
|
||||
return result
|
||||
|
||||
async def wait_for_stream_token(self, stream_token: StreamToken) -> bool:
|
||||
"""Wait for this worker to catch up with the given stream token."""
|
||||
|
||||
start = self.clock.time_msec()
|
||||
while True:
|
||||
current_token = self.event_sources.get_current_token()
|
||||
if stream_token.is_before_or_eq(current_token):
|
||||
return True
|
||||
|
||||
now = self.clock.time_msec()
|
||||
|
||||
if now - start > 10_000:
|
||||
return False
|
||||
|
||||
logger.info(
|
||||
"Waiting for current token to reach %s; currently at %s",
|
||||
stream_token,
|
||||
current_token,
|
||||
)
|
||||
|
||||
# TODO: be better
|
||||
await self.clock.sleep(0.5)
|
||||
|
||||
async def _get_room_ids(
|
||||
self, user: UserID, explicit_room_id: Optional[str]
|
||||
) -> Tuple[StrCollection, bool]:
|
||||
|
|
205
synapse/rest/client/media.py
Normal file
205
synapse/rest/client/media.py
Normal file
|
@ -0,0 +1,205 @@
|
|||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
# Copyright (C) 2024 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
#
|
||||
# Originally licensed under the Apache License, Version 2.0:
|
||||
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
||||
#
|
||||
# [This file includes modifications made by New Vector Limited]
|
||||
#
|
||||
#
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
from synapse.http.server import (
|
||||
HttpServer,
|
||||
respond_with_json,
|
||||
respond_with_json_bytes,
|
||||
set_corp_headers,
|
||||
set_cors_headers,
|
||||
)
|
||||
from synapse.http.servlet import RestServlet, parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.media._base import (
|
||||
DEFAULT_MAX_TIMEOUT_MS,
|
||||
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
|
||||
respond_404,
|
||||
)
|
||||
from synapse.media.media_repository import MediaRepository
|
||||
from synapse.media.media_storage import MediaStorage
|
||||
from synapse.media.thumbnailer import ThumbnailProvider
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util.stringutils import parse_and_validate_server_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstablePreviewURLServlet(RestServlet):
|
||||
"""
|
||||
Same as `GET /_matrix/media/r0/preview_url`, this endpoint provides a generic preview API
|
||||
for URLs which outputs Open Graph (https://ogp.me/) responses (with some Matrix
|
||||
specific additions).
|
||||
|
||||
This does have trade-offs compared to other designs:
|
||||
|
||||
* Pros:
|
||||
* Simple and flexible; can be used by any clients at any point
|
||||
* Cons:
|
||||
* If each homeserver provides one of these independently, all the homeservers in a
|
||||
room may needlessly DoS the target URI
|
||||
* The URL metadata must be stored somewhere, rather than just using Matrix
|
||||
itself to store the media.
|
||||
* Matrix cannot be used to distribute the metadata between homeservers.
|
||||
"""
|
||||
|
||||
PATTERNS = [
|
||||
re.compile(r"^/_matrix/client/unstable/org.matrix.msc3916/media/preview_url$")
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hs: "HomeServer",
|
||||
media_repo: "MediaRepository",
|
||||
media_storage: MediaStorage,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.media_repo = media_repo
|
||||
self.media_storage = media_storage
|
||||
assert self.media_repo.url_previewer is not None
|
||||
self.url_previewer = self.media_repo.url_previewer
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> None:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
url = parse_string(request, "url", required=True)
|
||||
ts = parse_integer(request, "ts")
|
||||
if ts is None:
|
||||
ts = self.clock.time_msec()
|
||||
|
||||
og = await self.url_previewer.preview(url, requester.user, ts)
|
||||
respond_with_json_bytes(request, 200, og, send_cors=True)
|
||||
|
||||
|
||||
class UnstableMediaConfigResource(RestServlet):
|
||||
PATTERNS = [
|
||||
re.compile(r"^/_matrix/client/unstable/org.matrix.msc3916/media/config$")
|
||||
]
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
config = hs.config
|
||||
self.clock = hs.get_clock()
|
||||
self.auth = hs.get_auth()
|
||||
self.limits_dict = {"m.upload.size": config.media.max_upload_size}
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> None:
|
||||
await self.auth.get_user_by_req(request)
|
||||
respond_with_json(request, 200, self.limits_dict, send_cors=True)
|
||||
|
||||
|
||||
class UnstableThumbnailResource(RestServlet):
|
||||
PATTERNS = [
|
||||
re.compile(
|
||||
"/_matrix/client/unstable/org.matrix.msc3916/media/thumbnail/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
|
||||
)
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hs: "HomeServer",
|
||||
media_repo: "MediaRepository",
|
||||
media_storage: MediaStorage,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.store = hs.get_datastores().main
|
||||
self.media_repo = media_repo
|
||||
self.media_storage = media_storage
|
||||
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
|
||||
self._is_mine_server_name = hs.is_mine_server_name
|
||||
self._server_name = hs.hostname
|
||||
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
|
||||
self.thumbnailer = ThumbnailProvider(hs, media_repo, media_storage)
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, server_name: str, media_id: str
|
||||
) -> None:
|
||||
# Validate the server name, raising if invalid
|
||||
parse_and_validate_server_name(server_name)
|
||||
await self.auth.get_user_by_req(request)
|
||||
|
||||
set_cors_headers(request)
|
||||
set_corp_headers(request)
|
||||
width = parse_integer(request, "width", required=True)
|
||||
height = parse_integer(request, "height", required=True)
|
||||
method = parse_string(request, "method", "scale")
|
||||
# TODO Parse the Accept header to get an prioritised list of thumbnail types.
|
||||
m_type = "image/png"
|
||||
max_timeout_ms = parse_integer(
|
||||
request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
|
||||
)
|
||||
max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
|
||||
|
||||
if self._is_mine_server_name(server_name):
|
||||
if self.dynamic_thumbnails:
|
||||
await self.thumbnailer.select_or_generate_local_thumbnail(
|
||||
request, media_id, width, height, method, m_type, max_timeout_ms
|
||||
)
|
||||
else:
|
||||
await self.thumbnailer.respond_local_thumbnail(
|
||||
request, media_id, width, height, method, m_type, max_timeout_ms
|
||||
)
|
||||
self.media_repo.mark_recently_accessed(None, media_id)
|
||||
else:
|
||||
# Don't let users download media from configured domains, even if it
|
||||
# is already downloaded. This is Trust & Safety tooling to make some
|
||||
# media inaccessible to local users.
|
||||
# See `prevent_media_downloads_from` config docs for more info.
|
||||
if server_name in self.prevent_media_downloads_from:
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
remote_resp_function = (
|
||||
self.thumbnailer.select_or_generate_remote_thumbnail
|
||||
if self.dynamic_thumbnails
|
||||
else self.thumbnailer.respond_remote_thumbnail
|
||||
)
|
||||
await remote_resp_function(
|
||||
request,
|
||||
server_name,
|
||||
media_id,
|
||||
width,
|
||||
height,
|
||||
method,
|
||||
m_type,
|
||||
max_timeout_ms,
|
||||
)
|
||||
self.media_repo.mark_recently_accessed(server_name, media_id)
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
if hs.config.experimental.msc3916_authenticated_media_enabled:
|
||||
media_repo = hs.get_media_repository()
|
||||
if hs.config.media.url_preview_enabled:
|
||||
UnstablePreviewURLServlet(
|
||||
hs, media_repo, media_repo.media_storage
|
||||
).register(http_server)
|
||||
UnstableMediaConfigResource(hs).register(http_server)
|
||||
UnstableThumbnailResource(hs, media_repo, media_repo.media_storage).register(
|
||||
http_server
|
||||
)
|
|
@ -567,5 +567,176 @@ class SyncRestServlet(RestServlet):
|
|||
return result
|
||||
|
||||
|
||||
class SlidingSyncE2eeRestServlet(RestServlet):
|
||||
"""
|
||||
API endpoint for MSC3575 Sliding Sync `/sync/e2ee`. This is being introduced as part
|
||||
of Sliding Sync but doesn't have any sliding window component. It's just a way to
|
||||
get E2EE events without having to sit through a big initial sync (`/sync` v2). And
|
||||
we can avoid encryption events being backed up by the main sync response.
|
||||
|
||||
Having To-Device messages split out to this sync endpoint also helps when clients
|
||||
need to have 2 or more sync streams open at a time, e.g a push notification process
|
||||
and a main process. This can cause the two processes to race to fetch the To-Device
|
||||
events, resulting in the need for complex synchronisation rules to ensure the token
|
||||
is correctly and atomically exchanged between processes.
|
||||
|
||||
GET parameters::
|
||||
timeout(int): How long to wait for new events in milliseconds.
|
||||
since(batch_token): Batch token when asking for incremental deltas.
|
||||
|
||||
Response JSON::
|
||||
{
|
||||
"next_batch": // batch token for the next /sync
|
||||
"to_device": {
|
||||
// list of to-device events
|
||||
"events": [
|
||||
{
|
||||
"content: { "algorithm": "m.olm.v1.curve25519-aes-sha2", "ciphertext": { ... }, "org.matrix.msgid": "abcd", "session_id": "abcd" },
|
||||
"type": "m.room.encrypted",
|
||||
"sender": "@alice:example.com",
|
||||
}
|
||||
// ...
|
||||
]
|
||||
},
|
||||
"device_lists": {
|
||||
"changed": ["@alice:example.com"],
|
||||
"left": ["@bob:example.com"]
|
||||
},
|
||||
"device_one_time_keys_count": {
|
||||
"signed_curve25519": 50
|
||||
},
|
||||
"device_unused_fallback_key_types": [
|
||||
"signed_curve25519"
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
"/org.matrix.msc3575/sync/e2ee$", releases=[], v1=False, unstable=True
|
||||
)
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastores().main
|
||||
self.sync_handler = hs.get_sync_handler()
|
||||
|
||||
# Filtering only matters for the `device_lists` because it requires a bunch of
|
||||
# derived information from rooms (see how `_generate_sync_entry_for_rooms()`
|
||||
# prepares a bunch of data for `_generate_sync_entry_for_device_list()`).
|
||||
self.only_member_events_filter_collection = FilterCollection(
|
||||
self.hs,
|
||||
{
|
||||
"room": {
|
||||
# We only care about membership events for the `device_lists`.
|
||||
# Membership will tell us whether a user has joined/left a room and
|
||||
# if there are new devices to encrypt for.
|
||||
"timeline": {
|
||||
"types": ["m.room.member"],
|
||||
},
|
||||
"state": {
|
||||
"types": ["m.room.member"],
|
||||
},
|
||||
# We don't want any extra account_data generated because it's not
|
||||
# returned by this endpoint. This helps us avoid work in
|
||||
# `_generate_sync_entry_for_rooms()`
|
||||
"account_data": {
|
||||
"not_types": ["*"],
|
||||
},
|
||||
# We don't want any extra ephemeral data generated because it's not
|
||||
# returned by this endpoint. This helps us avoid work in
|
||||
# `_generate_sync_entry_for_rooms()`
|
||||
"ephemeral": {
|
||||
"not_types": ["*"],
|
||||
},
|
||||
},
|
||||
# We don't want any extra account_data generated because it's not
|
||||
# returned by this endpoint. (This is just here for good measure)
|
||||
"account_data": {
|
||||
"not_types": ["*"],
|
||||
},
|
||||
# We don't want any extra presence data generated because it's not
|
||||
# returned by this endpoint. (This is just here for good measure)
|
||||
"presence": {
|
||||
"not_types": ["*"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
user = requester.user
|
||||
device_id = requester.device_id
|
||||
|
||||
timeout = parse_integer(request, "timeout", default=0)
|
||||
since = parse_string(request, "since")
|
||||
|
||||
sync_config = SyncConfig(
|
||||
user=user,
|
||||
filter_collection=self.only_member_events_filter_collection,
|
||||
is_guest=requester.is_guest,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
since_token = None
|
||||
if since is not None:
|
||||
since_token = await StreamToken.from_string(self.store, since)
|
||||
|
||||
# Request cache key
|
||||
request_key = (
|
||||
SyncVersion.E2EE_SYNC,
|
||||
user,
|
||||
timeout,
|
||||
since,
|
||||
)
|
||||
|
||||
# Gather data for the response
|
||||
sync_result = await self.sync_handler.wait_for_sync_for_user(
|
||||
requester,
|
||||
sync_config,
|
||||
SyncVersion.E2EE_SYNC,
|
||||
request_key,
|
||||
since_token=since_token,
|
||||
timeout=timeout,
|
||||
full_state=False,
|
||||
)
|
||||
|
||||
# The client may have disconnected by now; don't bother to serialize the
|
||||
# response if so.
|
||||
if request._disconnected:
|
||||
logger.info("Client has disconnected; not serializing response.")
|
||||
return 200, {}
|
||||
|
||||
response: JsonDict = defaultdict(dict)
|
||||
response["next_batch"] = await sync_result.next_batch.to_string(self.store)
|
||||
|
||||
if sync_result.to_device:
|
||||
response["to_device"] = {"events": sync_result.to_device}
|
||||
|
||||
if sync_result.device_lists.changed:
|
||||
response["device_lists"]["changed"] = list(sync_result.device_lists.changed)
|
||||
if sync_result.device_lists.left:
|
||||
response["device_lists"]["left"] = list(sync_result.device_lists.left)
|
||||
|
||||
# We always include this because https://github.com/vector-im/element-android/issues/3725
|
||||
# The spec isn't terribly clear on when this can be omitted and how a client would tell
|
||||
# the difference between "no keys present" and "nothing changed" in terms of whole field
|
||||
# absent / individual key type entry absent
|
||||
# Corresponding synapse issue: https://github.com/matrix-org/synapse/issues/10456
|
||||
response["device_one_time_keys_count"] = sync_result.device_one_time_keys_count
|
||||
|
||||
# https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
|
||||
# states that this field should always be included, as long as the server supports the feature.
|
||||
response["device_unused_fallback_key_types"] = (
|
||||
sync_result.device_unused_fallback_key_types
|
||||
)
|
||||
|
||||
return 200, response
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
SyncRestServlet(hs).register(http_server)
|
||||
|
||||
if hs.config.experimental.msc3575_enabled:
|
||||
SlidingSyncE2eeRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -22,23 +22,18 @@
|
|||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError, cs_error
|
||||
from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
|
||||
from synapse.http.server import respond_with_json, set_corp_headers, set_cors_headers
|
||||
from synapse.http.server import set_corp_headers, set_cors_headers
|
||||
from synapse.http.servlet import RestServlet, parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.media._base import (
|
||||
DEFAULT_MAX_TIMEOUT_MS,
|
||||
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
|
||||
FileInfo,
|
||||
ThumbnailInfo,
|
||||
respond_404,
|
||||
respond_with_file,
|
||||
respond_with_responder,
|
||||
)
|
||||
from synapse.media.media_storage import MediaStorage
|
||||
from synapse.media.thumbnailer import ThumbnailProvider
|
||||
from synapse.util.stringutils import parse_and_validate_server_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -66,10 +61,11 @@ class ThumbnailResource(RestServlet):
|
|||
self.store = hs.get_datastores().main
|
||||
self.media_repo = media_repo
|
||||
self.media_storage = media_storage
|
||||
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
|
||||
self._is_mine_server_name = hs.is_mine_server_name
|
||||
self._server_name = hs.hostname
|
||||
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
|
||||
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
|
||||
self.thumbnail_provider = ThumbnailProvider(hs, media_repo, media_storage)
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, server_name: str, media_id: str
|
||||
|
@ -91,11 +87,11 @@ class ThumbnailResource(RestServlet):
|
|||
|
||||
if self._is_mine_server_name(server_name):
|
||||
if self.dynamic_thumbnails:
|
||||
await self._select_or_generate_local_thumbnail(
|
||||
await self.thumbnail_provider.select_or_generate_local_thumbnail(
|
||||
request, media_id, width, height, method, m_type, max_timeout_ms
|
||||
)
|
||||
else:
|
||||
await self._respond_local_thumbnail(
|
||||
await self.thumbnail_provider.respond_local_thumbnail(
|
||||
request, media_id, width, height, method, m_type, max_timeout_ms
|
||||
)
|
||||
self.media_repo.mark_recently_accessed(None, media_id)
|
||||
|
@ -109,9 +105,9 @@ class ThumbnailResource(RestServlet):
|
|||
return
|
||||
|
||||
remote_resp_function = (
|
||||
self._select_or_generate_remote_thumbnail
|
||||
self.thumbnail_provider.select_or_generate_remote_thumbnail
|
||||
if self.dynamic_thumbnails
|
||||
else self._respond_remote_thumbnail
|
||||
else self.thumbnail_provider.respond_remote_thumbnail
|
||||
)
|
||||
await remote_resp_function(
|
||||
request,
|
||||
|
@ -124,457 +120,3 @@ class ThumbnailResource(RestServlet):
|
|||
max_timeout_ms,
|
||||
)
|
||||
self.media_repo.mark_recently_accessed(server_name, media_id)
|
||||
|
||||
async def _respond_local_thumbnail(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
media_id: str,
|
||||
width: int,
|
||||
height: int,
|
||||
method: str,
|
||||
m_type: str,
|
||||
max_timeout_ms: int,
|
||||
) -> None:
|
||||
media_info = await self.media_repo.get_local_media_info(
|
||||
request, media_id, max_timeout_ms
|
||||
)
|
||||
if not media_info:
|
||||
return
|
||||
|
||||
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
|
||||
await self._select_and_respond_with_thumbnail(
|
||||
request,
|
||||
width,
|
||||
height,
|
||||
method,
|
||||
m_type,
|
||||
thumbnail_infos,
|
||||
media_id,
|
||||
media_id,
|
||||
url_cache=bool(media_info.url_cache),
|
||||
server_name=None,
|
||||
)
|
||||
|
||||
async def _select_or_generate_local_thumbnail(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
media_id: str,
|
||||
desired_width: int,
|
||||
desired_height: int,
|
||||
desired_method: str,
|
||||
desired_type: str,
|
||||
max_timeout_ms: int,
|
||||
) -> None:
|
||||
media_info = await self.media_repo.get_local_media_info(
|
||||
request, media_id, max_timeout_ms
|
||||
)
|
||||
|
||||
if not media_info:
|
||||
return
|
||||
|
||||
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
|
||||
for info in thumbnail_infos:
|
||||
t_w = info.width == desired_width
|
||||
t_h = info.height == desired_height
|
||||
t_method = info.method == desired_method
|
||||
t_type = info.type == desired_type
|
||||
|
||||
if t_w and t_h and t_method and t_type:
|
||||
file_info = FileInfo(
|
||||
server_name=None,
|
||||
file_id=media_id,
|
||||
url_cache=bool(media_info.url_cache),
|
||||
thumbnail=info,
|
||||
)
|
||||
|
||||
responder = await self.media_storage.fetch_media(file_info)
|
||||
if responder:
|
||||
await respond_with_responder(
|
||||
request, responder, info.type, info.length
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug("We don't have a thumbnail of that size. Generating")
|
||||
|
||||
# Okay, so we generate one.
|
||||
file_path = await self.media_repo.generate_local_exact_thumbnail(
|
||||
media_id,
|
||||
desired_width,
|
||||
desired_height,
|
||||
desired_method,
|
||||
desired_type,
|
||||
url_cache=bool(media_info.url_cache),
|
||||
)
|
||||
|
||||
if file_path:
|
||||
await respond_with_file(request, desired_type, file_path)
|
||||
else:
|
||||
logger.warning("Failed to generate thumbnail")
|
||||
raise SynapseError(400, "Failed to generate thumbnail.")
|
||||
|
||||
async def _select_or_generate_remote_thumbnail(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
server_name: str,
|
||||
media_id: str,
|
||||
desired_width: int,
|
||||
desired_height: int,
|
||||
desired_method: str,
|
||||
desired_type: str,
|
||||
max_timeout_ms: int,
|
||||
) -> None:
|
||||
media_info = await self.media_repo.get_remote_media_info(
|
||||
server_name, media_id, max_timeout_ms
|
||||
)
|
||||
if not media_info:
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
thumbnail_infos = await self.store.get_remote_media_thumbnails(
|
||||
server_name, media_id
|
||||
)
|
||||
|
||||
file_id = media_info.filesystem_id
|
||||
|
||||
for info in thumbnail_infos:
|
||||
t_w = info.width == desired_width
|
||||
t_h = info.height == desired_height
|
||||
t_method = info.method == desired_method
|
||||
t_type = info.type == desired_type
|
||||
|
||||
if t_w and t_h and t_method and t_type:
|
||||
file_info = FileInfo(
|
||||
server_name=server_name,
|
||||
file_id=file_id,
|
||||
thumbnail=info,
|
||||
)
|
||||
|
||||
responder = await self.media_storage.fetch_media(file_info)
|
||||
if responder:
|
||||
await respond_with_responder(
|
||||
request, responder, info.type, info.length
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug("We don't have a thumbnail of that size. Generating")
|
||||
|
||||
# Okay, so we generate one.
|
||||
file_path = await self.media_repo.generate_remote_exact_thumbnail(
|
||||
server_name,
|
||||
file_id,
|
||||
media_id,
|
||||
desired_width,
|
||||
desired_height,
|
||||
desired_method,
|
||||
desired_type,
|
||||
)
|
||||
|
||||
if file_path:
|
||||
await respond_with_file(request, desired_type, file_path)
|
||||
else:
|
||||
logger.warning("Failed to generate thumbnail")
|
||||
raise SynapseError(400, "Failed to generate thumbnail.")
|
||||
|
||||
async def _respond_remote_thumbnail(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
server_name: str,
|
||||
media_id: str,
|
||||
width: int,
|
||||
height: int,
|
||||
method: str,
|
||||
m_type: str,
|
||||
max_timeout_ms: int,
|
||||
) -> None:
|
||||
# TODO: Don't download the whole remote file
|
||||
# We should proxy the thumbnail from the remote server instead of
|
||||
# downloading the remote file and generating our own thumbnails.
|
||||
media_info = await self.media_repo.get_remote_media_info(
|
||||
server_name, media_id, max_timeout_ms
|
||||
)
|
||||
if not media_info:
|
||||
return
|
||||
|
||||
thumbnail_infos = await self.store.get_remote_media_thumbnails(
|
||||
server_name, media_id
|
||||
)
|
||||
await self._select_and_respond_with_thumbnail(
|
||||
request,
|
||||
width,
|
||||
height,
|
||||
method,
|
||||
m_type,
|
||||
thumbnail_infos,
|
||||
media_id,
|
||||
media_info.filesystem_id,
|
||||
url_cache=False,
|
||||
server_name=server_name,
|
||||
)
|
||||
|
||||
async def _select_and_respond_with_thumbnail(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
desired_width: int,
|
||||
desired_height: int,
|
||||
desired_method: str,
|
||||
desired_type: str,
|
||||
thumbnail_infos: List[ThumbnailInfo],
|
||||
media_id: str,
|
||||
file_id: str,
|
||||
url_cache: bool,
|
||||
server_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
desired_width: The desired width, the returned thumbnail may be larger than this.
|
||||
desired_height: The desired height, the returned thumbnail may be larger than this.
|
||||
desired_method: The desired method used to generate the thumbnail.
|
||||
desired_type: The desired content-type of the thumbnail.
|
||||
thumbnail_infos: A list of thumbnail info of candidate thumbnails.
|
||||
file_id: The ID of the media that a thumbnail is being requested for.
|
||||
url_cache: True if this is from a URL cache.
|
||||
server_name: The server name, if this is a remote thumbnail.
|
||||
"""
|
||||
logger.debug(
|
||||
"_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s",
|
||||
media_id,
|
||||
desired_width,
|
||||
desired_height,
|
||||
desired_method,
|
||||
thumbnail_infos,
|
||||
)
|
||||
|
||||
# If `dynamic_thumbnails` is enabled, we expect Synapse to go down a
|
||||
# different code path to handle it.
|
||||
assert not self.dynamic_thumbnails
|
||||
|
||||
if thumbnail_infos:
|
||||
file_info = self._select_thumbnail(
|
||||
desired_width,
|
||||
desired_height,
|
||||
desired_method,
|
||||
desired_type,
|
||||
thumbnail_infos,
|
||||
file_id,
|
||||
url_cache,
|
||||
server_name,
|
||||
)
|
||||
if not file_info:
|
||||
logger.info("Couldn't find a thumbnail matching the desired inputs")
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
# The thumbnail property must exist.
|
||||
assert file_info.thumbnail is not None
|
||||
|
||||
responder = await self.media_storage.fetch_media(file_info)
|
||||
if responder:
|
||||
await respond_with_responder(
|
||||
request,
|
||||
responder,
|
||||
file_info.thumbnail.type,
|
||||
file_info.thumbnail.length,
|
||||
)
|
||||
return
|
||||
|
||||
# If we can't find the thumbnail we regenerate it. This can happen
|
||||
# if e.g. we've deleted the thumbnails but still have the original
|
||||
# image somewhere.
|
||||
#
|
||||
# Since we have an entry for the thumbnail in the DB we a) know we
|
||||
# have have successfully generated the thumbnail in the past (so we
|
||||
# don't need to worry about repeatedly failing to generate
|
||||
# thumbnails), and b) have already calculated that appropriate
|
||||
# width/height/method so we can just call the "generate exact"
|
||||
# methods.
|
||||
|
||||
# First let's check that we do actually have the original image
|
||||
# still. This will throw a 404 if we don't.
|
||||
# TODO: We should refetch the thumbnails for remote media.
|
||||
await self.media_storage.ensure_media_is_in_local_cache(
|
||||
FileInfo(server_name, file_id, url_cache=url_cache)
|
||||
)
|
||||
|
||||
if server_name:
|
||||
await self.media_repo.generate_remote_exact_thumbnail(
|
||||
server_name,
|
||||
file_id=file_id,
|
||||
media_id=media_id,
|
||||
t_width=file_info.thumbnail.width,
|
||||
t_height=file_info.thumbnail.height,
|
||||
t_method=file_info.thumbnail.method,
|
||||
t_type=file_info.thumbnail.type,
|
||||
)
|
||||
else:
|
||||
await self.media_repo.generate_local_exact_thumbnail(
|
||||
media_id=media_id,
|
||||
t_width=file_info.thumbnail.width,
|
||||
t_height=file_info.thumbnail.height,
|
||||
t_method=file_info.thumbnail.method,
|
||||
t_type=file_info.thumbnail.type,
|
||||
url_cache=url_cache,
|
||||
)
|
||||
|
||||
responder = await self.media_storage.fetch_media(file_info)
|
||||
await respond_with_responder(
|
||||
request,
|
||||
responder,
|
||||
file_info.thumbnail.type,
|
||||
file_info.thumbnail.length,
|
||||
)
|
||||
else:
|
||||
# This might be because:
|
||||
# 1. We can't create thumbnails for the given media (corrupted or
|
||||
# unsupported file type), or
|
||||
# 2. The thumbnailing process never ran or errored out initially
|
||||
# when the media was first uploaded (these bugs should be
|
||||
# reported and fixed).
|
||||
# Note that we don't attempt to generate a thumbnail now because
|
||||
# `dynamic_thumbnails` is disabled.
|
||||
logger.info("Failed to find any generated thumbnails")
|
||||
|
||||
assert request.path is not None
|
||||
respond_with_json(
|
||||
request,
|
||||
400,
|
||||
cs_error(
|
||||
"Cannot find any thumbnails for the requested media ('%s'). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
|
||||
% (
|
||||
request.path.decode(),
|
||||
", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
|
||||
),
|
||||
code=Codes.UNKNOWN,
|
||||
),
|
||||
send_cors=True,
|
||||
)
|
||||
|
||||
def _select_thumbnail(
|
||||
self,
|
||||
desired_width: int,
|
||||
desired_height: int,
|
||||
desired_method: str,
|
||||
desired_type: str,
|
||||
thumbnail_infos: List[ThumbnailInfo],
|
||||
file_id: str,
|
||||
url_cache: bool,
|
||||
server_name: Optional[str],
|
||||
) -> Optional[FileInfo]:
|
||||
"""
|
||||
Choose an appropriate thumbnail from the previously generated thumbnails.
|
||||
|
||||
Args:
|
||||
desired_width: The desired width, the returned thumbnail may be larger than this.
|
||||
desired_height: The desired height, the returned thumbnail may be larger than this.
|
||||
desired_method: The desired method used to generate the thumbnail.
|
||||
desired_type: The desired content-type of the thumbnail.
|
||||
thumbnail_infos: A list of thumbnail infos of candidate thumbnails.
|
||||
file_id: The ID of the media that a thumbnail is being requested for.
|
||||
url_cache: True if this is from a URL cache.
|
||||
server_name: The server name, if this is a remote thumbnail.
|
||||
|
||||
Returns:
|
||||
The thumbnail which best matches the desired parameters.
|
||||
"""
|
||||
desired_method = desired_method.lower()
|
||||
|
||||
# The chosen thumbnail.
|
||||
thumbnail_info = None
|
||||
|
||||
d_w = desired_width
|
||||
d_h = desired_height
|
||||
|
||||
if desired_method == "crop":
|
||||
# Thumbnails that match equal or larger sizes of desired width/height.
|
||||
crop_info_list: List[
|
||||
Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
|
||||
] = []
|
||||
# Other thumbnails.
|
||||
crop_info_list2: List[
|
||||
Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
|
||||
] = []
|
||||
for info in thumbnail_infos:
|
||||
# Skip thumbnails generated with different methods.
|
||||
if info.method != "crop":
|
||||
continue
|
||||
|
||||
t_w = info.width
|
||||
t_h = info.height
|
||||
aspect_quality = abs(d_w * t_h - d_h * t_w)
|
||||
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
|
||||
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
||||
type_quality = desired_type != info.type
|
||||
length_quality = info.length
|
||||
if t_w >= d_w or t_h >= d_h:
|
||||
crop_info_list.append(
|
||||
(
|
||||
aspect_quality,
|
||||
min_quality,
|
||||
size_quality,
|
||||
type_quality,
|
||||
length_quality,
|
||||
info,
|
||||
)
|
||||
)
|
||||
else:
|
||||
crop_info_list2.append(
|
||||
(
|
||||
aspect_quality,
|
||||
min_quality,
|
||||
size_quality,
|
||||
type_quality,
|
||||
length_quality,
|
||||
info,
|
||||
)
|
||||
)
|
||||
# Pick the most appropriate thumbnail. Some values of `desired_width` and
|
||||
# `desired_height` may result in a tie, in which case we avoid comparing on
|
||||
# the thumbnail info and pick the thumbnail that appears earlier
|
||||
# in the list of candidates.
|
||||
if crop_info_list:
|
||||
thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1]
|
||||
elif crop_info_list2:
|
||||
thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1]
|
||||
elif desired_method == "scale":
|
||||
# Thumbnails that match equal or larger sizes of desired width/height.
|
||||
info_list: List[Tuple[int, bool, int, ThumbnailInfo]] = []
|
||||
# Other thumbnails.
|
||||
info_list2: List[Tuple[int, bool, int, ThumbnailInfo]] = []
|
||||
|
||||
for info in thumbnail_infos:
|
||||
# Skip thumbnails generated with different methods.
|
||||
if info.method != "scale":
|
||||
continue
|
||||
|
||||
t_w = info.width
|
||||
t_h = info.height
|
||||
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
||||
type_quality = desired_type != info.type
|
||||
length_quality = info.length
|
||||
if t_w >= d_w or t_h >= d_h:
|
||||
info_list.append((size_quality, type_quality, length_quality, info))
|
||||
else:
|
||||
info_list2.append(
|
||||
(size_quality, type_quality, length_quality, info)
|
||||
)
|
||||
# Pick the most appropriate thumbnail. Some values of `desired_width` and
|
||||
# `desired_height` may result in a tie, in which case we avoid comparing on
|
||||
# the thumbnail info and pick the thumbnail that appears earlier
|
||||
# in the list of candidates.
|
||||
if info_list:
|
||||
thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1]
|
||||
elif info_list2:
|
||||
thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1]
|
||||
|
||||
if thumbnail_info:
|
||||
return FileInfo(
|
||||
file_id=file_id,
|
||||
url_cache=url_cache,
|
||||
server_name=server_name,
|
||||
thumbnail=thumbnail_info,
|
||||
)
|
||||
|
||||
# No matching thumbnail was found.
|
||||
return None
|
||||
|
|
|
@ -2461,7 +2461,11 @@ class DatabasePool:
|
|||
|
||||
|
||||
def make_in_list_sql_clause(
|
||||
database_engine: BaseDatabaseEngine, column: str, iterable: Collection[Any]
|
||||
database_engine: BaseDatabaseEngine,
|
||||
column: str,
|
||||
iterable: Collection[Any],
|
||||
*,
|
||||
negative: bool = False,
|
||||
) -> Tuple[str, list]:
|
||||
"""Returns an SQL clause that checks the given column is in the iterable.
|
||||
|
||||
|
@ -2474,6 +2478,7 @@ def make_in_list_sql_clause(
|
|||
database_engine
|
||||
column: Name of the column
|
||||
iterable: The values to check the column against.
|
||||
negative: Whether we should check for inequality, i.e. `NOT IN`
|
||||
|
||||
Returns:
|
||||
A tuple of SQL query and the args
|
||||
|
@ -2482,9 +2487,19 @@ def make_in_list_sql_clause(
|
|||
if database_engine.supports_using_any_list:
|
||||
# This should hopefully be faster, but also makes postgres query
|
||||
# stats easier to understand.
|
||||
return "%s = ANY(?)" % (column,), [list(iterable)]
|
||||
if not negative:
|
||||
clause = f"{column} = ANY(?)"
|
||||
else:
|
||||
clause = f"{column} != ALL(?)"
|
||||
|
||||
return clause, [list(iterable)]
|
||||
else:
|
||||
return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
|
||||
params = ",".join("?" for _ in iterable)
|
||||
if not negative:
|
||||
clause = f"{column} IN ({params})"
|
||||
else:
|
||||
clause = f"{column} NOT IN ({params})"
|
||||
return clause, list(iterable)
|
||||
|
||||
|
||||
# These overloads ensure that `columns` and `iterable` values have the same length.
|
||||
|
|
|
@ -43,11 +43,9 @@ from synapse.storage.database import (
|
|||
)
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdGenerator,
|
||||
MultiWriterIdGenerator,
|
||||
StreamIdGenerator,
|
||||
)
|
||||
from synapse.types import JsonDict, JsonMapping
|
||||
from synapse.util import json_encoder
|
||||
|
@ -75,37 +73,20 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
|||
|
||||
self._account_data_id_gen: AbstractStreamIdGenerator
|
||||
|
||||
if isinstance(database.engine, PostgresEngine):
|
||||
self._account_data_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="account_data",
|
||||
instance_name=self._instance_name,
|
||||
tables=[
|
||||
("room_account_data", "instance_name", "stream_id"),
|
||||
("room_tags_revisions", "instance_name", "stream_id"),
|
||||
("account_data", "instance_name", "stream_id"),
|
||||
],
|
||||
sequence_name="account_data_sequence",
|
||||
writers=hs.config.worker.writers.account_data,
|
||||
)
|
||||
else:
|
||||
# Multiple writers are not supported for SQLite.
|
||||
#
|
||||
# We shouldn't be running in worker mode with SQLite, but its useful
|
||||
# to support it for unit tests.
|
||||
self._account_data_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
hs.get_replication_notifier(),
|
||||
"room_account_data",
|
||||
"stream_id",
|
||||
extra_tables=[
|
||||
("account_data", "stream_id"),
|
||||
("room_tags_revisions", "stream_id"),
|
||||
],
|
||||
is_writer=self._instance_name in hs.config.worker.writers.account_data,
|
||||
)
|
||||
self._account_data_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="account_data",
|
||||
instance_name=self._instance_name,
|
||||
tables=[
|
||||
("room_account_data", "instance_name", "stream_id"),
|
||||
("room_tags_revisions", "instance_name", "stream_id"),
|
||||
("account_data", "instance_name", "stream_id"),
|
||||
],
|
||||
sequence_name="account_data_sequence",
|
||||
writers=hs.config.worker.writers.account_data,
|
||||
)
|
||||
|
||||
account_max = self.get_max_account_data_stream_id()
|
||||
self._account_data_stream_cache = StreamChangeCache(
|
||||
|
|
|
@ -318,7 +318,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
|||
self._invalidate_local_get_event_cache(redacts) # type: ignore[attr-defined]
|
||||
# Caches which might leak edits must be invalidated for the event being
|
||||
# redacted.
|
||||
self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,))
|
||||
self._attempt_to_invalidate_cache(
|
||||
"get_relations_for_event",
|
||||
(
|
||||
room_id,
|
||||
redacts,
|
||||
),
|
||||
)
|
||||
self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,))
|
||||
self._attempt_to_invalidate_cache("get_thread_id", (redacts,))
|
||||
self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,))
|
||||
|
@ -345,7 +351,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
if relates_to:
|
||||
self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,))
|
||||
self._attempt_to_invalidate_cache(
|
||||
"get_relations_for_event",
|
||||
(
|
||||
room_id,
|
||||
relates_to,
|
||||
),
|
||||
)
|
||||
self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,))
|
||||
self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
|
||||
self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
|
||||
|
@ -380,9 +392,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
|||
self._attempt_to_invalidate_cache(
|
||||
"get_unread_event_push_actions_by_room_for_user", (room_id,)
|
||||
)
|
||||
self._attempt_to_invalidate_cache("get_relations_for_event", (room_id,))
|
||||
|
||||
self._attempt_to_invalidate_cache("_get_membership_from_event_id", None)
|
||||
self._attempt_to_invalidate_cache("get_relations_for_event", None)
|
||||
self._attempt_to_invalidate_cache("get_applicable_edit", None)
|
||||
self._attempt_to_invalidate_cache("get_thread_id", None)
|
||||
self._attempt_to_invalidate_cache("get_thread_id_for_receipts", None)
|
||||
|
|
|
@ -50,16 +50,15 @@ from synapse.storage.database import (
|
|||
LoggingTransaction,
|
||||
make_in_list_sql_clause,
|
||||
)
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdGenerator,
|
||||
MultiWriterIdGenerator,
|
||||
StreamIdGenerator,
|
||||
)
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from synapse.util.stringutils import parse_and_validate_server_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -89,35 +88,23 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
expiry_ms=30 * 60 * 1000,
|
||||
)
|
||||
|
||||
if isinstance(database.engine, PostgresEngine):
|
||||
self._can_write_to_device = (
|
||||
self._instance_name in hs.config.worker.writers.to_device
|
||||
)
|
||||
self._can_write_to_device = (
|
||||
self._instance_name in hs.config.worker.writers.to_device
|
||||
)
|
||||
|
||||
self._to_device_msg_id_gen: AbstractStreamIdGenerator = (
|
||||
MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="to_device",
|
||||
instance_name=self._instance_name,
|
||||
tables=[
|
||||
("device_inbox", "instance_name", "stream_id"),
|
||||
("device_federation_outbox", "instance_name", "stream_id"),
|
||||
],
|
||||
sequence_name="device_inbox_sequence",
|
||||
writers=hs.config.worker.writers.to_device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._can_write_to_device = True
|
||||
self._to_device_msg_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
hs.get_replication_notifier(),
|
||||
"device_inbox",
|
||||
"stream_id",
|
||||
extra_tables=[("device_federation_outbox", "stream_id")],
|
||||
)
|
||||
self._to_device_msg_id_gen: AbstractStreamIdGenerator = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="to_device",
|
||||
instance_name=self._instance_name,
|
||||
tables=[
|
||||
("device_inbox", "instance_name", "stream_id"),
|
||||
("device_federation_outbox", "instance_name", "stream_id"),
|
||||
],
|
||||
sequence_name="device_inbox_sequence",
|
||||
writers=hs.config.worker.writers.to_device,
|
||||
)
|
||||
|
||||
max_device_inbox_id = self._to_device_msg_id_gen.get_current_token()
|
||||
device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
|
||||
|
@ -978,6 +965,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
|
||||
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
|
||||
REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
|
||||
CLEANUP_DEVICE_FEDERATION_OUTBOX = "cleanup_device_federation_outbox"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -1003,6 +991,11 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
|
|||
self._remove_dead_devices_from_device_inbox,
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
self.CLEANUP_DEVICE_FEDERATION_OUTBOX,
|
||||
self._cleanup_device_federation_outbox,
|
||||
)
|
||||
|
||||
async def _background_drop_index_device_inbox(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
|
@ -1094,6 +1087,75 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
|
|||
|
||||
return batch_size
|
||||
|
||||
async def _cleanup_device_federation_outbox(
|
||||
self,
|
||||
progress: JsonDict,
|
||||
batch_size: int,
|
||||
) -> int:
|
||||
def _cleanup_device_federation_outbox_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> bool:
|
||||
if "max_stream_id" in progress:
|
||||
max_stream_id = progress["max_stream_id"]
|
||||
else:
|
||||
txn.execute("SELECT max(stream_id) FROM device_federation_outbox")
|
||||
res = cast(Tuple[Optional[int]], txn.fetchone())
|
||||
if res[0] is None:
|
||||
# this can only happen if the `device_inbox` table is empty, in which
|
||||
# case we have no work to do.
|
||||
return True
|
||||
else:
|
||||
max_stream_id = res[0]
|
||||
|
||||
start = progress.get("stream_id", 0)
|
||||
stop = start + batch_size
|
||||
|
||||
sql = """
|
||||
SELECT destination FROM device_federation_outbox
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (start, stop))
|
||||
|
||||
destinations = {d for d, in txn}
|
||||
to_remove = set()
|
||||
for d in destinations:
|
||||
try:
|
||||
parse_and_validate_server_name(d)
|
||||
except ValueError:
|
||||
to_remove.add(d)
|
||||
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="device_federation_outbox",
|
||||
column="destination",
|
||||
values=to_remove,
|
||||
keyvalues={},
|
||||
)
|
||||
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn,
|
||||
self.CLEANUP_DEVICE_FEDERATION_OUTBOX,
|
||||
{
|
||||
"stream_id": stop,
|
||||
"max_stream_id": max_stream_id,
|
||||
},
|
||||
)
|
||||
|
||||
return stop >= max_stream_id
|
||||
|
||||
finished = await self.db_pool.runInteraction(
|
||||
"_cleanup_device_federation_outbox",
|
||||
_cleanup_device_federation_outbox_txn,
|
||||
)
|
||||
|
||||
if finished:
|
||||
await self.db_pool.updates._end_background_update(
|
||||
self.CLEANUP_DEVICE_FEDERATION_OUTBOX,
|
||||
)
|
||||
|
||||
return batch_size
|
||||
|
||||
|
||||
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
|
||||
pass
|
||||
|
|
|
@ -57,10 +57,7 @@ from synapse.storage.database import (
|
|||
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
|
||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdGenerator,
|
||||
StreamIdGenerator,
|
||||
)
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
JsonMapping,
|
||||
|
@ -99,19 +96,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
|
||||
# In the worker store this is an ID tracker which we overwrite in the non-worker
|
||||
# class below that is used on the main process.
|
||||
self._device_list_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
hs.get_replication_notifier(),
|
||||
"device_lists_stream",
|
||||
"stream_id",
|
||||
extra_tables=[
|
||||
("user_signature_stream", "stream_id"),
|
||||
("device_lists_outbound_pokes", "stream_id"),
|
||||
("device_lists_changes_in_room", "stream_id"),
|
||||
("device_lists_remote_pending", "stream_id"),
|
||||
("device_lists_changes_converted_stream_position", "stream_id"),
|
||||
self._device_list_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="device_lists_stream",
|
||||
instance_name=self._instance_name,
|
||||
tables=[
|
||||
("device_lists_stream", "instance_name", "stream_id"),
|
||||
("user_signature_stream", "instance_name", "stream_id"),
|
||||
("device_lists_outbound_pokes", "instance_name", "stream_id"),
|
||||
("device_lists_changes_in_room", "instance_name", "stream_id"),
|
||||
("device_lists_remote_pending", "instance_name", "stream_id"),
|
||||
],
|
||||
is_writer=hs.config.worker.worker_app is None,
|
||||
sequence_name="device_lists_sequence",
|
||||
writers=["master"],
|
||||
)
|
||||
|
||||
device_list_max = self._device_list_id_gen.get_current_token()
|
||||
|
@ -762,6 +761,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
"stream_id": stream_id,
|
||||
"from_user_id": from_user_id,
|
||||
"user_ids": json_encoder.encode(user_ids),
|
||||
"instance_name": self._instance_name,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -1582,6 +1582,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
|||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
"device_lists_stream_idx",
|
||||
index_name="device_lists_stream_user_id",
|
||||
|
@ -1694,6 +1696,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
|||
"device_lists_outbound_pokes",
|
||||
{
|
||||
"stream_id": stream_id,
|
||||
"instance_name": self._instance_name,
|
||||
"destination": destination,
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
|
@ -1730,10 +1733,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
|||
|
||||
|
||||
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
# Because we have write access, this will be a StreamIdGenerator
|
||||
# (see DeviceWorkerStore.__init__)
|
||||
_device_list_id_gen: AbstractStreamIdGenerator
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
|
@ -2092,9 +2091,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
self.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="device_lists_stream",
|
||||
keys=("stream_id", "user_id", "device_id"),
|
||||
keys=("instance_name", "stream_id", "user_id", "device_id"),
|
||||
values=[
|
||||
(stream_id, user_id, device_id)
|
||||
(self._instance_name, stream_id, user_id, device_id)
|
||||
for stream_id, device_id in zip(stream_ids, device_ids)
|
||||
],
|
||||
)
|
||||
|
@ -2124,6 +2123,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
values = [
|
||||
(
|
||||
destination,
|
||||
self._instance_name,
|
||||
next(stream_id_iterator),
|
||||
user_id,
|
||||
device_id,
|
||||
|
@ -2139,6 +2139,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
table="device_lists_outbound_pokes",
|
||||
keys=(
|
||||
"destination",
|
||||
"instance_name",
|
||||
"stream_id",
|
||||
"user_id",
|
||||
"device_id",
|
||||
|
@ -2157,10 +2158,34 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
device_id,
|
||||
{
|
||||
stream_id: destination
|
||||
for (destination, stream_id, _, _, _, _, _) in values
|
||||
for (destination, _, stream_id, _, _, _, _, _) in values
|
||||
},
|
||||
)
|
||||
|
||||
async def mark_redundant_device_lists_pokes(
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
room_id: str,
|
||||
converted_upto_stream_id: int,
|
||||
) -> None:
|
||||
"""If we've calculated the outbound pokes for a given room/device list
|
||||
update, mark any subsequent changes as already converted"""
|
||||
|
||||
sql = """
|
||||
UPDATE device_lists_changes_in_room
|
||||
SET converted_to_destinations = true
|
||||
WHERE stream_id > ? AND user_id = ? AND device_id = ?
|
||||
AND room_id = ? AND NOT converted_to_destinations
|
||||
"""
|
||||
|
||||
def mark_redundant_device_lists_pokes_txn(txn: LoggingTransaction) -> None:
|
||||
txn.execute(sql, (converted_upto_stream_id, user_id, device_id, room_id))
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"mark_redundant_device_lists_pokes", mark_redundant_device_lists_pokes_txn
|
||||
)
|
||||
|
||||
def _add_device_outbound_room_poke_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
|
@ -2186,6 +2211,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
"device_id",
|
||||
"room_id",
|
||||
"stream_id",
|
||||
"instance_name",
|
||||
"converted_to_destinations",
|
||||
"opentracing_context",
|
||||
),
|
||||
|
@ -2195,6 +2221,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
device_id,
|
||||
room_id,
|
||||
stream_id,
|
||||
self._instance_name,
|
||||
# We only need to calculate outbound pokes for local users
|
||||
not self.hs.is_mine_id(user_id),
|
||||
encoded_context,
|
||||
|
@ -2314,7 +2341,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
values={"stream_id": stream_id},
|
||||
values={
|
||||
"stream_id": stream_id,
|
||||
"instance_name": self._instance_name,
|
||||
},
|
||||
desc="add_remote_device_list_to_pending",
|
||||
)
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ from synapse.storage.database import (
|
|||
)
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import JsonDict, JsonMapping
|
||||
from synapse.util import json_decoder, json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
@ -1448,11 +1448,17 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._cross_signing_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
hs.get_replication_notifier(),
|
||||
"e2e_cross_signing_keys",
|
||||
"stream_id",
|
||||
self._cross_signing_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="e2e_cross_signing_keys",
|
||||
instance_name=self._instance_name,
|
||||
tables=[
|
||||
("e2e_cross_signing_keys", "instance_name", "stream_id"),
|
||||
],
|
||||
sequence_name="e2e_cross_signing_keys_sequence",
|
||||
writers=["master"],
|
||||
)
|
||||
|
||||
async def set_e2e_device_keys(
|
||||
|
@ -1627,6 +1633,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
"keytype": key_type,
|
||||
"keydata": json_encoder.encode(key),
|
||||
"stream_id": stream_id,
|
||||
"instance_name": self._instance_name,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
@ -95,6 +95,10 @@ class DeltaState:
|
|||
to_insert: StateMap[str]
|
||||
no_longer_in_room: bool = False
|
||||
|
||||
def is_noop(self) -> bool:
|
||||
"""Whether this state delta is actually empty"""
|
||||
return not self.to_delete and not self.to_insert and not self.no_longer_in_room
|
||||
|
||||
|
||||
class PersistEventsStore:
|
||||
"""Contains all the functions for writing events to the database.
|
||||
|
@ -1017,6 +1021,9 @@ class PersistEventsStore:
|
|||
) -> None:
|
||||
"""Update the current state stored in the datatabase for the given room"""
|
||||
|
||||
if state_delta.is_noop():
|
||||
return
|
||||
|
||||
async with self._stream_id_gen.get_next() as stream_ordering:
|
||||
await self.db_pool.runInteraction(
|
||||
"update_current_state",
|
||||
|
@ -1923,7 +1930,12 @@ class PersistEventsStore:
|
|||
|
||||
# Any relation information for the related event must be cleared.
|
||||
self.store._invalidate_cache_and_stream(
|
||||
txn, self.store.get_relations_for_event, (redacted_relates_to,)
|
||||
txn,
|
||||
self.store.get_relations_for_event,
|
||||
(
|
||||
room_id,
|
||||
redacted_relates_to,
|
||||
),
|
||||
)
|
||||
if rel_type == RelationTypes.REFERENCE:
|
||||
self.store._invalidate_cache_and_stream(
|
||||
|
|
|
@ -1181,7 +1181,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
|
||||
results = list(txn)
|
||||
# (event_id, parent_id, rel_type) for each relation
|
||||
relations_to_insert: List[Tuple[str, str, str]] = []
|
||||
relations_to_insert: List[Tuple[str, str, str, str]] = []
|
||||
for event_id, event_json_raw in results:
|
||||
try:
|
||||
event_json = db_to_json(event_json_raw)
|
||||
|
@ -1214,7 +1214,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
if not isinstance(parent_id, str):
|
||||
continue
|
||||
|
||||
relations_to_insert.append((event_id, parent_id, rel_type))
|
||||
room_id = event_json["room_id"]
|
||||
relations_to_insert.append((room_id, event_id, parent_id, rel_type))
|
||||
|
||||
# Insert the missing data, note that we upsert here in case the event
|
||||
# has already been processed.
|
||||
|
@ -1223,18 +1224,27 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
txn=txn,
|
||||
table="event_relations",
|
||||
key_names=("event_id",),
|
||||
key_values=[(r[0],) for r in relations_to_insert],
|
||||
key_values=[(r[1],) for r in relations_to_insert],
|
||||
value_names=("relates_to_id", "relation_type"),
|
||||
value_values=[r[1:] for r in relations_to_insert],
|
||||
value_values=[r[2:] for r in relations_to_insert],
|
||||
)
|
||||
|
||||
# Iterate the parent IDs and invalidate caches.
|
||||
cache_tuples = {(r[1],) for r in relations_to_insert}
|
||||
self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
|
||||
txn, self.get_relations_for_event, cache_tuples # type: ignore[attr-defined]
|
||||
txn,
|
||||
self.get_relations_for_event, # type: ignore[attr-defined]
|
||||
{
|
||||
(
|
||||
r[0], # room_id
|
||||
r[2], # parent_id
|
||||
)
|
||||
for r in relations_to_insert
|
||||
},
|
||||
)
|
||||
self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
|
||||
txn, self.get_thread_summary, cache_tuples # type: ignore[attr-defined]
|
||||
txn,
|
||||
self.get_thread_summary, # type: ignore[attr-defined]
|
||||
{(r[1],) for r in relations_to_insert},
|
||||
)
|
||||
|
||||
if results:
|
||||
|
|
|
@ -75,12 +75,10 @@ from synapse.storage.database import (
|
|||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdGenerator,
|
||||
MultiWriterIdGenerator,
|
||||
StreamIdGenerator,
|
||||
)
|
||||
from synapse.storage.util.sequence import build_sequence_generator
|
||||
from synapse.types import JsonDict, get_domain_from_id
|
||||
|
@ -195,51 +193,35 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
self._stream_id_gen: AbstractStreamIdGenerator
|
||||
self._backfill_id_gen: AbstractStreamIdGenerator
|
||||
if isinstance(database.engine, PostgresEngine):
|
||||
# If we're using Postgres than we can use `MultiWriterIdGenerator`
|
||||
# regardless of whether this process writes to the streams or not.
|
||||
self._stream_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="events",
|
||||
instance_name=hs.get_instance_name(),
|
||||
tables=[("events", "instance_name", "stream_ordering")],
|
||||
sequence_name="events_stream_seq",
|
||||
writers=hs.config.worker.writers.events,
|
||||
)
|
||||
self._backfill_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="backfill",
|
||||
instance_name=hs.get_instance_name(),
|
||||
tables=[("events", "instance_name", "stream_ordering")],
|
||||
sequence_name="events_backfill_stream_seq",
|
||||
positive=False,
|
||||
writers=hs.config.worker.writers.events,
|
||||
)
|
||||
else:
|
||||
# Multiple writers are not supported for SQLite.
|
||||
#
|
||||
# We shouldn't be running in worker mode with SQLite, but its useful
|
||||
# to support it for unit tests.
|
||||
self._stream_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
hs.get_replication_notifier(),
|
||||
"events",
|
||||
"stream_ordering",
|
||||
is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
|
||||
)
|
||||
self._backfill_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
hs.get_replication_notifier(),
|
||||
"events",
|
||||
"stream_ordering",
|
||||
step=-1,
|
||||
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
|
||||
is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
|
||||
)
|
||||
|
||||
self._stream_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="events",
|
||||
instance_name=hs.get_instance_name(),
|
||||
tables=[
|
||||
("events", "instance_name", "stream_ordering"),
|
||||
("current_state_delta_stream", "instance_name", "stream_id"),
|
||||
("ex_outlier_stream", "instance_name", "event_stream_ordering"),
|
||||
],
|
||||
sequence_name="events_stream_seq",
|
||||
writers=hs.config.worker.writers.events,
|
||||
)
|
||||
self._backfill_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="backfill",
|
||||
instance_name=hs.get_instance_name(),
|
||||
tables=[
|
||||
("events", "instance_name", "stream_ordering"),
|
||||
("ex_outlier_stream", "instance_name", "event_stream_ordering"),
|
||||
],
|
||||
sequence_name="events_backfill_stream_seq",
|
||||
positive=False,
|
||||
writers=hs.config.worker.writers.events,
|
||||
)
|
||||
|
||||
events_max = self._stream_id_gen.get_current_token()
|
||||
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
|
||||
|
@ -309,27 +291,17 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
self._un_partial_stated_events_stream_id_gen: AbstractStreamIdGenerator
|
||||
|
||||
if isinstance(database.engine, PostgresEngine):
|
||||
self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="un_partial_stated_event_stream",
|
||||
instance_name=hs.get_instance_name(),
|
||||
tables=[
|
||||
("un_partial_stated_event_stream", "instance_name", "stream_id")
|
||||
],
|
||||
sequence_name="un_partial_stated_event_stream_sequence",
|
||||
# TODO(faster_joins, multiple writers) Support multiple writers.
|
||||
writers=["master"],
|
||||
)
|
||||
else:
|
||||
self._un_partial_stated_events_stream_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
hs.get_replication_notifier(),
|
||||
"un_partial_stated_event_stream",
|
||||
"stream_id",
|
||||
)
|
||||
self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="un_partial_stated_event_stream",
|
||||
instance_name=hs.get_instance_name(),
|
||||
tables=[("un_partial_stated_event_stream", "instance_name", "stream_id")],
|
||||
sequence_name="un_partial_stated_event_stream_sequence",
|
||||
# TODO(faster_joins, multiple writers) Support multiple writers.
|
||||
writers=["master"],
|
||||
)
|
||||
|
||||
def get_un_partial_stated_events_token(self, instance_name: str) -> int:
|
||||
return (
|
||||
|
|
|
@ -40,13 +40,11 @@ from synapse.storage.database import (
|
|||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.engines._base import IsolationLevel
|
||||
from synapse.storage.types import Connection
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdGenerator,
|
||||
MultiWriterIdGenerator,
|
||||
StreamIdGenerator,
|
||||
)
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
@ -91,21 +89,16 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
|
|||
self._instance_name in hs.config.worker.writers.presence
|
||||
)
|
||||
|
||||
if isinstance(database.engine, PostgresEngine):
|
||||
self._presence_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="presence_stream",
|
||||
instance_name=self._instance_name,
|
||||
tables=[("presence_stream", "instance_name", "stream_id")],
|
||||
sequence_name="presence_stream_sequence",
|
||||
writers=hs.config.worker.writers.presence,
|
||||
)
|
||||
else:
|
||||
self._presence_id_gen = StreamIdGenerator(
|
||||
db_conn, hs.get_replication_notifier(), "presence_stream", "stream_id"
|
||||
)
|
||||
self._presence_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="presence_stream",
|
||||
instance_name=self._instance_name,
|
||||
tables=[("presence_stream", "instance_name", "stream_id")],
|
||||
sequence_name="presence_stream_sequence",
|
||||
writers=hs.config.worker.writers.presence,
|
||||
)
|
||||
|
||||
self.hs = hs
|
||||
self._presence_on_startup = self._get_active_presence(db_conn)
|
||||
|
|
|
@ -53,7 +53,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
|
|||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
|
||||
from synapse.storage.util.id_generators import IdGenerator, StreamIdGenerator
|
||||
from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
|
||||
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder, unwrapFirstError
|
||||
|
@ -126,7 +126,7 @@ class PushRulesWorkerStore(
|
|||
`get_max_push_rules_stream_id` which can be called in the initializer.
|
||||
"""
|
||||
|
||||
_push_rules_stream_id_gen: StreamIdGenerator
|
||||
_push_rules_stream_id_gen: MultiWriterIdGenerator
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -140,14 +140,17 @@ class PushRulesWorkerStore(
|
|||
hs.get_instance_name() in hs.config.worker.writers.push_rules
|
||||
)
|
||||
|
||||
# In the worker store this is an ID tracker which we overwrite in the non-worker
|
||||
# class below that is used on the main process.
|
||||
self._push_rules_stream_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
hs.get_replication_notifier(),
|
||||
"push_rules_stream",
|
||||
"stream_id",
|
||||
is_writer=self._is_push_writer,
|
||||
self._push_rules_stream_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="push_rules_stream",
|
||||
instance_name=self._instance_name,
|
||||
tables=[
|
||||
("push_rules_stream", "instance_name", "stream_id"),
|
||||
],
|
||||
sequence_name="push_rules_stream_sequence",
|
||||
writers=hs.config.worker.writers.push_rules,
|
||||
)
|
||||
|
||||
push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
|
||||
|
@ -880,6 +883,7 @@ class PushRulesWorkerStore(
|
|||
raise Exception("Not a push writer")
|
||||
|
||||
values = {
|
||||
"instance_name": self._instance_name,
|
||||
"stream_id": stream_id,
|
||||
"event_stream_ordering": event_stream_ordering,
|
||||
"user_id": user_id,
|
||||
|
|
|
@ -40,10 +40,7 @@ from synapse.storage.database import (
|
|||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdGenerator,
|
||||
StreamIdGenerator,
|
||||
)
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
@ -84,15 +81,20 @@ class PusherWorkerStore(SQLBaseStore):
|
|||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
# In the worker store this is an ID tracker which we overwrite in the non-worker
|
||||
# class below that is used on the main process.
|
||||
self._pushers_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
hs.get_replication_notifier(),
|
||||
"pushers",
|
||||
"id",
|
||||
extra_tables=[("deleted_pushers", "stream_id")],
|
||||
is_writer=hs.config.worker.worker_app is None,
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
self._pushers_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="pushers",
|
||||
instance_name=self._instance_name,
|
||||
tables=[
|
||||
("pushers", "instance_name", "id"),
|
||||
("deleted_pushers", "instance_name", "stream_id"),
|
||||
],
|
||||
sequence_name="pushers_sequence",
|
||||
writers=["master"],
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
|
@ -655,7 +657,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
|
|||
class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
||||
# Because we have write access, this will be a StreamIdGenerator
|
||||
# (see PusherWorkerStore.__init__)
|
||||
_pushers_id_gen: AbstractStreamIdGenerator
|
||||
_pushers_id_gen: MultiWriterIdGenerator
|
||||
|
||||
async def add_pusher(
|
||||
self,
|
||||
|
@ -688,6 +690,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
|||
"last_stream_ordering": last_stream_ordering,
|
||||
"profile_tag": profile_tag,
|
||||
"id": stream_id,
|
||||
"instance_name": self._instance_name,
|
||||
"enabled": enabled,
|
||||
"device_id": device_id,
|
||||
# XXX(quenting): We're only really persisting the access token ID
|
||||
|
@ -735,6 +738,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
|||
table="deleted_pushers",
|
||||
values={
|
||||
"stream_id": stream_id,
|
||||
"instance_name": self._instance_name,
|
||||
"app_id": app_id,
|
||||
"pushkey": pushkey,
|
||||
"user_id": user_id,
|
||||
|
@ -773,9 +777,15 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
|||
self.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="deleted_pushers",
|
||||
keys=("stream_id", "app_id", "pushkey", "user_id"),
|
||||
keys=("stream_id", "instance_name", "app_id", "pushkey", "user_id"),
|
||||
values=[
|
||||
(stream_id, pusher.app_id, pusher.pushkey, user_id)
|
||||
(
|
||||
stream_id,
|
||||
self._instance_name,
|
||||
pusher.app_id,
|
||||
pusher.pushkey,
|
||||
user_id,
|
||||
)
|
||||
for stream_id, pusher in zip(stream_ids, pushers)
|
||||
],
|
||||
)
|
||||
|
|
|
@ -44,12 +44,10 @@ from synapse.storage.database import (
|
|||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.engines._base import IsolationLevel
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdGenerator,
|
||||
MultiWriterIdGenerator,
|
||||
StreamIdGenerator,
|
||||
)
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
|
@ -80,35 +78,20 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
# class below that is used on the main process.
|
||||
self._receipts_id_gen: AbstractStreamIdGenerator
|
||||
|
||||
if isinstance(database.engine, PostgresEngine):
|
||||
self._can_write_to_receipts = (
|
||||
self._instance_name in hs.config.worker.writers.receipts
|
||||
)
|
||||
self._can_write_to_receipts = (
|
||||
self._instance_name in hs.config.worker.writers.receipts
|
||||
)
|
||||
|
||||
self._receipts_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="receipts",
|
||||
instance_name=self._instance_name,
|
||||
tables=[("receipts_linearized", "instance_name", "stream_id")],
|
||||
sequence_name="receipts_sequence",
|
||||
writers=hs.config.worker.writers.receipts,
|
||||
)
|
||||
else:
|
||||
self._can_write_to_receipts = True
|
||||
|
||||
# Multiple writers are not supported for SQLite.
|
||||
#
|
||||
# We shouldn't be running in worker mode with SQLite, but its useful
|
||||
# to support it for unit tests.
|
||||
self._receipts_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
hs.get_replication_notifier(),
|
||||
"receipts_linearized",
|
||||
"stream_id",
|
||||
is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts,
|
||||
)
|
||||
self._receipts_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="receipts",
|
||||
instance_name=self._instance_name,
|
||||
tables=[("receipts_linearized", "instance_name", "stream_id")],
|
||||
sequence_name="receipts_sequence",
|
||||
writers=hs.config.worker.writers.receipts,
|
||||
)
|
||||
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
|
|
|
@ -169,9 +169,9 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
@cached(uncached_args=("event",), tree=True)
|
||||
async def get_relations_for_event(
|
||||
self,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
event: EventBase,
|
||||
room_id: str,
|
||||
relation_type: Optional[str] = None,
|
||||
event_type: Optional[str] = None,
|
||||
limit: int = 5,
|
||||
|
|
|
@ -58,13 +58,11 @@ from synapse.storage.database import (
|
|||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdGenerator,
|
||||
IdGenerator,
|
||||
MultiWriterIdGenerator,
|
||||
StreamIdGenerator,
|
||||
)
|
||||
from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
|
||||
from synapse.util import json_encoder
|
||||
|
@ -155,27 +153,17 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator
|
||||
|
||||
if isinstance(database.engine, PostgresEngine):
|
||||
self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="un_partial_stated_room_stream",
|
||||
instance_name=self._instance_name,
|
||||
tables=[
|
||||
("un_partial_stated_room_stream", "instance_name", "stream_id")
|
||||
],
|
||||
sequence_name="un_partial_stated_room_stream_sequence",
|
||||
# TODO(faster_joins, multiple writers) Support multiple writers.
|
||||
writers=["master"],
|
||||
)
|
||||
else:
|
||||
self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
hs.get_replication_notifier(),
|
||||
"un_partial_stated_room_stream",
|
||||
"stream_id",
|
||||
)
|
||||
self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="un_partial_stated_room_stream",
|
||||
instance_name=self._instance_name,
|
||||
tables=[("un_partial_stated_room_stream", "instance_name", "stream_id")],
|
||||
sequence_name="un_partial_stated_room_stream_sequence",
|
||||
# TODO(faster_joins, multiple writers) Support multiple writers.
|
||||
writers=["master"],
|
||||
)
|
||||
|
||||
def process_replication_position(
|
||||
self, stream_name: str, instance_name: str, token: int
|
||||
|
|
|
@ -142,6 +142,10 @@ class PostgresEngine(
|
|||
apply stricter checks on new databases versus existing database.
|
||||
"""
|
||||
|
||||
allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
|
||||
if allow_unsafe_locale:
|
||||
return
|
||||
|
||||
collation, ctype = self.get_db_locale(txn)
|
||||
|
||||
errors = []
|
||||
|
@ -155,7 +159,9 @@ class PostgresEngine(
|
|||
if errors:
|
||||
raise IncorrectDatabaseSetup(
|
||||
"Database is incorrectly configured:\n\n%s\n\n"
|
||||
"See docs/postgres.md for more information." % ("\n".join(errors))
|
||||
"See docs/postgres.md for more information. You can override this check by"
|
||||
"setting 'allow_unsafe_locale' to true in the database config.",
|
||||
"\n".join(errors),
|
||||
)
|
||||
|
||||
def convert_param_style(self, sql: str) -> str:
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2024 New Vector, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
-- Add `instance_name` columns to stream tables to allow them to be used with
|
||||
-- `MultiWriterIdGenerator`
|
||||
ALTER TABLE device_lists_stream ADD COLUMN instance_name TEXT;
|
||||
ALTER TABLE user_signature_stream ADD COLUMN instance_name TEXT;
|
||||
ALTER TABLE device_lists_outbound_pokes ADD COLUMN instance_name TEXT;
|
||||
ALTER TABLE device_lists_changes_in_room ADD COLUMN instance_name TEXT;
|
||||
ALTER TABLE device_lists_remote_pending ADD COLUMN instance_name TEXT;
|
||||
|
||||
ALTER TABLE e2e_cross_signing_keys ADD COLUMN instance_name TEXT;
|
||||
|
||||
ALTER TABLE push_rules_stream ADD COLUMN instance_name TEXT;
|
||||
|
||||
ALTER TABLE pushers ADD COLUMN instance_name TEXT;
|
||||
ALTER TABLE deleted_pushers ADD COLUMN instance_name TEXT;
|
|
@ -0,0 +1,54 @@
|
|||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2024 New Vector, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
-- Add squences for stream tables to allow them to be used with
|
||||
-- `MultiWriterIdGenerator`
|
||||
CREATE SEQUENCE IF NOT EXISTS device_lists_sequence;
|
||||
|
||||
-- We need to take the max across all the device lists tables as they share the
|
||||
-- ID generator
|
||||
SELECT setval('device_lists_sequence', (
|
||||
SELECT GREATEST(
|
||||
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_stream),
|
||||
(SELECT COALESCE(MAX(stream_id), 1) FROM user_signature_stream),
|
||||
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_outbound_pokes),
|
||||
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_in_room),
|
||||
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_remote_pending),
|
||||
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_converted_stream_position)
|
||||
)
|
||||
));
|
||||
|
||||
CREATE SEQUENCE IF NOT EXISTS e2e_cross_signing_keys_sequence;
|
||||
|
||||
SELECT setval('e2e_cross_signing_keys_sequence', (
|
||||
SELECT COALESCE(MAX(stream_id), 1) FROM e2e_cross_signing_keys
|
||||
));
|
||||
|
||||
|
||||
CREATE SEQUENCE IF NOT EXISTS push_rules_stream_sequence;
|
||||
|
||||
SELECT setval('push_rules_stream_sequence', (
|
||||
SELECT COALESCE(MAX(stream_id), 1) FROM push_rules_stream
|
||||
));
|
||||
|
||||
|
||||
CREATE SEQUENCE IF NOT EXISTS pushers_sequence;
|
||||
|
||||
-- We need to take the max across all the pusher tables as they share the
|
||||
-- ID generator
|
||||
SELECT setval('pushers_sequence', (
|
||||
SELECT GREATEST(
|
||||
(SELECT COALESCE(MAX(id), 1) FROM pushers),
|
||||
(SELECT COALESCE(MAX(stream_id), 1) FROM deleted_pushers)
|
||||
)
|
||||
));
|
|
@ -0,0 +1,15 @@
|
|||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2024 New Vector, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(8504, 'cleanup_device_federation_outbox', '{}');
|
|
@ -23,15 +23,12 @@ import abc
|
|||
import heapq
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AsyncContextManager,
|
||||
ContextManager,
|
||||
Dict,
|
||||
Generator,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
|
@ -53,9 +50,11 @@ from synapse.storage.database import (
|
|||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
make_in_list_sql_clause,
|
||||
)
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.sequence import PostgresSequenceGenerator
|
||||
from synapse.storage.util.sequence import build_sequence_generator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.notifier import ReplicationNotifier
|
||||
|
@ -177,161 +176,6 @@ class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
|
|||
raise NotImplementedError()
|
||||
|
||||
|
||||
class StreamIdGenerator(AbstractStreamIdGenerator):
|
||||
"""Generates and tracks stream IDs for a stream with a single writer.
|
||||
|
||||
This class must only be used when the current Synapse process is the sole
|
||||
writer for a stream.
|
||||
|
||||
Args:
|
||||
db_conn(connection): A database connection to use to fetch the
|
||||
initial value of the generator from.
|
||||
table(str): A database table to read the initial value of the id
|
||||
generator from.
|
||||
column(str): The column of the database table to read the initial
|
||||
value from the id generator from.
|
||||
extra_tables(list): List of pairs of database tables and columns to
|
||||
use to source the initial value of the generator from. The value
|
||||
with the largest magnitude is used.
|
||||
step(int): which direction the stream ids grow in. +1 to grow
|
||||
upwards, -1 to grow downwards.
|
||||
|
||||
Usage:
|
||||
async with stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
notifier: "ReplicationNotifier",
|
||||
table: str,
|
||||
column: str,
|
||||
extra_tables: Iterable[Tuple[str, str]] = (),
|
||||
step: int = 1,
|
||||
is_writer: bool = True,
|
||||
) -> None:
|
||||
assert step != 0
|
||||
self._lock = threading.Lock()
|
||||
self._step: int = step
|
||||
self._current: int = _load_current_id(db_conn, table, column, step)
|
||||
self._is_writer = is_writer
|
||||
for table, column in extra_tables:
|
||||
self._current = (max if step > 0 else min)(
|
||||
self._current, _load_current_id(db_conn, table, column, step)
|
||||
)
|
||||
|
||||
# We use this as an ordered set, as we want to efficiently append items,
|
||||
# remove items and get the first item. Since we insert IDs in order, the
|
||||
# insertion ordering will ensure its in the correct ordering.
|
||||
#
|
||||
# The key and values are the same, but we never look at the values.
|
||||
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
|
||||
|
||||
self._notifier = notifier
|
||||
|
||||
def advance(self, instance_name: str, new_id: int) -> None:
|
||||
# Advance should never be called on a writer instance, only over replication
|
||||
if self._is_writer:
|
||||
raise Exception("Replication is not supported by writer StreamIdGenerator")
|
||||
|
||||
self._current = (max if self._step > 0 else min)(self._current, new_id)
|
||||
|
||||
def get_next(self) -> AsyncContextManager[int]:
|
||||
with self._lock:
|
||||
self._current += self._step
|
||||
next_id = self._current
|
||||
|
||||
self._unfinished_ids[next_id] = next_id
|
||||
|
||||
@contextmanager
|
||||
def manager() -> Generator[int, None, None]:
|
||||
try:
|
||||
yield next_id
|
||||
finally:
|
||||
with self._lock:
|
||||
self._unfinished_ids.pop(next_id)
|
||||
|
||||
self._notifier.notify_replication()
|
||||
|
||||
return _AsyncCtxManagerWrapper(manager())
|
||||
|
||||
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
|
||||
with self._lock:
|
||||
next_ids = range(
|
||||
self._current + self._step,
|
||||
self._current + self._step * (n + 1),
|
||||
self._step,
|
||||
)
|
||||
self._current += n * self._step
|
||||
|
||||
for next_id in next_ids:
|
||||
self._unfinished_ids[next_id] = next_id
|
||||
|
||||
@contextmanager
|
||||
def manager() -> Generator[Sequence[int], None, None]:
|
||||
try:
|
||||
yield next_ids
|
||||
finally:
|
||||
with self._lock:
|
||||
for next_id in next_ids:
|
||||
self._unfinished_ids.pop(next_id)
|
||||
|
||||
self._notifier.notify_replication()
|
||||
|
||||
return _AsyncCtxManagerWrapper(manager())
|
||||
|
||||
def get_next_txn(self, txn: LoggingTransaction) -> int:
|
||||
"""
|
||||
Retrieve the next stream ID from within a database transaction.
|
||||
|
||||
Clean-up functions will be called when the transaction finishes.
|
||||
|
||||
Args:
|
||||
txn: The database transaction object.
|
||||
|
||||
Returns:
|
||||
The next stream ID.
|
||||
"""
|
||||
if not self._is_writer:
|
||||
raise Exception("Tried to allocate stream ID on non-writer")
|
||||
|
||||
# Get the next stream ID.
|
||||
with self._lock:
|
||||
self._current += self._step
|
||||
next_id = self._current
|
||||
|
||||
self._unfinished_ids[next_id] = next_id
|
||||
|
||||
def clear_unfinished_id(id_to_clear: int) -> None:
|
||||
"""A function to mark processing this ID as finished"""
|
||||
with self._lock:
|
||||
self._unfinished_ids.pop(id_to_clear)
|
||||
|
||||
# Mark this ID as finished once the database transaction itself finishes.
|
||||
txn.call_after(clear_unfinished_id, next_id)
|
||||
txn.call_on_exception(clear_unfinished_id, next_id)
|
||||
|
||||
# Return the new ID.
|
||||
return next_id
|
||||
|
||||
def get_current_token(self) -> int:
|
||||
if not self._is_writer:
|
||||
return self._current
|
||||
|
||||
with self._lock:
|
||||
if self._unfinished_ids:
|
||||
return next(iter(self._unfinished_ids)) - self._step
|
||||
|
||||
return self._current
|
||||
|
||||
def get_current_token_for_writer(self, instance_name: str) -> int:
|
||||
return self.get_current_token()
|
||||
|
||||
def get_minimal_local_current_token(self) -> int:
|
||||
return self.get_current_token()
|
||||
|
||||
|
||||
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||
"""Generates and tracks stream IDs for a stream with multiple writers.
|
||||
|
||||
|
@ -432,7 +276,22 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
|||
# no active writes in progress.
|
||||
self._max_position_of_local_instance = self._max_seen_allocated_stream_id
|
||||
|
||||
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
|
||||
# This goes and fills out the above state from the database.
|
||||
self._load_current_ids(db_conn, tables)
|
||||
|
||||
self._sequence_gen = build_sequence_generator(
|
||||
db_conn=db_conn,
|
||||
database_engine=db.engine,
|
||||
get_first_callback=lambda _: self._persisted_upto_position,
|
||||
sequence_name=sequence_name,
|
||||
# We only need to set the below if we want it to call
|
||||
# `check_consistency`, but we do that ourselves below so we can
|
||||
# leave them blank.
|
||||
table=None,
|
||||
id_column=None,
|
||||
stream_name=None,
|
||||
positive=positive,
|
||||
)
|
||||
|
||||
# We check that the table and sequence haven't diverged.
|
||||
for table, _, id_column in tables:
|
||||
|
@ -444,9 +303,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
|||
positive=positive,
|
||||
)
|
||||
|
||||
# This goes and fills out the above state from the database.
|
||||
self._load_current_ids(db_conn, tables)
|
||||
|
||||
self._max_seen_allocated_stream_id = max(
|
||||
self._current_positions.values(), default=1
|
||||
)
|
||||
|
@ -480,13 +336,17 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
|||
# important if we add back a writer after a long time; we want to
|
||||
# consider that a "new" writer, rather than using the old stale
|
||||
# entry here.
|
||||
sql = """
|
||||
clause, args = make_in_list_sql_clause(
|
||||
self._db.engine, "instance_name", self._writers, negative=True
|
||||
)
|
||||
|
||||
sql = f"""
|
||||
DELETE FROM stream_positions
|
||||
WHERE
|
||||
stream_name = ?
|
||||
AND instance_name != ALL(?)
|
||||
AND {clause}
|
||||
"""
|
||||
cur.execute(sql, (self._stream_name, self._writers))
|
||||
cur.execute(sql, [self._stream_name] + args)
|
||||
|
||||
sql = """
|
||||
SELECT instance_name, stream_id FROM stream_positions
|
||||
|
@ -508,12 +368,16 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
|||
# We add a GREATEST here to ensure that the result is always
|
||||
# positive. (This can be a problem for e.g. backfill streams where
|
||||
# the server has never backfilled).
|
||||
greatest_func = (
|
||||
"GREATEST" if isinstance(self._db.engine, PostgresEngine) else "MAX"
|
||||
)
|
||||
max_stream_id = 1
|
||||
for table, _, id_column in tables:
|
||||
sql = """
|
||||
SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
|
||||
SELECT %(greatest_func)s(COALESCE(%(agg)s(%(id)s), 1), 1)
|
||||
FROM %(table)s
|
||||
""" % {
|
||||
"greatest_func": greatest_func,
|
||||
"id": id_column,
|
||||
"table": table,
|
||||
"agg": "MAX" if self._positive else "-MIN",
|
||||
|
@ -913,6 +777,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
|||
|
||||
# We upsert the value, ensuring on conflict that we always increase the
|
||||
# value (or decrease if stream goes backwards).
|
||||
if isinstance(self._db.engine, PostgresEngine):
|
||||
agg = "GREATEST" if self._positive else "LEAST"
|
||||
else:
|
||||
agg = "MAX" if self._positive else "MIN"
|
||||
|
||||
sql = """
|
||||
INSERT INTO stream_positions (stream_name, instance_name, stream_id)
|
||||
VALUES (?, ?, ?)
|
||||
|
@ -920,10 +789,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
|||
DO UPDATE SET
|
||||
stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
|
||||
""" % {
|
||||
"agg": "GREATEST" if self._positive else "LEAST",
|
||||
"agg": agg,
|
||||
}
|
||||
|
||||
pos = (self.get_current_token_for_writer(self._instance_name),)
|
||||
pos = self.get_current_token_for_writer(self._instance_name)
|
||||
txn.execute(sql, (self._stream_name, self._instance_name, pos))
|
||||
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ import attr
|
|||
from immutabledict import immutabledict
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
from signedjson.types import VerifyKey
|
||||
from typing_extensions import TypedDict
|
||||
from typing_extensions import Self, TypedDict
|
||||
from unpaddedbase64 import decode_base64
|
||||
from zope.interface import Interface
|
||||
|
||||
|
@ -515,6 +515,27 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
|
|||
# at `self.stream`.
|
||||
return self.instance_map.get(instance_name, self.stream)
|
||||
|
||||
def is_before_or_eq(self, other_token: Self) -> bool:
|
||||
"""Wether this token is before the other token, i.e. every constituent
|
||||
part is before the other.
|
||||
|
||||
Essentially it is `self <= other`.
|
||||
|
||||
Note: if `self.is_before_or_eq(other_token) is False` then that does not
|
||||
imply that the reverse is True.
|
||||
"""
|
||||
if self.stream > other_token.stream:
|
||||
return False
|
||||
|
||||
instances = self.instance_map.keys() | other_token.instance_map.keys()
|
||||
for instance in instances:
|
||||
if self.instance_map.get(
|
||||
instance, self.stream
|
||||
) > other_token.instance_map.get(instance, other_token.stream):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True, order=False)
|
||||
class RoomStreamToken(AbstractMultiWriterStreamToken):
|
||||
|
@ -1008,6 +1029,41 @@ class StreamToken:
|
|||
"""Returns the stream ID for the given key."""
|
||||
return getattr(self, key.value)
|
||||
|
||||
def is_before_or_eq(self, other_token: "StreamToken") -> bool:
|
||||
"""Wether this token is before the other token, i.e. every constituent
|
||||
part is before the other.
|
||||
|
||||
Essentially it is `self <= other`.
|
||||
|
||||
Note: if `self.is_before_or_eq(other_token) is False` then that does not
|
||||
imply that the reverse is True.
|
||||
"""
|
||||
|
||||
for _, key in StreamKeyType.__members__.items():
|
||||
if key == StreamKeyType.TYPING:
|
||||
# Typing stream is allowed to "reset", and so comparisons don't
|
||||
# really make sense as is.
|
||||
# TODO: Figure out a better way of tracking resets.
|
||||
continue
|
||||
|
||||
self_value = self.get_field(key)
|
||||
other_value = other_token.get_field(key)
|
||||
|
||||
if isinstance(self_value, RoomStreamToken):
|
||||
assert isinstance(other_value, RoomStreamToken)
|
||||
if not self_value.is_before_or_eq(other_value):
|
||||
return False
|
||||
elif isinstance(self_value, MultiWriterStreamToken):
|
||||
assert isinstance(other_value, MultiWriterStreamToken)
|
||||
if not self_value.is_before_or_eq(other_value):
|
||||
return False
|
||||
else:
|
||||
assert isinstance(other_value, int)
|
||||
if self_value > other_value:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
StreamToken.START = StreamToken(
|
||||
RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0
|
||||
|
|
|
@ -24,7 +24,12 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set
|
|||
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from synapse.logging.context import nested_logging_context
|
||||
from synapse.logging.context import (
|
||||
ContextResourceUsage,
|
||||
LoggingContext,
|
||||
nested_logging_context,
|
||||
set_current_context,
|
||||
)
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import (
|
||||
run_as_background_process,
|
||||
|
@ -81,6 +86,8 @@ class TaskScheduler:
|
|||
MAX_CONCURRENT_RUNNING_TASKS = 5
|
||||
# Time from the last task update after which we will log a warning
|
||||
LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000 # 24hrs
|
||||
# Report a running task's status and usage every so often.
|
||||
OCCASIONAL_REPORT_INTERVAL_MS = 5 * 60 * 1000 # 5 minutes
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._hs = hs
|
||||
|
@ -346,6 +353,33 @@ class TaskScheduler:
|
|||
assert task.id not in self._running_tasks
|
||||
await self._store.delete_scheduled_task(task.id)
|
||||
|
||||
@staticmethod
|
||||
def _log_task_usage(
|
||||
state: str, task: ScheduledTask, usage: ContextResourceUsage, active_time: float
|
||||
) -> None:
|
||||
"""
|
||||
Log a line describing the state and usage of a task.
|
||||
The log line is inspired by / a copy of the request log line format,
|
||||
but with irrelevant fields removed.
|
||||
|
||||
active_time: Time that the task has been running for, in seconds.
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
"Task %s: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
|
||||
" [%d dbevts] %r, %r",
|
||||
state,
|
||||
active_time,
|
||||
usage.ru_utime,
|
||||
usage.ru_stime,
|
||||
usage.db_sched_duration_sec,
|
||||
usage.db_txn_duration_sec,
|
||||
int(usage.db_txn_count),
|
||||
usage.evt_db_fetch_count,
|
||||
task.resource_id,
|
||||
task.params,
|
||||
)
|
||||
|
||||
async def _launch_task(self, task: ScheduledTask) -> None:
|
||||
"""Launch a scheduled task now.
|
||||
|
||||
|
@ -360,8 +394,32 @@ class TaskScheduler:
|
|||
)
|
||||
function = self._actions[task.action]
|
||||
|
||||
def _occasional_report(
|
||||
task_log_context: LoggingContext, start_time: float
|
||||
) -> None:
|
||||
"""
|
||||
Helper to log a 'Task continuing' line every so often.
|
||||
"""
|
||||
|
||||
current_time = self._clock.time()
|
||||
calling_context = set_current_context(task_log_context)
|
||||
try:
|
||||
usage = task_log_context.get_resource_usage()
|
||||
TaskScheduler._log_task_usage(
|
||||
"continuing", task, usage, current_time - start_time
|
||||
)
|
||||
finally:
|
||||
set_current_context(calling_context)
|
||||
|
||||
async def wrapper() -> None:
|
||||
with nested_logging_context(task.id):
|
||||
with nested_logging_context(task.id) as log_context:
|
||||
start_time = self._clock.time()
|
||||
occasional_status_call = self._clock.looping_call(
|
||||
_occasional_report,
|
||||
TaskScheduler.OCCASIONAL_REPORT_INTERVAL_MS,
|
||||
log_context,
|
||||
start_time,
|
||||
)
|
||||
try:
|
||||
(status, result, error) = await function(task)
|
||||
except Exception:
|
||||
|
@ -383,6 +441,13 @@ class TaskScheduler:
|
|||
)
|
||||
self._running_tasks.remove(task.id)
|
||||
|
||||
current_time = self._clock.time()
|
||||
usage = log_context.get_resource_usage()
|
||||
TaskScheduler._log_task_usage(
|
||||
status.value, task, usage, current_time - start_time
|
||||
)
|
||||
occasional_status_call.stop()
|
||||
|
||||
# Try launch a new task since we've finished with this one.
|
||||
self._clock.call_later(0.1, self._launch_scheduled_tasks)
|
||||
|
||||
|
|
|
@ -407,3 +407,24 @@ class RoomMemberMasterHandlerTestCase(HomeserverTestCase):
|
|||
self.assertFalse(
|
||||
self.get_success(self.store.did_forget(self.alice, self.room_id))
|
||||
)
|
||||
|
||||
def test_deduplicate_joins(self) -> None:
|
||||
"""
|
||||
Test that calling /join multiple times does not store a new state group.
|
||||
"""
|
||||
|
||||
self.helper.join(self.room_id, user=self.bob, tok=self.bob_token)
|
||||
|
||||
sql = "SELECT COUNT(*) FROM state_groups WHERE room_id = ?"
|
||||
rows = self.get_success(
|
||||
self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id)
|
||||
)
|
||||
initial_count = rows[0][0]
|
||||
|
||||
self.helper.join(self.room_id, user=self.bob, tok=self.bob_token)
|
||||
rows = self.get_success(
|
||||
self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id)
|
||||
)
|
||||
new_count = rows[0][0]
|
||||
|
||||
self.assertEqual(initial_count, new_count)
|
||||
|
|
|
@ -32,7 +32,7 @@ from twisted.web.resource import Resource
|
|||
from synapse.api.constants import EduTypes
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.federation.transport.server import TransportLayerServer
|
||||
from synapse.handlers.typing import TypingWriterHandler
|
||||
from synapse.handlers.typing import FORGET_TIMEOUT, TypingWriterHandler
|
||||
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, Requester, StreamKeyType, UserID, create_requester
|
||||
|
@ -501,3 +501,54 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||
}
|
||||
],
|
||||
)
|
||||
|
||||
def test_prune_typing_replication(self) -> None:
|
||||
"""Regression test for `get_all_typing_updates` breaking when we prune
|
||||
old updates
|
||||
"""
|
||||
self.room_members = [U_APPLE, U_BANANA]
|
||||
|
||||
instance_name = self.hs.get_instance_name()
|
||||
|
||||
self.get_success(
|
||||
self.handler.started_typing(
|
||||
target_user=U_APPLE,
|
||||
requester=create_requester(U_APPLE),
|
||||
room_id=ROOM_ID,
|
||||
timeout=10000,
|
||||
)
|
||||
)
|
||||
|
||||
rows, _, _ = self.get_success(
|
||||
self.handler.get_all_typing_updates(
|
||||
instance_name=instance_name,
|
||||
last_id=0,
|
||||
current_id=self.handler.get_current_token(),
|
||||
limit=100,
|
||||
)
|
||||
)
|
||||
self.assertEqual(rows, [(1, [ROOM_ID, [U_APPLE.to_string()]])])
|
||||
|
||||
self.reactor.advance(20000)
|
||||
|
||||
rows, _, _ = self.get_success(
|
||||
self.handler.get_all_typing_updates(
|
||||
instance_name=instance_name,
|
||||
last_id=1,
|
||||
current_id=self.handler.get_current_token(),
|
||||
limit=100,
|
||||
)
|
||||
)
|
||||
self.assertEqual(rows, [(2, [ROOM_ID, []])])
|
||||
|
||||
self.reactor.advance(FORGET_TIMEOUT)
|
||||
|
||||
rows, _, _ = self.get_success(
|
||||
self.handler.get_all_typing_updates(
|
||||
instance_name=instance_name,
|
||||
last_id=1,
|
||||
current_id=self.handler.get_current_token(),
|
||||
limit=100,
|
||||
)
|
||||
)
|
||||
self.assertEqual(rows, [])
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
# [This file includes modifications made by New Vector Limited]
|
||||
#
|
||||
#
|
||||
import itertools
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
@ -46,11 +47,11 @@ from synapse.media._base import FileInfo, ThumbnailInfo
|
|||
from synapse.media.filepath import MediaFilePaths
|
||||
from synapse.media.media_storage import MediaStorage, ReadableFileWrapper
|
||||
from synapse.media.storage_provider import FileStorageProviderBackend
|
||||
from synapse.media.thumbnailer import ThumbnailProvider
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login
|
||||
from synapse.rest.media.thumbnail_resource import ThumbnailResource
|
||||
from synapse.rest.client import login, media
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, RoomAlias
|
||||
from synapse.util import Clock
|
||||
|
@ -153,68 +154,54 @@ class _TestImage:
|
|||
is_inline: bool = True
|
||||
|
||||
|
||||
@parameterized_class(
|
||||
("test_image",),
|
||||
[
|
||||
# small png
|
||||
(
|
||||
_TestImage(
|
||||
SMALL_PNG,
|
||||
b"image/png",
|
||||
b".png",
|
||||
unhexlify(
|
||||
b"89504e470d0a1a0a0000000d4948445200000020000000200806"
|
||||
b"000000737a7af40000001a49444154789cedc101010000008220"
|
||||
b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
|
||||
b"44ae426082"
|
||||
),
|
||||
unhexlify(
|
||||
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
|
||||
b"0000001f15c4890000000d49444154789c636060606000000005"
|
||||
b"0001a5f645400000000049454e44ae426082"
|
||||
),
|
||||
),
|
||||
),
|
||||
# small png with transparency.
|
||||
(
|
||||
_TestImage(
|
||||
unhexlify(
|
||||
b"89504e470d0a1a0a0000000d49484452000000010000000101000"
|
||||
b"00000376ef9240000000274524e5300010194fdae0000000a4944"
|
||||
b"4154789c636800000082008177cd72b60000000049454e44ae426"
|
||||
b"082"
|
||||
),
|
||||
b"image/png",
|
||||
b".png",
|
||||
# Note that we don't check the output since it varies across
|
||||
# different versions of Pillow.
|
||||
),
|
||||
),
|
||||
# small lossless webp
|
||||
(
|
||||
_TestImage(
|
||||
unhexlify(
|
||||
b"524946461a000000574542505650384c0d0000002f0000001007"
|
||||
b"1011118888fe0700"
|
||||
),
|
||||
b"image/webp",
|
||||
b".webp",
|
||||
),
|
||||
),
|
||||
# an empty file
|
||||
(
|
||||
_TestImage(
|
||||
b"",
|
||||
b"image/gif",
|
||||
b".gif",
|
||||
expected_found=False,
|
||||
unable_to_thumbnail=True,
|
||||
),
|
||||
),
|
||||
# An SVG.
|
||||
(
|
||||
_TestImage(
|
||||
b"""<?xml version="1.0"?>
|
||||
small_png = _TestImage(
|
||||
SMALL_PNG,
|
||||
b"image/png",
|
||||
b".png",
|
||||
unhexlify(
|
||||
b"89504e470d0a1a0a0000000d4948445200000020000000200806"
|
||||
b"000000737a7af40000001a49444154789cedc101010000008220"
|
||||
b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
|
||||
b"44ae426082"
|
||||
),
|
||||
unhexlify(
|
||||
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
|
||||
b"0000001f15c4890000000d49444154789c636060606000000005"
|
||||
b"0001a5f645400000000049454e44ae426082"
|
||||
),
|
||||
)
|
||||
|
||||
small_png_with_transparency = _TestImage(
|
||||
unhexlify(
|
||||
b"89504e470d0a1a0a0000000d49484452000000010000000101000"
|
||||
b"00000376ef9240000000274524e5300010194fdae0000000a4944"
|
||||
b"4154789c636800000082008177cd72b60000000049454e44ae426"
|
||||
b"082"
|
||||
),
|
||||
b"image/png",
|
||||
b".png",
|
||||
# Note that we don't check the output since it varies across
|
||||
# different versions of Pillow.
|
||||
)
|
||||
|
||||
small_lossless_webp = _TestImage(
|
||||
unhexlify(
|
||||
b"524946461a000000574542505650384c0d0000002f0000001007" b"1011118888fe0700"
|
||||
),
|
||||
b"image/webp",
|
||||
b".webp",
|
||||
)
|
||||
|
||||
empty_file = _TestImage(
|
||||
b"",
|
||||
b"image/gif",
|
||||
b".gif",
|
||||
expected_found=False,
|
||||
unable_to_thumbnail=True,
|
||||
)
|
||||
|
||||
SVG = _TestImage(
|
||||
b"""<?xml version="1.0"?>
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
|
||||
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||
|
||||
|
@ -223,19 +210,32 @@ class _TestImage:
|
|||
<circle cx="100" cy="100" r="50" stroke="black"
|
||||
stroke-width="5" fill="red" />
|
||||
</svg>""",
|
||||
b"image/svg",
|
||||
b".svg",
|
||||
expected_found=False,
|
||||
unable_to_thumbnail=True,
|
||||
is_inline=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
b"image/svg",
|
||||
b".svg",
|
||||
expected_found=False,
|
||||
unable_to_thumbnail=True,
|
||||
is_inline=False,
|
||||
)
|
||||
test_images = [
|
||||
small_png,
|
||||
small_png_with_transparency,
|
||||
small_lossless_webp,
|
||||
empty_file,
|
||||
SVG,
|
||||
]
|
||||
urls = [
|
||||
"_matrix/media/r0/thumbnail",
|
||||
"_matrix/client/unstable/org.matrix.msc3916/media/thumbnail",
|
||||
]
|
||||
|
||||
|
||||
@parameterized_class(("test_image", "url"), itertools.product(test_images, urls))
|
||||
class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
servlets = [media.register_servlets]
|
||||
test_image: ClassVar[_TestImage]
|
||||
hijack_auth = True
|
||||
user_id = "@test:user"
|
||||
url: ClassVar[str]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.fetches: List[
|
||||
|
@ -298,6 +298,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
|||
"config": {"directory": self.storage_path},
|
||||
}
|
||||
config["media_storage_providers"] = [provider_config]
|
||||
config["experimental_features"] = {"msc3916_authenticated_media_enabled": True}
|
||||
|
||||
hs = self.setup_test_homeserver(config=config, federation_http_client=client)
|
||||
|
||||
|
@ -502,7 +503,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
|||
params = "?width=32&height=32&method=scale"
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/media/v3/thumbnail/{self.media_id}{params}",
|
||||
f"/{self.url}/{self.media_id}{params}",
|
||||
shorthand=False,
|
||||
await_result=False,
|
||||
)
|
||||
|
@ -530,7 +531,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
|||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/media/v3/thumbnail/{self.media_id}{params}",
|
||||
f"/{self.url}/{self.media_id}{params}",
|
||||
shorthand=False,
|
||||
await_result=False,
|
||||
)
|
||||
|
@ -566,12 +567,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
|||
params = "?width=32&height=32&method=" + method
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/media/r0/thumbnail/{self.media_id}{params}",
|
||||
f"/{self.url}/{self.media_id}{params}",
|
||||
shorthand=False,
|
||||
await_result=False,
|
||||
)
|
||||
self.pump()
|
||||
|
||||
headers = {
|
||||
b"Content-Length": [b"%d" % (len(self.test_image.data))],
|
||||
b"Content-Type": [self.test_image.content_type],
|
||||
|
@ -580,7 +580,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
|||
(self.test_image.data, (len(self.test_image.data), headers))
|
||||
)
|
||||
self.pump()
|
||||
|
||||
if expected_found:
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
|
@ -603,7 +602,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
|||
channel.json_body,
|
||||
{
|
||||
"errcode": "M_UNKNOWN",
|
||||
"error": "Cannot find any thumbnails for the requested media ('/_matrix/media/r0/thumbnail/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
|
||||
"error": f"Cannot find any thumbnails for the requested media ('/{self.url}/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
|
||||
},
|
||||
)
|
||||
else:
|
||||
|
@ -613,7 +612,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
|||
channel.json_body,
|
||||
{
|
||||
"errcode": "M_NOT_FOUND",
|
||||
"error": "Not found '/_matrix/media/r0/thumbnail/example.com/12345'",
|
||||
"error": f"Not found '/{self.url}/example.com/12345'",
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -625,12 +624,12 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
|||
|
||||
content_type = self.test_image.content_type.decode()
|
||||
media_repo = self.hs.get_media_repository()
|
||||
thumbnail_resouce = ThumbnailResource(
|
||||
thumbnail_provider = ThumbnailProvider(
|
||||
self.hs, media_repo, media_repo.media_storage
|
||||
)
|
||||
|
||||
self.assertIsNotNone(
|
||||
thumbnail_resouce._select_thumbnail(
|
||||
thumbnail_provider._select_thumbnail(
|
||||
desired_width=desired_size,
|
||||
desired_height=desired_size,
|
||||
desired_method=method,
|
||||
|
|
|
@ -24,8 +24,8 @@ from twisted.internet.defer import ensureDeferred
|
|||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.errors import NotFoundError
|
||||
from synapse.rest import admin, devices, room, sync
|
||||
from synapse.rest.client import account, keys, login, register
|
||||
from synapse.rest import admin, devices, sync
|
||||
from synapse.rest.client import keys, login, register
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, UserID, create_requester
|
||||
from synapse.util import Clock
|
||||
|
@ -33,146 +33,6 @@ from synapse.util import Clock
|
|||
from tests import unittest
|
||||
|
||||
|
||||
class DeviceListsTestCase(unittest.HomeserverTestCase):
|
||||
"""Tests regarding device list changes."""
|
||||
|
||||
servlets = [
|
||||
admin.register_servlets_for_client_rest_resource,
|
||||
login.register_servlets,
|
||||
register.register_servlets,
|
||||
account.register_servlets,
|
||||
room.register_servlets,
|
||||
sync.register_servlets,
|
||||
devices.register_servlets,
|
||||
]
|
||||
|
||||
def test_receiving_local_device_list_changes(self) -> None:
|
||||
"""Tests that a local users that share a room receive each other's device list
|
||||
changes.
|
||||
"""
|
||||
# Register two users
|
||||
test_device_id = "TESTDEVICE"
|
||||
alice_user_id = self.register_user("alice", "correcthorse")
|
||||
alice_access_token = self.login(
|
||||
alice_user_id, "correcthorse", device_id=test_device_id
|
||||
)
|
||||
|
||||
bob_user_id = self.register_user("bob", "ponyponypony")
|
||||
bob_access_token = self.login(bob_user_id, "ponyponypony")
|
||||
|
||||
# Create a room for them to coexist peacefully in
|
||||
new_room_id = self.helper.create_room_as(
|
||||
alice_user_id, is_public=True, tok=alice_access_token
|
||||
)
|
||||
self.assertIsNotNone(new_room_id)
|
||||
|
||||
# Have Bob join the room
|
||||
self.helper.invite(
|
||||
new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
|
||||
)
|
||||
self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
|
||||
|
||||
# Now have Bob initiate an initial sync (in order to get a since token)
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/sync",
|
||||
access_token=bob_access_token,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
next_batch_token = channel.json_body["next_batch"]
|
||||
|
||||
# ...and then an incremental sync. This should block until the sync stream is woken up,
|
||||
# which we hope will happen as a result of Alice updating their device list.
|
||||
bob_sync_channel = self.make_request(
|
||||
"GET",
|
||||
f"/sync?since={next_batch_token}&timeout=30000",
|
||||
access_token=bob_access_token,
|
||||
# Start the request, then continue on.
|
||||
await_result=False,
|
||||
)
|
||||
|
||||
# Have alice update their device list
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/devices/{test_device_id}",
|
||||
{
|
||||
"display_name": "New Device Name",
|
||||
},
|
||||
access_token=alice_access_token,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||
|
||||
# Check that bob's incremental sync contains the updated device list.
|
||||
# If not, the client would only receive the device list update on the
|
||||
# *next* sync.
|
||||
bob_sync_channel.await_result()
|
||||
self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
|
||||
|
||||
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
|
||||
"changed", []
|
||||
)
|
||||
self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
|
||||
|
||||
def test_not_receiving_local_device_list_changes(self) -> None:
|
||||
"""Tests a local users DO NOT receive device updates from each other if they do not
|
||||
share a room.
|
||||
"""
|
||||
# Register two users
|
||||
test_device_id = "TESTDEVICE"
|
||||
alice_user_id = self.register_user("alice", "correcthorse")
|
||||
alice_access_token = self.login(
|
||||
alice_user_id, "correcthorse", device_id=test_device_id
|
||||
)
|
||||
|
||||
bob_user_id = self.register_user("bob", "ponyponypony")
|
||||
bob_access_token = self.login(bob_user_id, "ponyponypony")
|
||||
|
||||
# These users do not share a room. They are lonely.
|
||||
|
||||
# Have Bob initiate an initial sync (in order to get a since token)
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/sync",
|
||||
access_token=bob_access_token,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||
next_batch_token = channel.json_body["next_batch"]
|
||||
|
||||
# ...and then an incremental sync. This should block until the sync stream is woken up,
|
||||
# which we hope will happen as a result of Alice updating their device list.
|
||||
bob_sync_channel = self.make_request(
|
||||
"GET",
|
||||
f"/sync?since={next_batch_token}&timeout=1000",
|
||||
access_token=bob_access_token,
|
||||
# Start the request, then continue on.
|
||||
await_result=False,
|
||||
)
|
||||
|
||||
# Have alice update their device list
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/devices/{test_device_id}",
|
||||
{
|
||||
"display_name": "New Device Name",
|
||||
},
|
||||
access_token=alice_access_token,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||
|
||||
# Check that bob's incremental sync does not contain the updated device list.
|
||||
bob_sync_channel.await_result()
|
||||
self.assertEqual(
|
||||
bob_sync_channel.code, HTTPStatus.OK, bob_sync_channel.json_body
|
||||
)
|
||||
|
||||
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
|
||||
"changed", []
|
||||
)
|
||||
self.assertNotIn(
|
||||
alice_user_id, changed_device_lists, bob_sync_channel.json_body
|
||||
)
|
||||
|
||||
|
||||
class DevicesTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
|
|
1609
tests/rest/client/test_media.py
Normal file
1609
tests/rest/client/test_media.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -18,15 +18,39 @@
|
|||
# [This file includes modifications made by New Vector Limited]
|
||||
#
|
||||
#
|
||||
from parameterized import parameterized_class
|
||||
|
||||
from synapse.api.constants import EduTypes
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, sendtodevice, sync
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from tests.unittest import HomeserverTestCase, override_config
|
||||
|
||||
|
||||
@parameterized_class(
|
||||
("sync_endpoint", "experimental_features"),
|
||||
[
|
||||
("/sync", {}),
|
||||
(
|
||||
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
|
||||
# Enable sliding sync
|
||||
{"msc3575_enabled": True},
|
||||
),
|
||||
],
|
||||
)
|
||||
class SendToDeviceTestCase(HomeserverTestCase):
|
||||
"""
|
||||
Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`.
|
||||
|
||||
Attributes:
|
||||
sync_endpoint: The endpoint under test to use for syncing.
|
||||
experimental_features: The experimental features homeserver config to use.
|
||||
"""
|
||||
|
||||
sync_endpoint: str
|
||||
experimental_features: JsonDict
|
||||
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
login.register_servlets,
|
||||
|
@ -34,6 +58,11 @@ class SendToDeviceTestCase(HomeserverTestCase):
|
|||
sync.register_servlets,
|
||||
]
|
||||
|
||||
def default_config(self) -> JsonDict:
|
||||
config = super().default_config()
|
||||
config["experimental_features"] = self.experimental_features
|
||||
return config
|
||||
|
||||
def test_user_to_user(self) -> None:
|
||||
"""A to-device message from one user to another should get delivered"""
|
||||
|
||||
|
@ -54,7 +83,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
|
|||
self.assertEqual(chan.code, 200, chan.result)
|
||||
|
||||
# check it appears
|
||||
channel = self.make_request("GET", "/sync", access_token=user2_tok)
|
||||
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
expected_result = {
|
||||
"events": [
|
||||
|
@ -67,15 +96,19 @@ class SendToDeviceTestCase(HomeserverTestCase):
|
|||
}
|
||||
self.assertEqual(channel.json_body["to_device"], expected_result)
|
||||
|
||||
# it should re-appear if we do another sync
|
||||
channel = self.make_request("GET", "/sync", access_token=user2_tok)
|
||||
# it should re-appear if we do another sync because the to-device message is not
|
||||
# deleted until we acknowledge it by sending a `?since=...` parameter in the
|
||||
# next sync request corresponding to the `next_batch` value from the response.
|
||||
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual(channel.json_body["to_device"], expected_result)
|
||||
|
||||
# it should *not* appear if we do an incremental sync
|
||||
sync_token = channel.json_body["next_batch"]
|
||||
channel = self.make_request(
|
||||
"GET", f"/sync?since={sync_token}", access_token=user2_tok
|
||||
"GET",
|
||||
f"{self.sync_endpoint}?since={sync_token}",
|
||||
access_token=user2_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])
|
||||
|
@ -99,15 +132,19 @@ class SendToDeviceTestCase(HomeserverTestCase):
|
|||
)
|
||||
self.assertEqual(chan.code, 200, chan.result)
|
||||
|
||||
# now sync: we should get two of the three
|
||||
channel = self.make_request("GET", "/sync", access_token=user2_tok)
|
||||
# now sync: we should get two of the three (because burst_count=2)
|
||||
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
msgs = channel.json_body["to_device"]["events"]
|
||||
self.assertEqual(len(msgs), 2)
|
||||
for i in range(2):
|
||||
self.assertEqual(
|
||||
msgs[i],
|
||||
{"sender": user1, "type": "m.room_key_request", "content": {"idx": i}},
|
||||
{
|
||||
"sender": user1,
|
||||
"type": "m.room_key_request",
|
||||
"content": {"idx": i},
|
||||
},
|
||||
)
|
||||
sync_token = channel.json_body["next_batch"]
|
||||
|
||||
|
@ -125,7 +162,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
|
|||
|
||||
# ... which should arrive
|
||||
channel = self.make_request(
|
||||
"GET", f"/sync?since={sync_token}", access_token=user2_tok
|
||||
"GET",
|
||||
f"{self.sync_endpoint}?since={sync_token}",
|
||||
access_token=user2_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
msgs = channel.json_body["to_device"]["events"]
|
||||
|
@ -159,7 +198,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
|
|||
)
|
||||
|
||||
# now sync: we should get two of the three
|
||||
channel = self.make_request("GET", "/sync", access_token=user2_tok)
|
||||
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
msgs = channel.json_body["to_device"]["events"]
|
||||
self.assertEqual(len(msgs), 2)
|
||||
|
@ -193,7 +232,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
|
|||
|
||||
# ... which should arrive
|
||||
channel = self.make_request(
|
||||
"GET", f"/sync?since={sync_token}", access_token=user2_tok
|
||||
"GET",
|
||||
f"{self.sync_endpoint}?since={sync_token}",
|
||||
access_token=user2_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
msgs = channel.json_body["to_device"]["events"]
|
||||
|
@ -217,7 +258,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
|
|||
user2_tok = self.login("u2", "pass", "d2")
|
||||
|
||||
# Do an initial sync
|
||||
channel = self.make_request("GET", "/sync", access_token=user2_tok)
|
||||
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
sync_token = channel.json_body["next_batch"]
|
||||
|
||||
|
@ -233,7 +274,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
|
|||
self.assertEqual(chan.code, 200, chan.result)
|
||||
|
||||
channel = self.make_request(
|
||||
"GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
|
||||
"GET",
|
||||
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
|
||||
access_token=user2_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
messages = channel.json_body.get("to_device", {}).get("events", [])
|
||||
|
@ -241,7 +284,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
|
|||
sync_token = channel.json_body["next_batch"]
|
||||
|
||||
channel = self.make_request(
|
||||
"GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
|
||||
"GET",
|
||||
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
|
||||
access_token=user2_tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
messages = channel.json_body.get("to_device", {}).get("events", [])
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
import json
|
||||
from typing import List
|
||||
|
||||
from parameterized import parameterized
|
||||
from parameterized import parameterized, parameterized_class
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
|
@ -688,24 +688,180 @@ class SyncCacheTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
|
||||
@parameterized_class(
|
||||
("sync_endpoint", "experimental_features"),
|
||||
[
|
||||
("/sync", {}),
|
||||
(
|
||||
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
|
||||
# Enable sliding sync
|
||||
{"msc3575_enabled": True},
|
||||
),
|
||||
],
|
||||
)
|
||||
class DeviceListSyncTestCase(unittest.HomeserverTestCase):
|
||||
"""
|
||||
Tests regarding device list (`device_lists`) changes.
|
||||
|
||||
Attributes:
|
||||
sync_endpoint: The endpoint under test to use for syncing.
|
||||
experimental_features: The experimental features homeserver config to use.
|
||||
"""
|
||||
|
||||
sync_endpoint: str
|
||||
experimental_features: JsonDict
|
||||
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
sync.register_servlets,
|
||||
devices.register_servlets,
|
||||
]
|
||||
|
||||
def default_config(self) -> JsonDict:
|
||||
config = super().default_config()
|
||||
config["experimental_features"] = self.experimental_features
|
||||
return config
|
||||
|
||||
def test_receiving_local_device_list_changes(self) -> None:
|
||||
"""Tests that a local users that share a room receive each other's device list
|
||||
changes.
|
||||
"""
|
||||
# Register two users
|
||||
test_device_id = "TESTDEVICE"
|
||||
alice_user_id = self.register_user("alice", "correcthorse")
|
||||
alice_access_token = self.login(
|
||||
alice_user_id, "correcthorse", device_id=test_device_id
|
||||
)
|
||||
|
||||
bob_user_id = self.register_user("bob", "ponyponypony")
|
||||
bob_access_token = self.login(bob_user_id, "ponyponypony")
|
||||
|
||||
# Create a room for them to coexist peacefully in
|
||||
new_room_id = self.helper.create_room_as(
|
||||
alice_user_id, is_public=True, tok=alice_access_token
|
||||
)
|
||||
self.assertIsNotNone(new_room_id)
|
||||
|
||||
# Have Bob join the room
|
||||
self.helper.invite(
|
||||
new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
|
||||
)
|
||||
self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
|
||||
|
||||
# Now have Bob initiate an initial sync (in order to get a since token)
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
self.sync_endpoint,
|
||||
access_token=bob_access_token,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
next_batch_token = channel.json_body["next_batch"]
|
||||
|
||||
# ...and then an incremental sync. This should block until the sync stream is woken up,
|
||||
# which we hope will happen as a result of Alice updating their device list.
|
||||
bob_sync_channel = self.make_request(
|
||||
"GET",
|
||||
f"{self.sync_endpoint}?since={next_batch_token}&timeout=30000",
|
||||
access_token=bob_access_token,
|
||||
# Start the request, then continue on.
|
||||
await_result=False,
|
||||
)
|
||||
|
||||
# Have alice update their device list
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/devices/{test_device_id}",
|
||||
{
|
||||
"display_name": "New Device Name",
|
||||
},
|
||||
access_token=alice_access_token,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Check that bob's incremental sync contains the updated device list.
|
||||
# If not, the client would only receive the device list update on the
|
||||
# *next* sync.
|
||||
bob_sync_channel.await_result()
|
||||
self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
|
||||
|
||||
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
|
||||
"changed", []
|
||||
)
|
||||
self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
|
||||
|
||||
def test_not_receiving_local_device_list_changes(self) -> None:
|
||||
"""Tests a local users DO NOT receive device updates from each other if they do not
|
||||
share a room.
|
||||
"""
|
||||
# Register two users
|
||||
test_device_id = "TESTDEVICE"
|
||||
alice_user_id = self.register_user("alice", "correcthorse")
|
||||
alice_access_token = self.login(
|
||||
alice_user_id, "correcthorse", device_id=test_device_id
|
||||
)
|
||||
|
||||
bob_user_id = self.register_user("bob", "ponyponypony")
|
||||
bob_access_token = self.login(bob_user_id, "ponyponypony")
|
||||
|
||||
# These users do not share a room. They are lonely.
|
||||
|
||||
# Have Bob initiate an initial sync (in order to get a since token)
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
self.sync_endpoint,
|
||||
access_token=bob_access_token,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
next_batch_token = channel.json_body["next_batch"]
|
||||
|
||||
# ...and then an incremental sync. This should block until the sync stream is woken up,
|
||||
# which we hope will happen as a result of Alice updating their device list.
|
||||
bob_sync_channel = self.make_request(
|
||||
"GET",
|
||||
f"{self.sync_endpoint}?since={next_batch_token}&timeout=1000",
|
||||
access_token=bob_access_token,
|
||||
# Start the request, then continue on.
|
||||
await_result=False,
|
||||
)
|
||||
|
||||
# Have alice update their device list
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/devices/{test_device_id}",
|
||||
{
|
||||
"display_name": "New Device Name",
|
||||
},
|
||||
access_token=alice_access_token,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Check that bob's incremental sync does not contain the updated device list.
|
||||
bob_sync_channel.await_result()
|
||||
self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
|
||||
|
||||
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
|
||||
"changed", []
|
||||
)
|
||||
self.assertNotIn(
|
||||
alice_user_id, changed_device_lists, bob_sync_channel.json_body
|
||||
)
|
||||
|
||||
def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None:
|
||||
"""Tests that a user with no rooms still receives their own device list updates"""
|
||||
device_id = "TESTDEVICE"
|
||||
test_device_id = "TESTDEVICE"
|
||||
|
||||
# Register a user and login, creating a device
|
||||
self.user_id = self.register_user("kermit", "monkey")
|
||||
self.tok = self.login("kermit", "monkey", device_id=device_id)
|
||||
alice_user_id = self.register_user("alice", "correcthorse")
|
||||
alice_access_token = self.login(
|
||||
alice_user_id, "correcthorse", device_id=test_device_id
|
||||
)
|
||||
|
||||
# Request an initial sync
|
||||
channel = self.make_request("GET", "/sync", access_token=self.tok)
|
||||
channel = self.make_request(
|
||||
"GET", self.sync_endpoint, access_token=alice_access_token
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
next_batch = channel.json_body["next_batch"]
|
||||
|
||||
|
@ -713,19 +869,19 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
|
|||
# It won't return until something has happened
|
||||
incremental_sync_channel = self.make_request(
|
||||
"GET",
|
||||
f"/sync?since={next_batch}&timeout=30000",
|
||||
access_token=self.tok,
|
||||
f"{self.sync_endpoint}?since={next_batch}&timeout=30000",
|
||||
access_token=alice_access_token,
|
||||
await_result=False,
|
||||
)
|
||||
|
||||
# Change our device's display name
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"devices/{device_id}",
|
||||
f"devices/{test_device_id}",
|
||||
{
|
||||
"display_name": "freeze ray",
|
||||
},
|
||||
access_token=self.tok,
|
||||
access_token=alice_access_token,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
|
@ -739,7 +895,230 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
|
|||
).get("changed", [])
|
||||
|
||||
self.assertIn(
|
||||
self.user_id, device_list_changes, incremental_sync_channel.json_body
|
||||
alice_user_id, device_list_changes, incremental_sync_channel.json_body
|
||||
)
|
||||
|
||||
|
||||
@parameterized_class(
|
||||
("sync_endpoint", "experimental_features"),
|
||||
[
|
||||
("/sync", {}),
|
||||
(
|
||||
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
|
||||
# Enable sliding sync
|
||||
{"msc3575_enabled": True},
|
||||
),
|
||||
],
|
||||
)
|
||||
class DeviceOneTimeKeysSyncTestCase(unittest.HomeserverTestCase):
|
||||
"""
|
||||
Tests regarding device one time keys (`device_one_time_keys_count`) changes.
|
||||
|
||||
Attributes:
|
||||
sync_endpoint: The endpoint under test to use for syncing.
|
||||
experimental_features: The experimental features homeserver config to use.
|
||||
"""
|
||||
|
||||
sync_endpoint: str
|
||||
experimental_features: JsonDict
|
||||
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets,
|
||||
login.register_servlets,
|
||||
sync.register_servlets,
|
||||
devices.register_servlets,
|
||||
]
|
||||
|
||||
def default_config(self) -> JsonDict:
|
||||
config = super().default_config()
|
||||
config["experimental_features"] = self.experimental_features
|
||||
return config
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||
|
||||
def test_no_device_one_time_keys(self) -> None:
|
||||
"""
|
||||
Tests when no one time keys set, it still has the default `signed_curve25519` in
|
||||
`device_one_time_keys_count`
|
||||
"""
|
||||
test_device_id = "TESTDEVICE"
|
||||
|
||||
alice_user_id = self.register_user("alice", "correcthorse")
|
||||
alice_access_token = self.login(
|
||||
alice_user_id, "correcthorse", device_id=test_device_id
|
||||
)
|
||||
|
||||
# Request an initial sync
|
||||
channel = self.make_request(
|
||||
"GET", self.sync_endpoint, access_token=alice_access_token
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Check for those one time key counts
|
||||
self.assertDictEqual(
|
||||
channel.json_body["device_one_time_keys_count"],
|
||||
# Note that "signed_curve25519" is always returned in key count responses
|
||||
# regardless of whether we uploaded any keys for it. This is necessary until
|
||||
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
|
||||
{"signed_curve25519": 0},
|
||||
channel.json_body["device_one_time_keys_count"],
|
||||
)
|
||||
|
||||
def test_returns_device_one_time_keys(self) -> None:
|
||||
"""
|
||||
Tests that one time keys for the device/user are counted correctly in the `/sync`
|
||||
response
|
||||
"""
|
||||
test_device_id = "TESTDEVICE"
|
||||
|
||||
alice_user_id = self.register_user("alice", "correcthorse")
|
||||
alice_access_token = self.login(
|
||||
alice_user_id, "correcthorse", device_id=test_device_id
|
||||
)
|
||||
|
||||
# Upload one time keys for the user/device
|
||||
keys: JsonDict = {
|
||||
"alg1:k1": "key1",
|
||||
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
|
||||
"alg2:k3": {"key": "key3"},
|
||||
}
|
||||
res = self.get_success(
|
||||
self.e2e_keys_handler.upload_keys_for_user(
|
||||
alice_user_id, test_device_id, {"one_time_keys": keys}
|
||||
)
|
||||
)
|
||||
# Note that "signed_curve25519" is always returned in key count responses
|
||||
# regardless of whether we uploaded any keys for it. This is necessary until
|
||||
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
|
||||
self.assertDictEqual(
|
||||
res,
|
||||
{"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}},
|
||||
)
|
||||
|
||||
# Request an initial sync
|
||||
channel = self.make_request(
|
||||
"GET", self.sync_endpoint, access_token=alice_access_token
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Check for those one time key counts
|
||||
self.assertDictEqual(
|
||||
channel.json_body["device_one_time_keys_count"],
|
||||
{"alg1": 1, "alg2": 2, "signed_curve25519": 0},
|
||||
channel.json_body["device_one_time_keys_count"],
|
||||
)
|
||||
|
||||
|
||||
@parameterized_class(
|
||||
("sync_endpoint", "experimental_features"),
|
||||
[
|
||||
("/sync", {}),
|
||||
(
|
||||
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
|
||||
# Enable sliding sync
|
||||
{"msc3575_enabled": True},
|
||||
),
|
||||
],
|
||||
)
|
||||
class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase):
|
||||
"""
|
||||
Tests regarding device one time keys (`device_unused_fallback_key_types`) changes.
|
||||
|
||||
Attributes:
|
||||
sync_endpoint: The endpoint under test to use for syncing.
|
||||
experimental_features: The experimental features homeserver config to use.
|
||||
"""
|
||||
|
||||
sync_endpoint: str
|
||||
experimental_features: JsonDict
|
||||
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets,
|
||||
login.register_servlets,
|
||||
sync.register_servlets,
|
||||
devices.register_servlets,
|
||||
]
|
||||
|
||||
def default_config(self) -> JsonDict:
|
||||
config = super().default_config()
|
||||
config["experimental_features"] = self.experimental_features
|
||||
return config
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = self.hs.get_datastores().main
|
||||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||
|
||||
def test_no_device_unused_fallback_key(self) -> None:
|
||||
"""
|
||||
Test when no unused fallback key is set, it just returns an empty list. The MSC
|
||||
says "The device_unused_fallback_key_types parameter must be present if the
|
||||
server supports fallback keys.",
|
||||
https://github.com/matrix-org/matrix-spec-proposals/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
|
||||
"""
|
||||
test_device_id = "TESTDEVICE"
|
||||
|
||||
alice_user_id = self.register_user("alice", "correcthorse")
|
||||
alice_access_token = self.login(
|
||||
alice_user_id, "correcthorse", device_id=test_device_id
|
||||
)
|
||||
|
||||
# Request an initial sync
|
||||
channel = self.make_request(
|
||||
"GET", self.sync_endpoint, access_token=alice_access_token
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Check for those one time key counts
|
||||
self.assertListEqual(
|
||||
channel.json_body["device_unused_fallback_key_types"],
|
||||
[],
|
||||
channel.json_body["device_unused_fallback_key_types"],
|
||||
)
|
||||
|
||||
def test_returns_device_one_time_keys(self) -> None:
|
||||
"""
|
||||
Tests that device unused fallback key type is returned correctly in the `/sync`
|
||||
"""
|
||||
test_device_id = "TESTDEVICE"
|
||||
|
||||
alice_user_id = self.register_user("alice", "correcthorse")
|
||||
alice_access_token = self.login(
|
||||
alice_user_id, "correcthorse", device_id=test_device_id
|
||||
)
|
||||
|
||||
# We shouldn't have any unused fallback keys yet
|
||||
res = self.get_success(
|
||||
self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
|
||||
)
|
||||
self.assertEqual(res, [])
|
||||
|
||||
# Upload a fallback key for the user/device
|
||||
fallback_key = {"alg1:k1": "fallback_key1"}
|
||||
self.get_success(
|
||||
self.e2e_keys_handler.upload_keys_for_user(
|
||||
alice_user_id,
|
||||
test_device_id,
|
||||
{"fallback_keys": fallback_key},
|
||||
)
|
||||
)
|
||||
# We should now have an unused alg1 key
|
||||
fallback_res = self.get_success(
|
||||
self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
|
||||
)
|
||||
self.assertEqual(fallback_res, ["alg1"], fallback_res)
|
||||
|
||||
# Request an initial sync
|
||||
channel = self.make_request(
|
||||
"GET", self.sync_endpoint, access_token=alice_access_token
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Check for the unused fallback key types
|
||||
self.assertListEqual(
|
||||
channel.json_body["device_unused_fallback_key_types"],
|
||||
["alg1"],
|
||||
channel.json_body["device_unused_fallback_key_types"],
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -44,13 +44,13 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
|
|||
# from a regular 404.
|
||||
file_id = "abcdefg12345"
|
||||
file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id)
|
||||
with hs.get_media_repository().media_storage.store_into_file(file_info) as (
|
||||
f,
|
||||
fname,
|
||||
finish,
|
||||
):
|
||||
f.write(SMALL_PNG)
|
||||
self.get_success(finish())
|
||||
|
||||
media_storage = hs.get_media_repository().media_storage
|
||||
|
||||
ctx = media_storage.store_into_file(file_info)
|
||||
(f, fname) = self.get_success(ctx.__aenter__())
|
||||
f.write(SMALL_PNG)
|
||||
self.get_success(ctx.__aexit__(None, None, None))
|
||||
|
||||
self.get_success(
|
||||
self.store.store_cached_remote_media(
|
||||
|
|
|
@ -30,163 +30,34 @@ from synapse.storage.database import (
|
|||
)
|
||||
from synapse.storage.engines import IncorrectDatabaseSetup
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.storage.util.sequence import (
|
||||
LocalSequenceGenerator,
|
||||
PostgresSequenceGenerator,
|
||||
SequenceGenerator,
|
||||
)
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
from tests.utils import USE_POSTGRES_FOR_TESTS
|
||||
|
||||
|
||||
class StreamIdGeneratorTestCase(HomeserverTestCase):
|
||||
class MultiWriterIdGeneratorBase(HomeserverTestCase):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
self.db_pool: DatabasePool = self.store.db_pool
|
||||
|
||||
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
|
||||
|
||||
def _setup_db(self, txn: LoggingTransaction) -> None:
|
||||
txn.execute(
|
||||
"""
|
||||
CREATE TABLE foobar (
|
||||
stream_id BIGINT NOT NULL,
|
||||
data TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
txn.execute("INSERT INTO foobar VALUES (123, 'hello world');")
|
||||
|
||||
def _create_id_generator(self) -> StreamIdGenerator:
|
||||
def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator:
|
||||
return StreamIdGenerator(
|
||||
db_conn=conn,
|
||||
notifier=self.hs.get_replication_notifier(),
|
||||
table="foobar",
|
||||
column="stream_id",
|
||||
)
|
||||
|
||||
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
|
||||
|
||||
def test_initial_value(self) -> None:
|
||||
"""Check that we read the current token from the DB."""
|
||||
id_gen = self._create_id_generator()
|
||||
self.assertEqual(id_gen.get_current_token(), 123)
|
||||
|
||||
def test_single_gen_next(self) -> None:
|
||||
"""Check that we correctly increment the current token from the DB."""
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
async def test_gen_next() -> None:
|
||||
async with id_gen.get_next() as next_id:
|
||||
# We haven't persisted `next_id` yet; current token is still 123
|
||||
self.assertEqual(id_gen.get_current_token(), 123)
|
||||
# But we did learn what the next value is
|
||||
self.assertEqual(next_id, 124)
|
||||
|
||||
# Once the context manager closes we assume that the `next_id` has been
|
||||
# written to the DB.
|
||||
self.assertEqual(id_gen.get_current_token(), 124)
|
||||
|
||||
self.get_success(test_gen_next())
|
||||
|
||||
def test_multiple_gen_nexts(self) -> None:
|
||||
"""Check that we handle overlapping calls to gen_next sensibly."""
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
async def test_gen_next() -> None:
|
||||
ctx1 = id_gen.get_next()
|
||||
ctx2 = id_gen.get_next()
|
||||
ctx3 = id_gen.get_next()
|
||||
|
||||
# Request three new stream IDs.
|
||||
self.assertEqual(await ctx1.__aenter__(), 124)
|
||||
self.assertEqual(await ctx2.__aenter__(), 125)
|
||||
self.assertEqual(await ctx3.__aenter__(), 126)
|
||||
|
||||
# None are persisted: current token unchanged.
|
||||
self.assertEqual(id_gen.get_current_token(), 123)
|
||||
|
||||
# Persist each in turn.
|
||||
await ctx1.__aexit__(None, None, None)
|
||||
self.assertEqual(id_gen.get_current_token(), 124)
|
||||
await ctx2.__aexit__(None, None, None)
|
||||
self.assertEqual(id_gen.get_current_token(), 125)
|
||||
await ctx3.__aexit__(None, None, None)
|
||||
self.assertEqual(id_gen.get_current_token(), 126)
|
||||
|
||||
self.get_success(test_gen_next())
|
||||
|
||||
def test_multiple_gen_nexts_closed_in_different_order(self) -> None:
|
||||
"""Check that we handle overlapping calls to gen_next, even when their IDs
|
||||
created and persisted in different orders."""
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
async def test_gen_next() -> None:
|
||||
ctx1 = id_gen.get_next()
|
||||
ctx2 = id_gen.get_next()
|
||||
ctx3 = id_gen.get_next()
|
||||
|
||||
# Request three new stream IDs.
|
||||
self.assertEqual(await ctx1.__aenter__(), 124)
|
||||
self.assertEqual(await ctx2.__aenter__(), 125)
|
||||
self.assertEqual(await ctx3.__aenter__(), 126)
|
||||
|
||||
# None are persisted: current token unchanged.
|
||||
self.assertEqual(id_gen.get_current_token(), 123)
|
||||
|
||||
# Persist them in a different order, starting with 126 from ctx3.
|
||||
await ctx3.__aexit__(None, None, None)
|
||||
# We haven't persisted 124 from ctx1 yet---current token is still 123.
|
||||
self.assertEqual(id_gen.get_current_token(), 123)
|
||||
|
||||
# Now persist 124 from ctx1.
|
||||
await ctx1.__aexit__(None, None, None)
|
||||
# Current token is then 124, waiting for 125 to be persisted.
|
||||
self.assertEqual(id_gen.get_current_token(), 124)
|
||||
|
||||
# Finally persist 125 from ctx2.
|
||||
await ctx2.__aexit__(None, None, None)
|
||||
# Current token is then 126 (skipping over 125).
|
||||
self.assertEqual(id_gen.get_current_token(), 126)
|
||||
|
||||
self.get_success(test_gen_next())
|
||||
|
||||
def test_gen_next_while_still_waiting_for_persistence(self) -> None:
|
||||
"""Check that we handle overlapping calls to gen_next."""
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
async def test_gen_next() -> None:
|
||||
ctx1 = id_gen.get_next()
|
||||
ctx2 = id_gen.get_next()
|
||||
ctx3 = id_gen.get_next()
|
||||
|
||||
# Request two new stream IDs.
|
||||
self.assertEqual(await ctx1.__aenter__(), 124)
|
||||
self.assertEqual(await ctx2.__aenter__(), 125)
|
||||
|
||||
# Persist ctx2 first.
|
||||
await ctx2.__aexit__(None, None, None)
|
||||
# Still waiting on ctx1's ID to be persisted.
|
||||
self.assertEqual(id_gen.get_current_token(), 123)
|
||||
|
||||
# Now request a third stream ID. It should be 126 (the smallest ID that
|
||||
# we've not yet handed out.)
|
||||
self.assertEqual(await ctx3.__aenter__(), 126)
|
||||
|
||||
self.get_success(test_gen_next())
|
||||
|
||||
|
||||
class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||
if not USE_POSTGRES_FOR_TESTS:
|
||||
skip = "Requires Postgres"
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
self.db_pool: DatabasePool = self.store.db_pool
|
||||
|
||||
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
|
||||
if USE_POSTGRES_FOR_TESTS:
|
||||
self.seq_gen: SequenceGenerator = PostgresSequenceGenerator("foobar_seq")
|
||||
else:
|
||||
self.seq_gen = LocalSequenceGenerator(lambda _: 0)
|
||||
|
||||
def _setup_db(self, txn: LoggingTransaction) -> None:
|
||||
txn.execute("CREATE SEQUENCE foobar_seq")
|
||||
if USE_POSTGRES_FOR_TESTS:
|
||||
txn.execute("CREATE SEQUENCE foobar_seq")
|
||||
|
||||
txn.execute(
|
||||
"""
|
||||
CREATE TABLE foobar (
|
||||
|
@ -221,44 +92,27 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
|
||||
def _insert(txn: LoggingTransaction) -> None:
|
||||
for _ in range(number):
|
||||
next_val = self.seq_gen.get_next_id_txn(txn)
|
||||
txn.execute(
|
||||
"INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
|
||||
(instance_name,),
|
||||
"INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)",
|
||||
(
|
||||
next_val,
|
||||
instance_name,
|
||||
),
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
"""
|
||||
INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
|
||||
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
|
||||
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
|
||||
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
|
||||
""",
|
||||
(instance_name,),
|
||||
(instance_name, next_val, next_val),
|
||||
)
|
||||
|
||||
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
|
||||
|
||||
def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None:
|
||||
"""Insert one row as the given instance with given stream_id, updating
|
||||
the postgres sequence position to match.
|
||||
"""
|
||||
|
||||
def _insert(txn: LoggingTransaction) -> None:
|
||||
txn.execute(
|
||||
"INSERT INTO foobar VALUES (?, ?)",
|
||||
(
|
||||
stream_id,
|
||||
instance_name,
|
||||
),
|
||||
)
|
||||
txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
|
||||
txn.execute(
|
||||
"""
|
||||
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
|
||||
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
|
||||
""",
|
||||
(instance_name, stream_id, stream_id),
|
||||
)
|
||||
|
||||
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
|
||||
|
||||
class MultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
|
||||
def test_empty(self) -> None:
|
||||
"""Test an ID generator against an empty database gives sensible
|
||||
current positions.
|
||||
|
@ -347,6 +201,176 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
self.assertEqual(id_gen.get_positions(), {"master": 11})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
|
||||
|
||||
def test_get_next_txn(self) -> None:
|
||||
"""Test that the `get_next_txn` function works correctly."""
|
||||
|
||||
# Prefill table with 7 rows written by 'master'
|
||||
self._insert_rows("master", 7)
|
||||
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||
|
||||
# Try allocating a new ID gen and check that we only see position
|
||||
# advanced after we leave the context manager.
|
||||
|
||||
def _get_next_txn(txn: LoggingTransaction) -> None:
|
||||
stream_id = id_gen.get_next_txn(txn)
|
||||
self.assertEqual(stream_id, 8)
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||
|
||||
self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 8})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
|
||||
|
||||
def test_restart_during_out_of_order_persistence(self) -> None:
|
||||
"""Test that restarting a process while another process is writing out
|
||||
of order updates are handled correctly.
|
||||
"""
|
||||
|
||||
# Prefill table with 7 rows written by 'master'
|
||||
self._insert_rows("master", 7)
|
||||
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||
|
||||
# Persist two rows at once
|
||||
ctx1 = id_gen.get_next()
|
||||
ctx2 = id_gen.get_next()
|
||||
|
||||
s1 = self.get_success(ctx1.__aenter__())
|
||||
s2 = self.get_success(ctx2.__aenter__())
|
||||
|
||||
self.assertEqual(s1, 8)
|
||||
self.assertEqual(s2, 9)
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||
|
||||
# We finish persisting the second row before restart
|
||||
self.get_success(ctx2.__aexit__(None, None, None))
|
||||
|
||||
# We simulate a restart of another worker by just creating a new ID gen.
|
||||
id_gen_worker = self._create_id_generator("worker")
|
||||
|
||||
# Restarted worker should not see the second persisted row
|
||||
self.assertEqual(id_gen_worker.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7)
|
||||
|
||||
# Now if we persist the first row then both instances should jump ahead
|
||||
# correctly.
|
||||
self.get_success(ctx1.__aexit__(None, None, None))
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 9})
|
||||
id_gen_worker.advance("master", 9)
|
||||
self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
|
||||
|
||||
|
||||
class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
|
||||
if not USE_POSTGRES_FOR_TESTS:
|
||||
skip = "Requires Postgres"
|
||||
|
||||
def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None:
|
||||
"""Insert one row as the given instance with given stream_id, updating
|
||||
the postgres sequence position to match.
|
||||
"""
|
||||
|
||||
def _insert(txn: LoggingTransaction) -> None:
|
||||
txn.execute(
|
||||
"INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)",
|
||||
(
|
||||
stream_id,
|
||||
instance_name,
|
||||
),
|
||||
)
|
||||
|
||||
txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
|
||||
|
||||
txn.execute(
|
||||
"""
|
||||
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
|
||||
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
|
||||
""",
|
||||
(instance_name, stream_id, stream_id),
|
||||
)
|
||||
|
||||
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
|
||||
|
||||
def test_get_persisted_upto_position(self) -> None:
|
||||
"""Test that `get_persisted_upto_position` correctly tracks updates to
|
||||
positions.
|
||||
"""
|
||||
|
||||
# The following tests are a bit cheeky in that we notify about new
|
||||
# positions via `advance` without *actually* advancing the postgres
|
||||
# sequence.
|
||||
|
||||
self._insert_row_with_id("first", 3)
|
||||
self._insert_row_with_id("second", 5)
|
||||
|
||||
id_gen = self._create_id_generator("worker", writers=["first", "second"])
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
|
||||
|
||||
# Min is 3 and there is a gap between 5, so we expect it to be 3.
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
|
||||
|
||||
# We advance "first" straight to 6. Min is now 5 but there is no gap so
|
||||
# we expect it to be 6
|
||||
id_gen.advance("first", 6)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
|
||||
|
||||
# No gap, so we expect 7.
|
||||
id_gen.advance("second", 7)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
|
||||
|
||||
# We haven't seen 8 yet, so we expect 7 still.
|
||||
id_gen.advance("second", 9)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
|
||||
|
||||
# Now that we've seen 7, 8 and 9 we can got straight to 9.
|
||||
id_gen.advance("first", 8)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 9)
|
||||
|
||||
# Jump forward with gaps. The minimum is 11, even though we haven't seen
|
||||
# 10 we know that everything before 11 must be persisted.
|
||||
id_gen.advance("first", 11)
|
||||
id_gen.advance("second", 15)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 11)
|
||||
|
||||
def test_get_persisted_upto_position_get_next(self) -> None:
|
||||
"""Test that `get_persisted_upto_position` correctly tracks updates to
|
||||
positions when `get_next` is called.
|
||||
"""
|
||||
|
||||
self._insert_row_with_id("first", 3)
|
||||
self._insert_row_with_id("second", 5)
|
||||
|
||||
id_gen = self._create_id_generator("first", writers=["first", "second"])
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
|
||||
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
|
||||
|
||||
async def _get_next_async() -> None:
|
||||
async with id_gen.get_next() as stream_id:
|
||||
self.assertEqual(stream_id, 6)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
|
||||
|
||||
self.get_success(_get_next_async())
|
||||
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
|
||||
|
||||
# We assume that so long as `get_next` does correctly advance the
|
||||
# `persisted_upto_position` in this case, then it will be correct in the
|
||||
# other cases that are tested above (since they'll hit the same code).
|
||||
|
||||
def test_multi_instance(self) -> None:
|
||||
"""Test that reads and writes from multiple processes are handled
|
||||
correctly.
|
||||
|
@ -453,145 +477,6 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8}
|
||||
)
|
||||
|
||||
def test_get_next_txn(self) -> None:
|
||||
"""Test that the `get_next_txn` function works correctly."""
|
||||
|
||||
# Prefill table with 7 rows written by 'master'
|
||||
self._insert_rows("master", 7)
|
||||
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||
|
||||
# Try allocating a new ID gen and check that we only see position
|
||||
# advanced after we leave the context manager.
|
||||
|
||||
def _get_next_txn(txn: LoggingTransaction) -> None:
|
||||
stream_id = id_gen.get_next_txn(txn)
|
||||
self.assertEqual(stream_id, 8)
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||
|
||||
self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 8})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
|
||||
|
||||
def test_get_persisted_upto_position(self) -> None:
|
||||
"""Test that `get_persisted_upto_position` correctly tracks updates to
|
||||
positions.
|
||||
"""
|
||||
|
||||
# The following tests are a bit cheeky in that we notify about new
|
||||
# positions via `advance` without *actually* advancing the postgres
|
||||
# sequence.
|
||||
|
||||
self._insert_row_with_id("first", 3)
|
||||
self._insert_row_with_id("second", 5)
|
||||
|
||||
id_gen = self._create_id_generator("worker", writers=["first", "second"])
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
|
||||
|
||||
# Min is 3 and there is a gap between 5, so we expect it to be 3.
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
|
||||
|
||||
# We advance "first" straight to 6. Min is now 5 but there is no gap so
|
||||
# we expect it to be 6
|
||||
id_gen.advance("first", 6)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
|
||||
|
||||
# No gap, so we expect 7.
|
||||
id_gen.advance("second", 7)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
|
||||
|
||||
# We haven't seen 8 yet, so we expect 7 still.
|
||||
id_gen.advance("second", 9)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
|
||||
|
||||
# Now that we've seen 7, 8 and 9 we can got straight to 9.
|
||||
id_gen.advance("first", 8)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 9)
|
||||
|
||||
# Jump forward with gaps. The minimum is 11, even though we haven't seen
|
||||
# 10 we know that everything before 11 must be persisted.
|
||||
id_gen.advance("first", 11)
|
||||
id_gen.advance("second", 15)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 11)
|
||||
|
||||
def test_get_persisted_upto_position_get_next(self) -> None:
|
||||
"""Test that `get_persisted_upto_position` correctly tracks updates to
|
||||
positions when `get_next` is called.
|
||||
"""
|
||||
|
||||
self._insert_row_with_id("first", 3)
|
||||
self._insert_row_with_id("second", 5)
|
||||
|
||||
id_gen = self._create_id_generator("first", writers=["first", "second"])
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
|
||||
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
|
||||
|
||||
async def _get_next_async() -> None:
|
||||
async with id_gen.get_next() as stream_id:
|
||||
self.assertEqual(stream_id, 6)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
|
||||
|
||||
self.get_success(_get_next_async())
|
||||
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
|
||||
|
||||
# We assume that so long as `get_next` does correctly advance the
|
||||
# `persisted_upto_position` in this case, then it will be correct in the
|
||||
# other cases that are tested above (since they'll hit the same code).
|
||||
|
||||
def test_restart_during_out_of_order_persistence(self) -> None:
|
||||
"""Test that restarting a process while another process is writing out
|
||||
of order updates are handled correctly.
|
||||
"""
|
||||
|
||||
# Prefill table with 7 rows written by 'master'
|
||||
self._insert_rows("master", 7)
|
||||
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||
|
||||
# Persist two rows at once
|
||||
ctx1 = id_gen.get_next()
|
||||
ctx2 = id_gen.get_next()
|
||||
|
||||
s1 = self.get_success(ctx1.__aenter__())
|
||||
s2 = self.get_success(ctx2.__aenter__())
|
||||
|
||||
self.assertEqual(s1, 8)
|
||||
self.assertEqual(s2, 9)
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||
|
||||
# We finish persisting the second row before restart
|
||||
self.get_success(ctx2.__aexit__(None, None, None))
|
||||
|
||||
# We simulate a restart of another worker by just creating a new ID gen.
|
||||
id_gen_worker = self._create_id_generator("worker")
|
||||
|
||||
# Restarted worker should not see the second persisted row
|
||||
self.assertEqual(id_gen_worker.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7)
|
||||
|
||||
# Now if we persist the first row then both instances should jump ahead
|
||||
# correctly.
|
||||
self.get_success(ctx1.__aexit__(None, None, None))
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 9})
|
||||
id_gen_worker.advance("master", 9)
|
||||
self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
|
||||
|
||||
def test_writer_config_change(self) -> None:
|
||||
"""Test that changing the writer config correctly works."""
|
||||
|
||||
|
|
Loading…
Reference in a new issue