mirror of
https://github.com/element-hq/synapse
synced 2024-10-02 09:12:43 +00:00
Merge branch 'develop' into shay/account_suspension_pt_2
This commit is contained in:
commit
9b56bf8497
78 changed files with 4470 additions and 1682 deletions
|
@ -1,3 +1,10 @@
|
||||||
|
# Synapse 1.108.0 (2024-05-28)
|
||||||
|
|
||||||
|
No significant changes since 1.108.0rc1.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Synapse 1.108.0rc1 (2024-05-21)
|
# Synapse 1.108.0rc1 (2024-05-21)
|
||||||
|
|
||||||
### Features
|
### Features
|
||||||
|
|
8
Cargo.lock
generated
8
Cargo.lock
generated
|
@ -485,18 +485,18 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde"
|
name = "serde"
|
||||||
version = "1.0.202"
|
version = "1.0.203"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "226b61a0d411b2ba5ff6d7f73a476ac4f8bb900373459cd00fab8512828ba395"
|
checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"serde_derive",
|
"serde_derive",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_derive"
|
name = "serde_derive"
|
||||||
version = "1.0.202"
|
version = "1.0.203"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6048858004bcff69094cd972ed40a32500f153bd3be9f716b2eed2e8217c4838"
|
checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
|
1
changelog.d/17083.misc
Normal file
1
changelog.d/17083.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Improve DB usage when fetching related events.
|
1
changelog.d/17164.bugfix
Normal file
1
changelog.d/17164.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix deduplicating of membership events to not create unused state groups.
|
1
changelog.d/17167.feature
Normal file
1
changelog.d/17167.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync/e2ee` endpoint for To-Device messages and device encryption info.
|
1
changelog.d/17176.misc
Normal file
1
changelog.d/17176.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Log exceptions when failing to auto-join new user according to the `auto_join_rooms` option.
|
1
changelog.d/17204.doc
Normal file
1
changelog.d/17204.doc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Update OIDC documentation: by default Matrix doesn't query userinfo endpoint, then claims should be put on id_token.
|
1
changelog.d/17211.misc
Normal file
1
changelog.d/17211.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Reduce work of calculating outbound device lists updates.
|
1
changelog.d/17213.feature
Normal file
1
changelog.d/17213.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Support MSC3916 by adding unstable media endpoints to `_matrix/client` (#17213).
|
1
changelog.d/17215.bugfix
Normal file
1
changelog.d/17215.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix bug where duplicate events could be sent down sync when using workers that are overloaded.
|
1
changelog.d/17219.feature
Normal file
1
changelog.d/17219.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add logging to tasks managed by the task scheduler, showing CPU and database usage.
|
1
changelog.d/17226.misc
Normal file
1
changelog.d/17226.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Move towards using `MultiWriterIdGenerator` everywhere.
|
1
changelog.d/17229.misc
Normal file
1
changelog.d/17229.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Replaces all usages of `StreamIdGenerator` with `MultiWriterIdGenerator`.
|
1
changelog.d/17238.misc
Normal file
1
changelog.d/17238.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Change the `allow_unsafe_locale` config option to also apply when setting up new databases.
|
1
changelog.d/17239.misc
Normal file
1
changelog.d/17239.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix errors in logs about closing incorrect logging contexts when media gets rejected by a module.
|
1
changelog.d/17240.bugfix
Normal file
1
changelog.d/17240.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Ignore attempts to send to-device messages to bad users, to avoid log spam when we try to connect to the bad server.
|
1
changelog.d/17241.bugfix
Normal file
1
changelog.d/17241.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix handling of duplicate concurrent uploading of device one-time-keys.
|
1
changelog.d/17242.misc
Normal file
1
changelog.d/17242.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Clean out invalid destinations from `device_federation_outbox` table.
|
1
changelog.d/17246.misc
Normal file
1
changelog.d/17246.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix errors in logs about closing incorrect logging contexts when media gets rejected by a module.
|
1
changelog.d/17250.misc
Normal file
1
changelog.d/17250.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Stop logging errors when receiving invalid User IDs in key querys requests.
|
1
changelog.d/17251.bugfix
Normal file
1
changelog.d/17251.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix reporting of default tags to Sentry, such as worker name. Broke in v1.108.0.
|
1
changelog.d/17252.bugfix
Normal file
1
changelog.d/17252.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix bug where typing updates would not be sent when using workers after a restart.
|
6
debian/changelog
vendored
6
debian/changelog
vendored
|
@ -1,3 +1,9 @@
|
||||||
|
matrix-synapse-py3 (1.108.0) stable; urgency=medium
|
||||||
|
|
||||||
|
* New Synapse release 1.108.0.
|
||||||
|
|
||||||
|
-- Synapse Packaging team <packages@matrix.org> Tue, 28 May 2024 11:54:22 +0100
|
||||||
|
|
||||||
matrix-synapse-py3 (1.108.0~rc1) stable; urgency=medium
|
matrix-synapse-py3 (1.108.0~rc1) stable; urgency=medium
|
||||||
|
|
||||||
* New Synapse release 1.108.0rc1.
|
* New Synapse release 1.108.0rc1.
|
||||||
|
|
|
@ -525,6 +525,8 @@ oidc_providers:
|
||||||
(`Options > Security > ID Token signature algorithm` and `Options > Security >
|
(`Options > Security > ID Token signature algorithm` and `Options > Security >
|
||||||
Access Token signature algorithm`)
|
Access Token signature algorithm`)
|
||||||
- Scopes: OpenID, Email and Profile
|
- Scopes: OpenID, Email and Profile
|
||||||
|
- Force claims into `id_token`
|
||||||
|
(`Options > Advanced > Force claims to be returned in ID Token`)
|
||||||
- Allowed redirection addresses for login (`Options > Basic > Allowed
|
- Allowed redirection addresses for login (`Options > Basic > Allowed
|
||||||
redirection addresses for login` ) :
|
redirection addresses for login` ) :
|
||||||
`[synapse public baseurl]/_synapse/client/oidc/callback`
|
`[synapse public baseurl]/_synapse/client/oidc/callback`
|
||||||
|
|
|
@ -242,12 +242,11 @@ host all all ::1/128 ident
|
||||||
|
|
||||||
### Fixing incorrect `COLLATE` or `CTYPE`
|
### Fixing incorrect `COLLATE` or `CTYPE`
|
||||||
|
|
||||||
Synapse will refuse to set up a new database if it has the wrong values of
|
Synapse will refuse to start when using a database with incorrect values of
|
||||||
`COLLATE` and `CTYPE` set. Synapse will also refuse to start an existing database with incorrect values
|
`COLLATE` and `CTYPE` unless the config flag `allow_unsafe_locale`, found in the
|
||||||
of `COLLATE` and `CTYPE` unless the config flag `allow_unsafe_locale`, found in the
|
`database` section of the config, is set to true. Using different locales can
|
||||||
`database` section of the config, is set to true. Using different locales can cause issues if the locale library is updated from
|
cause issues if the locale library is updated from underneath the database, or
|
||||||
underneath the database, or if a different version of the locale is used on any
|
if a different version of the locale is used on any replicas.
|
||||||
replicas.
|
|
||||||
|
|
||||||
If you have a database with an unsafe locale, the safest way to fix the issue is to dump the database and recreate it with
|
If you have a database with an unsafe locale, the safest way to fix the issue is to dump the database and recreate it with
|
||||||
the correct locale parameter (as shown above). It is also possible to change the
|
the correct locale parameter (as shown above). It is also possible to change the
|
||||||
|
|
24
poetry.lock
generated
24
poetry.lock
generated
|
@ -1536,13 +1536,13 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "phonenumbers"
|
name = "phonenumbers"
|
||||||
version = "8.13.35"
|
version = "8.13.37"
|
||||||
description = "Python version of Google's common library for parsing, formatting, storing and validating international phone numbers."
|
description = "Python version of Google's common library for parsing, formatting, storing and validating international phone numbers."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "*"
|
python-versions = "*"
|
||||||
files = [
|
files = [
|
||||||
{file = "phonenumbers-8.13.35-py2.py3-none-any.whl", hash = "sha256:58286a8e617bd75f541e04313b28c36398be6d4443a778c85e9617a93c391310"},
|
{file = "phonenumbers-8.13.37-py2.py3-none-any.whl", hash = "sha256:4ea00ef5012422c08c7955c21131e7ae5baa9a3ef52cf2d561e963f023006b80"},
|
||||||
{file = "phonenumbers-8.13.35.tar.gz", hash = "sha256:64f061a967dcdae11e1c59f3688649e697b897110a33bb74d5a69c3e35321245"},
|
{file = "phonenumbers-8.13.37.tar.gz", hash = "sha256:bd315fed159aea0516f7c367231810fe8344d5bec26156b88fa18374c11d1cf2"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1673,13 +1673,13 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.2.2)", "pytest (>=7.2.1)", "pytes
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "prometheus-client"
|
name = "prometheus-client"
|
||||||
version = "0.19.0"
|
version = "0.20.0"
|
||||||
description = "Python client for the Prometheus monitoring system."
|
description = "Python client for the Prometheus monitoring system."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "prometheus_client-0.19.0-py3-none-any.whl", hash = "sha256:c88b1e6ecf6b41cd8fb5731c7ae919bf66df6ec6fafa555cd6c0e16ca169ae92"},
|
{file = "prometheus_client-0.20.0-py3-none-any.whl", hash = "sha256:cde524a85bce83ca359cc837f28b8c0db5cac7aa653a588fd7e84ba061c329e7"},
|
||||||
{file = "prometheus_client-0.19.0.tar.gz", hash = "sha256:4585b0d1223148c27a225b10dbec5ae9bc4c81a99a3fa80774fa6209935324e1"},
|
{file = "prometheus_client-0.20.0.tar.gz", hash = "sha256:287629d00b147a32dcb2be0b9df905da599b2d82f80377083ec8463309a4bb89"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
|
@ -1915,12 +1915,12 @@ plugins = ["importlib-metadata"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyicu"
|
name = "pyicu"
|
||||||
version = "2.13"
|
version = "2.13.1"
|
||||||
description = "Python extension wrapping the ICU C++ API"
|
description = "Python extension wrapping the ICU C++ API"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = "*"
|
python-versions = "*"
|
||||||
files = [
|
files = [
|
||||||
{file = "PyICU-2.13.tar.gz", hash = "sha256:d481be888975df3097c2790241bbe8518f65c9676a74957cdbe790e559c828f6"},
|
{file = "PyICU-2.13.1.tar.gz", hash = "sha256:d4919085eaa07da12bade8ee721e7bbf7ade0151ca0f82946a26c8f4b98cdceb"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1997,13 +1997,13 @@ tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyopenssl"
|
name = "pyopenssl"
|
||||||
version = "24.0.0"
|
version = "24.1.0"
|
||||||
description = "Python wrapper module around the OpenSSL library"
|
description = "Python wrapper module around the OpenSSL library"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "pyOpenSSL-24.0.0-py3-none-any.whl", hash = "sha256:ba07553fb6fd6a7a2259adb9b84e12302a9a8a75c44046e8bb5d3e5ee887e3c3"},
|
{file = "pyOpenSSL-24.1.0-py3-none-any.whl", hash = "sha256:17ed5be5936449c5418d1cd269a1a9e9081bc54c17aed272b45856a3d3dc86ad"},
|
||||||
{file = "pyOpenSSL-24.0.0.tar.gz", hash = "sha256:6aa33039a93fffa4563e655b61d11364d01264be8ccb49906101e02a334530bf"},
|
{file = "pyOpenSSL-24.1.0.tar.gz", hash = "sha256:cabed4bfaa5df9f1a16c0ef64a0cb65318b5cd077a7eda7d6970131ca2f41a6f"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -2011,7 +2011,7 @@ cryptography = ">=41.0.5,<43"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
docs = ["sphinx (!=5.2.0,!=5.2.0.post0,!=7.2.5)", "sphinx-rtd-theme"]
|
docs = ["sphinx (!=5.2.0,!=5.2.0.post0,!=7.2.5)", "sphinx-rtd-theme"]
|
||||||
test = ["flaky", "pretend", "pytest (>=3.0.1)"]
|
test = ["pretend", "pytest (>=3.0.1)", "pytest-rerunfailures"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pysaml2"
|
name = "pysaml2"
|
||||||
|
|
|
@ -96,7 +96,7 @@ module-name = "synapse.synapse_rust"
|
||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "matrix-synapse"
|
name = "matrix-synapse"
|
||||||
version = "1.108.0rc1"
|
version = "1.108.0"
|
||||||
description = "Homeserver for the Matrix decentralised comms protocol"
|
description = "Homeserver for the Matrix decentralised comms protocol"
|
||||||
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
|
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
|
||||||
license = "AGPL-3.0-or-later"
|
license = "AGPL-3.0-or-later"
|
||||||
|
@ -200,10 +200,8 @@ netaddr = ">=0.7.18"
|
||||||
# add a lower bound to the Jinja2 dependency.
|
# add a lower bound to the Jinja2 dependency.
|
||||||
Jinja2 = ">=3.0"
|
Jinja2 = ">=3.0"
|
||||||
bleach = ">=1.4.3"
|
bleach = ">=1.4.3"
|
||||||
# We use `ParamSpec` and `Concatenate`, which were added in `typing-extensions` 3.10.0.0.
|
# We use `Self`, which were added in `typing-extensions` 4.0.
|
||||||
# Additionally we need https://github.com/python/typing/pull/817 to allow types to be
|
typing-extensions = ">=4.0"
|
||||||
# generic over ParamSpecs.
|
|
||||||
typing-extensions = ">=3.10.0.1"
|
|
||||||
# We enforce that we have a `cryptography` version that bundles an `openssl`
|
# We enforce that we have a `cryptography` version that bundles an `openssl`
|
||||||
# with the latest security patches.
|
# with the latest security patches.
|
||||||
cryptography = ">=3.4.7"
|
cryptography = ">=3.4.7"
|
||||||
|
|
|
@ -777,22 +777,74 @@ class Porter:
|
||||||
await self._setup_events_stream_seqs()
|
await self._setup_events_stream_seqs()
|
||||||
await self._setup_sequence(
|
await self._setup_sequence(
|
||||||
"un_partial_stated_event_stream_sequence",
|
"un_partial_stated_event_stream_sequence",
|
||||||
("un_partial_stated_event_stream",),
|
[("un_partial_stated_event_stream", "stream_id")],
|
||||||
)
|
)
|
||||||
await self._setup_sequence(
|
await self._setup_sequence(
|
||||||
"device_inbox_sequence", ("device_inbox", "device_federation_outbox")
|
"device_inbox_sequence",
|
||||||
|
[
|
||||||
|
("device_inbox", "stream_id"),
|
||||||
|
("device_federation_outbox", "stream_id"),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
await self._setup_sequence(
|
await self._setup_sequence(
|
||||||
"account_data_sequence",
|
"account_data_sequence",
|
||||||
("room_account_data", "room_tags_revisions", "account_data"),
|
[
|
||||||
|
("room_account_data", "stream_id"),
|
||||||
|
("room_tags_revisions", "stream_id"),
|
||||||
|
("account_data", "stream_id"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
await self._setup_sequence(
|
||||||
|
"receipts_sequence",
|
||||||
|
[
|
||||||
|
("receipts_linearized", "stream_id"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
await self._setup_sequence(
|
||||||
|
"presence_stream_sequence",
|
||||||
|
[
|
||||||
|
("presence_stream", "stream_id"),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
await self._setup_sequence("receipts_sequence", ("receipts_linearized",))
|
|
||||||
await self._setup_sequence("presence_stream_sequence", ("presence_stream",))
|
|
||||||
await self._setup_auth_chain_sequence()
|
await self._setup_auth_chain_sequence()
|
||||||
await self._setup_sequence(
|
await self._setup_sequence(
|
||||||
"application_services_txn_id_seq",
|
"application_services_txn_id_seq",
|
||||||
("application_services_txns",),
|
[
|
||||||
"txn_id",
|
(
|
||||||
|
"application_services_txns",
|
||||||
|
"txn_id",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
await self._setup_sequence(
|
||||||
|
"device_lists_sequence",
|
||||||
|
[
|
||||||
|
("device_lists_stream", "stream_id"),
|
||||||
|
("user_signature_stream", "stream_id"),
|
||||||
|
("device_lists_outbound_pokes", "stream_id"),
|
||||||
|
("device_lists_changes_in_room", "stream_id"),
|
||||||
|
("device_lists_remote_pending", "stream_id"),
|
||||||
|
("device_lists_changes_converted_stream_position", "stream_id"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
await self._setup_sequence(
|
||||||
|
"e2e_cross_signing_keys_sequence",
|
||||||
|
[
|
||||||
|
("e2e_cross_signing_keys", "stream_id"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
await self._setup_sequence(
|
||||||
|
"push_rules_stream_sequence",
|
||||||
|
[
|
||||||
|
("push_rules_stream", "stream_id"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
await self._setup_sequence(
|
||||||
|
"pushers_sequence",
|
||||||
|
[
|
||||||
|
("pushers", "id"),
|
||||||
|
("deleted_pushers", "stream_id"),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 3. Get tables.
|
# Step 3. Get tables.
|
||||||
|
@ -1101,12 +1153,11 @@ class Porter:
|
||||||
async def _setup_sequence(
|
async def _setup_sequence(
|
||||||
self,
|
self,
|
||||||
sequence_name: str,
|
sequence_name: str,
|
||||||
stream_id_tables: Iterable[str],
|
stream_id_tables: Iterable[Tuple[str, str]],
|
||||||
column_name: str = "stream_id",
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set a sequence to the correct value."""
|
"""Set a sequence to the correct value."""
|
||||||
current_stream_ids = []
|
current_stream_ids = []
|
||||||
for stream_id_table in stream_id_tables:
|
for stream_id_table, column_name in stream_id_tables:
|
||||||
max_stream_id = cast(
|
max_stream_id = cast(
|
||||||
int,
|
int,
|
||||||
await self.sqlite_store.db_pool.simple_select_one_onecol(
|
await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||||
|
|
|
@ -681,17 +681,17 @@ def setup_sentry(hs: "HomeServer") -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
# We set some default tags that give some context to this instance
|
# We set some default tags that give some context to this instance
|
||||||
with sentry_sdk.configure_scope() as scope:
|
global_scope = sentry_sdk.Scope.get_global_scope()
|
||||||
scope.set_tag("matrix_server_name", hs.config.server.server_name)
|
global_scope.set_tag("matrix_server_name", hs.config.server.server_name)
|
||||||
|
|
||||||
app = (
|
app = (
|
||||||
hs.config.worker.worker_app
|
hs.config.worker.worker_app
|
||||||
if hs.config.worker.worker_app
|
if hs.config.worker.worker_app
|
||||||
else "synapse.app.homeserver"
|
else "synapse.app.homeserver"
|
||||||
)
|
)
|
||||||
name = hs.get_instance_name()
|
name = hs.get_instance_name()
|
||||||
scope.set_tag("worker_app", app)
|
global_scope.set_tag("worker_app", app)
|
||||||
scope.set_tag("worker_name", name)
|
global_scope.set_tag("worker_name", name)
|
||||||
|
|
||||||
|
|
||||||
def setup_sdnotify(hs: "HomeServer") -> None:
|
def setup_sdnotify(hs: "HomeServer") -> None:
|
||||||
|
|
|
@ -332,6 +332,9 @@ class ExperimentalConfig(Config):
|
||||||
# MSC3391: Removing account data.
|
# MSC3391: Removing account data.
|
||||||
self.msc3391_enabled = experimental.get("msc3391_enabled", False)
|
self.msc3391_enabled = experimental.get("msc3391_enabled", False)
|
||||||
|
|
||||||
|
# MSC3575 (Sliding Sync API endpoints)
|
||||||
|
self.msc3575_enabled: bool = experimental.get("msc3575_enabled", False)
|
||||||
|
|
||||||
# MSC3773: Thread notifications
|
# MSC3773: Thread notifications
|
||||||
self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False)
|
self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False)
|
||||||
|
|
||||||
|
@ -440,3 +443,7 @@ class ExperimentalConfig(Config):
|
||||||
self.msc3823_account_suspension = experimental.get(
|
self.msc3823_account_suspension = experimental.get(
|
||||||
"msc3823_account_suspension", False
|
"msc3823_account_suspension", False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.msc3916_authenticated_media_enabled = experimental.get(
|
||||||
|
"msc3916_authenticated_media_enabled", False
|
||||||
|
)
|
||||||
|
|
|
@ -906,6 +906,13 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||||
context=opentracing_context,
|
context=opentracing_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await self.store.mark_redundant_device_lists_pokes(
|
||||||
|
user_id=user_id,
|
||||||
|
device_id=device_id,
|
||||||
|
room_id=room_id,
|
||||||
|
converted_upto_stream_id=stream_id,
|
||||||
|
)
|
||||||
|
|
||||||
# Notify replication that we've updated the device list stream.
|
# Notify replication that we've updated the device list stream.
|
||||||
self.notifier.notify_replication()
|
self.notifier.notify_replication()
|
||||||
|
|
||||||
|
|
|
@ -236,6 +236,13 @@ class DeviceMessageHandler:
|
||||||
local_messages = {}
|
local_messages = {}
|
||||||
remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||||
for user_id, by_device in messages.items():
|
for user_id, by_device in messages.items():
|
||||||
|
if not UserID.is_valid(user_id):
|
||||||
|
logger.warning(
|
||||||
|
"Ignoring attempt to send device message to invalid user: %r",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
# add an opentracing log entry for each message
|
# add an opentracing log entry for each message
|
||||||
for device_id, message_content in by_device.items():
|
for device_id, message_content in by_device.items():
|
||||||
log_kv(
|
log_kv(
|
||||||
|
|
|
@ -53,6 +53,9 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
ONE_TIME_KEY_UPLOAD = "one_time_key_upload_lock"
|
||||||
|
|
||||||
|
|
||||||
class E2eKeysHandler:
|
class E2eKeysHandler:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.config = hs.config
|
self.config = hs.config
|
||||||
|
@ -62,6 +65,7 @@ class E2eKeysHandler:
|
||||||
self._appservice_handler = hs.get_application_service_handler()
|
self._appservice_handler = hs.get_application_service_handler()
|
||||||
self.is_mine = hs.is_mine
|
self.is_mine = hs.is_mine
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
self._worker_lock_handler = hs.get_worker_locks_handler()
|
||||||
|
|
||||||
federation_registry = hs.get_federation_registry()
|
federation_registry = hs.get_federation_registry()
|
||||||
|
|
||||||
|
@ -145,6 +149,11 @@ class E2eKeysHandler:
|
||||||
remote_queries = {}
|
remote_queries = {}
|
||||||
|
|
||||||
for user_id, device_ids in device_keys_query.items():
|
for user_id, device_ids in device_keys_query.items():
|
||||||
|
if not UserID.is_valid(user_id):
|
||||||
|
# Ignore invalid user IDs, which is the same behaviour as if
|
||||||
|
# the user existed but had no keys.
|
||||||
|
continue
|
||||||
|
|
||||||
# we use UserID.from_string to catch invalid user ids
|
# we use UserID.from_string to catch invalid user ids
|
||||||
if self.is_mine(UserID.from_string(user_id)):
|
if self.is_mine(UserID.from_string(user_id)):
|
||||||
local_query[user_id] = device_ids
|
local_query[user_id] = device_ids
|
||||||
|
@ -855,45 +864,53 @@ class E2eKeysHandler:
|
||||||
async def _upload_one_time_keys_for_user(
|
async def _upload_one_time_keys_for_user(
|
||||||
self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
|
self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.info(
|
# We take out a lock so that we don't have to worry about a client
|
||||||
"Adding one_time_keys %r for device %r for user %r at %d",
|
# sending duplicate requests.
|
||||||
one_time_keys.keys(),
|
lock_key = f"{user_id}_{device_id}"
|
||||||
device_id,
|
async with self._worker_lock_handler.acquire_lock(
|
||||||
user_id,
|
ONE_TIME_KEY_UPLOAD, lock_key
|
||||||
time_now,
|
):
|
||||||
)
|
logger.info(
|
||||||
|
"Adding one_time_keys %r for device %r for user %r at %d",
|
||||||
|
one_time_keys.keys(),
|
||||||
|
device_id,
|
||||||
|
user_id,
|
||||||
|
time_now,
|
||||||
|
)
|
||||||
|
|
||||||
# make a list of (alg, id, key) tuples
|
# make a list of (alg, id, key) tuples
|
||||||
key_list = []
|
key_list = []
|
||||||
for key_id, key_obj in one_time_keys.items():
|
for key_id, key_obj in one_time_keys.items():
|
||||||
algorithm, key_id = key_id.split(":")
|
algorithm, key_id = key_id.split(":")
|
||||||
key_list.append((algorithm, key_id, key_obj))
|
key_list.append((algorithm, key_id, key_obj))
|
||||||
|
|
||||||
# First we check if we have already persisted any of the keys.
|
# First we check if we have already persisted any of the keys.
|
||||||
existing_key_map = await self.store.get_e2e_one_time_keys(
|
existing_key_map = await self.store.get_e2e_one_time_keys(
|
||||||
user_id, device_id, [k_id for _, k_id, _ in key_list]
|
user_id, device_id, [k_id for _, k_id, _ in key_list]
|
||||||
)
|
)
|
||||||
|
|
||||||
new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
|
new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
|
||||||
for algorithm, key_id, key in key_list:
|
for algorithm, key_id, key in key_list:
|
||||||
ex_json = existing_key_map.get((algorithm, key_id), None)
|
ex_json = existing_key_map.get((algorithm, key_id), None)
|
||||||
if ex_json:
|
if ex_json:
|
||||||
if not _one_time_keys_match(ex_json, key):
|
if not _one_time_keys_match(ex_json, key):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400,
|
400,
|
||||||
(
|
(
|
||||||
"One time key %s:%s already exists. "
|
"One time key %s:%s already exists. "
|
||||||
"Old key: %s; new key: %r"
|
"Old key: %s; new key: %r"
|
||||||
|
)
|
||||||
|
% (algorithm, key_id, ex_json, key),
|
||||||
)
|
)
|
||||||
% (algorithm, key_id, ex_json, key),
|
else:
|
||||||
|
new_keys.append(
|
||||||
|
(algorithm, key_id, encode_canonical_json(key).decode("ascii"))
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
new_keys.append(
|
|
||||||
(algorithm, key_id, encode_canonical_json(key).decode("ascii"))
|
|
||||||
)
|
|
||||||
|
|
||||||
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
|
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
|
||||||
await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
|
await self.store.add_e2e_one_time_keys(
|
||||||
|
user_id, device_id, time_now, new_keys
|
||||||
|
)
|
||||||
|
|
||||||
async def upload_signing_keys_for_user(
|
async def upload_signing_keys_for_user(
|
||||||
self, user_id: str, keys: JsonDict
|
self, user_id: str, keys: JsonDict
|
||||||
|
|
|
@ -496,13 +496,6 @@ class EventCreationHandler:
|
||||||
|
|
||||||
self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
|
self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
|
||||||
|
|
||||||
self.membership_types_to_include_profile_data_in = {
|
|
||||||
Membership.JOIN,
|
|
||||||
Membership.KNOCK,
|
|
||||||
}
|
|
||||||
if self.hs.config.server.include_profile_data_on_invite:
|
|
||||||
self.membership_types_to_include_profile_data_in.add(Membership.INVITE)
|
|
||||||
|
|
||||||
self.send_event = ReplicationSendEventRestServlet.make_client(hs)
|
self.send_event = ReplicationSendEventRestServlet.make_client(hs)
|
||||||
self.send_events = ReplicationSendEventsRestServlet.make_client(hs)
|
self.send_events = ReplicationSendEventsRestServlet.make_client(hs)
|
||||||
|
|
||||||
|
@ -594,8 +587,6 @@ class EventCreationHandler:
|
||||||
Creates an FrozenEvent object, filling out auth_events, prev_events,
|
Creates an FrozenEvent object, filling out auth_events, prev_events,
|
||||||
etc.
|
etc.
|
||||||
|
|
||||||
Adds display names to Join membership events.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
requester
|
requester
|
||||||
event_dict: An entire event
|
event_dict: An entire event
|
||||||
|
@ -683,29 +674,6 @@ class EventCreationHandler:
|
||||||
|
|
||||||
self.validator.validate_builder(builder)
|
self.validator.validate_builder(builder)
|
||||||
|
|
||||||
if builder.type == EventTypes.Member:
|
|
||||||
membership = builder.content.get("membership", None)
|
|
||||||
target = UserID.from_string(builder.state_key)
|
|
||||||
|
|
||||||
if membership in self.membership_types_to_include_profile_data_in:
|
|
||||||
# If event doesn't include a display name, add one.
|
|
||||||
profile = self.profile_handler
|
|
||||||
content = builder.content
|
|
||||||
|
|
||||||
try:
|
|
||||||
if "displayname" not in content:
|
|
||||||
displayname = await profile.get_displayname(target)
|
|
||||||
if displayname is not None:
|
|
||||||
content["displayname"] = displayname
|
|
||||||
if "avatar_url" not in content:
|
|
||||||
avatar_url = await profile.get_avatar_url(target)
|
|
||||||
if avatar_url is not None:
|
|
||||||
content["avatar_url"] = avatar_url
|
|
||||||
except Exception as e:
|
|
||||||
logger.info(
|
|
||||||
"Failed to get profile information for %r: %s", target, e
|
|
||||||
)
|
|
||||||
|
|
||||||
is_exempt = await self._is_exempt_from_privacy_policy(builder, requester)
|
is_exempt = await self._is_exempt_from_privacy_policy(builder, requester)
|
||||||
if require_consent and not is_exempt:
|
if require_consent and not is_exempt:
|
||||||
await self.assert_accepted_privacy_policy(requester)
|
await self.assert_accepted_privacy_policy(requester)
|
||||||
|
|
|
@ -590,7 +590,7 @@ class RegistrationHandler:
|
||||||
# moving away from bare excepts is a good thing to do.
|
# moving away from bare excepts is a good thing to do.
|
||||||
logger.error("Failed to join new user to %r: %r", r, e)
|
logger.error("Failed to join new user to %r: %r", r, e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to join new user to %r: %r", r, e)
|
logger.error("Failed to join new user to %r: %r", r, e, exc_info=True)
|
||||||
|
|
||||||
async def _auto_join_rooms(self, user_id: str) -> None:
|
async def _auto_join_rooms(self, user_id: str) -> None:
|
||||||
"""Automatically joins users to auto join rooms - creating the room in the first place
|
"""Automatically joins users to auto join rooms - creating the room in the first place
|
||||||
|
|
|
@ -393,9 +393,9 @@ class RelationsHandler:
|
||||||
|
|
||||||
# Attempt to find another event to use as the latest event.
|
# Attempt to find another event to use as the latest event.
|
||||||
potential_events, _ = await self._main_store.get_relations_for_event(
|
potential_events, _ = await self._main_store.get_relations_for_event(
|
||||||
|
room_id,
|
||||||
event_id,
|
event_id,
|
||||||
event,
|
event,
|
||||||
room_id,
|
|
||||||
RelationTypes.THREAD,
|
RelationTypes.THREAD,
|
||||||
direction=Direction.FORWARDS,
|
direction=Direction.FORWARDS,
|
||||||
)
|
)
|
||||||
|
|
|
@ -106,6 +106,13 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
self.event_auth_handler = hs.get_event_auth_handler()
|
self.event_auth_handler = hs.get_event_auth_handler()
|
||||||
self._worker_lock_handler = hs.get_worker_locks_handler()
|
self._worker_lock_handler = hs.get_worker_locks_handler()
|
||||||
|
|
||||||
|
self._membership_types_to_include_profile_data_in = {
|
||||||
|
Membership.JOIN,
|
||||||
|
Membership.KNOCK,
|
||||||
|
}
|
||||||
|
if self.hs.config.server.include_profile_data_on_invite:
|
||||||
|
self._membership_types_to_include_profile_data_in.add(Membership.INVITE)
|
||||||
|
|
||||||
self.member_linearizer: Linearizer = Linearizer(name="member")
|
self.member_linearizer: Linearizer = Linearizer(name="member")
|
||||||
self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter")
|
self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter")
|
||||||
|
|
||||||
|
@ -785,9 +792,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
if (
|
if (
|
||||||
not self.allow_per_room_profiles and not is_requester_server_notices_user
|
not self.allow_per_room_profiles and not is_requester_server_notices_user
|
||||||
) or requester.shadow_banned:
|
) or requester.shadow_banned:
|
||||||
# Strip profile data, knowing that new profile data will be added to the
|
# Strip profile data, knowing that new profile data will be added to
|
||||||
# event's content in event_creation_handler.create_event() using the target's
|
# the event's content below using the target's global profile.
|
||||||
# global profile.
|
|
||||||
content.pop("displayname", None)
|
content.pop("displayname", None)
|
||||||
content.pop("avatar_url", None)
|
content.pop("avatar_url", None)
|
||||||
|
|
||||||
|
@ -823,6 +829,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
if action in ["kick", "unban"]:
|
if action in ["kick", "unban"]:
|
||||||
effective_membership_state = "leave"
|
effective_membership_state = "leave"
|
||||||
|
|
||||||
|
if effective_membership_state not in Membership.LIST:
|
||||||
|
raise SynapseError(400, "Invalid membership key")
|
||||||
|
|
||||||
|
# Add profile data for joins etc, if no per-room profile.
|
||||||
|
if (
|
||||||
|
effective_membership_state
|
||||||
|
in self._membership_types_to_include_profile_data_in
|
||||||
|
):
|
||||||
|
# If event doesn't include a display name, add one.
|
||||||
|
profile = self.profile_handler
|
||||||
|
|
||||||
|
try:
|
||||||
|
if "displayname" not in content:
|
||||||
|
displayname = await profile.get_displayname(target)
|
||||||
|
if displayname is not None:
|
||||||
|
content["displayname"] = displayname
|
||||||
|
if "avatar_url" not in content:
|
||||||
|
avatar_url = await profile.get_avatar_url(target)
|
||||||
|
if avatar_url is not None:
|
||||||
|
content["avatar_url"] = avatar_url
|
||||||
|
except Exception as e:
|
||||||
|
logger.info("Failed to get profile information for %r: %s", target, e)
|
||||||
|
|
||||||
# if this is a join with a 3pid signature, we may need to turn a 3pid
|
# if this is a join with a 3pid signature, we may need to turn a 3pid
|
||||||
# invite into a normal invite before we can handle the join.
|
# invite into a normal invite before we can handle the join.
|
||||||
if third_party_signed is not None:
|
if third_party_signed is not None:
|
||||||
|
|
|
@ -28,11 +28,14 @@ from typing import (
|
||||||
Dict,
|
Dict,
|
||||||
FrozenSet,
|
FrozenSet,
|
||||||
List,
|
List,
|
||||||
|
Literal,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
Union,
|
||||||
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -128,6 +131,8 @@ class SyncVersion(Enum):
|
||||||
|
|
||||||
# Traditional `/sync` endpoint
|
# Traditional `/sync` endpoint
|
||||||
SYNC_V2 = "sync_v2"
|
SYNC_V2 = "sync_v2"
|
||||||
|
# Part of MSC3575 Sliding Sync
|
||||||
|
E2EE_SYNC = "e2ee_sync"
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
@ -279,6 +284,43 @@ class SyncResult:
|
||||||
or self.device_lists
|
or self.device_lists
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def empty(next_batch: StreamToken) -> "SyncResult":
|
||||||
|
"Return a new empty result"
|
||||||
|
return SyncResult(
|
||||||
|
next_batch=next_batch,
|
||||||
|
presence=[],
|
||||||
|
account_data=[],
|
||||||
|
joined=[],
|
||||||
|
invited=[],
|
||||||
|
knocked=[],
|
||||||
|
archived=[],
|
||||||
|
to_device=[],
|
||||||
|
device_lists=DeviceListUpdates(),
|
||||||
|
device_one_time_keys_count={},
|
||||||
|
device_unused_fallback_key_types=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class E2eeSyncResult:
|
||||||
|
"""
|
||||||
|
Attributes:
|
||||||
|
next_batch: Token for the next sync
|
||||||
|
to_device: List of direct messages for the device.
|
||||||
|
device_lists: List of user_ids whose devices have changed
|
||||||
|
device_one_time_keys_count: Dict of algorithm to count for one time keys
|
||||||
|
for this device
|
||||||
|
device_unused_fallback_key_types: List of key types that have an unused fallback
|
||||||
|
key
|
||||||
|
"""
|
||||||
|
|
||||||
|
next_batch: StreamToken
|
||||||
|
to_device: List[JsonDict]
|
||||||
|
device_lists: DeviceListUpdates
|
||||||
|
device_one_time_keys_count: JsonMapping
|
||||||
|
device_unused_fallback_key_types: List[str]
|
||||||
|
|
||||||
|
|
||||||
class SyncHandler:
|
class SyncHandler:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
@ -322,6 +364,31 @@ class SyncHandler:
|
||||||
|
|
||||||
self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
|
self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def wait_for_sync_for_user(
|
||||||
|
self,
|
||||||
|
requester: Requester,
|
||||||
|
sync_config: SyncConfig,
|
||||||
|
sync_version: Literal[SyncVersion.SYNC_V2],
|
||||||
|
request_key: SyncRequestKey,
|
||||||
|
since_token: Optional[StreamToken] = None,
|
||||||
|
timeout: int = 0,
|
||||||
|
full_state: bool = False,
|
||||||
|
) -> SyncResult: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def wait_for_sync_for_user(
|
||||||
|
self,
|
||||||
|
requester: Requester,
|
||||||
|
sync_config: SyncConfig,
|
||||||
|
sync_version: Literal[SyncVersion.E2EE_SYNC],
|
||||||
|
request_key: SyncRequestKey,
|
||||||
|
since_token: Optional[StreamToken] = None,
|
||||||
|
timeout: int = 0,
|
||||||
|
full_state: bool = False,
|
||||||
|
) -> E2eeSyncResult: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
async def wait_for_sync_for_user(
|
async def wait_for_sync_for_user(
|
||||||
self,
|
self,
|
||||||
requester: Requester,
|
requester: Requester,
|
||||||
|
@ -331,7 +398,18 @@ class SyncHandler:
|
||||||
since_token: Optional[StreamToken] = None,
|
since_token: Optional[StreamToken] = None,
|
||||||
timeout: int = 0,
|
timeout: int = 0,
|
||||||
full_state: bool = False,
|
full_state: bool = False,
|
||||||
) -> SyncResult:
|
) -> Union[SyncResult, E2eeSyncResult]: ...
|
||||||
|
|
||||||
|
async def wait_for_sync_for_user(
|
||||||
|
self,
|
||||||
|
requester: Requester,
|
||||||
|
sync_config: SyncConfig,
|
||||||
|
sync_version: SyncVersion,
|
||||||
|
request_key: SyncRequestKey,
|
||||||
|
since_token: Optional[StreamToken] = None,
|
||||||
|
timeout: int = 0,
|
||||||
|
full_state: bool = False,
|
||||||
|
) -> Union[SyncResult, E2eeSyncResult]:
|
||||||
"""Get the sync for a client if we have new data for it now. Otherwise
|
"""Get the sync for a client if we have new data for it now. Otherwise
|
||||||
wait for new data to arrive on the server. If the timeout expires, then
|
wait for new data to arrive on the server. If the timeout expires, then
|
||||||
return an empty sync result.
|
return an empty sync result.
|
||||||
|
@ -344,8 +422,10 @@ class SyncHandler:
|
||||||
since_token: The point in the stream to sync from.
|
since_token: The point in the stream to sync from.
|
||||||
timeout: How long to wait for new data to arrive before giving up.
|
timeout: How long to wait for new data to arrive before giving up.
|
||||||
full_state: Whether to return the full state for each room.
|
full_state: Whether to return the full state for each room.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
|
When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
|
||||||
|
When `SyncVersion.E2EE_SYNC`, returns a `E2eeSyncResult`.
|
||||||
"""
|
"""
|
||||||
# If the user is not part of the mau group, then check that limits have
|
# If the user is not part of the mau group, then check that limits have
|
||||||
# not been exceeded (if not part of the group by this point, almost certain
|
# not been exceeded (if not part of the group by this point, almost certain
|
||||||
|
@ -366,6 +446,29 @@ class SyncHandler:
|
||||||
logger.debug("Returning sync response for %s", user_id)
|
logger.debug("Returning sync response for %s", user_id)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def _wait_for_sync_for_user(
|
||||||
|
self,
|
||||||
|
sync_config: SyncConfig,
|
||||||
|
sync_version: Literal[SyncVersion.SYNC_V2],
|
||||||
|
since_token: Optional[StreamToken],
|
||||||
|
timeout: int,
|
||||||
|
full_state: bool,
|
||||||
|
cache_context: ResponseCacheContext[SyncRequestKey],
|
||||||
|
) -> SyncResult: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def _wait_for_sync_for_user(
|
||||||
|
self,
|
||||||
|
sync_config: SyncConfig,
|
||||||
|
sync_version: Literal[SyncVersion.E2EE_SYNC],
|
||||||
|
since_token: Optional[StreamToken],
|
||||||
|
timeout: int,
|
||||||
|
full_state: bool,
|
||||||
|
cache_context: ResponseCacheContext[SyncRequestKey],
|
||||||
|
) -> E2eeSyncResult: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
async def _wait_for_sync_for_user(
|
async def _wait_for_sync_for_user(
|
||||||
self,
|
self,
|
||||||
sync_config: SyncConfig,
|
sync_config: SyncConfig,
|
||||||
|
@ -374,7 +477,17 @@ class SyncHandler:
|
||||||
timeout: int,
|
timeout: int,
|
||||||
full_state: bool,
|
full_state: bool,
|
||||||
cache_context: ResponseCacheContext[SyncRequestKey],
|
cache_context: ResponseCacheContext[SyncRequestKey],
|
||||||
) -> SyncResult:
|
) -> Union[SyncResult, E2eeSyncResult]: ...
|
||||||
|
|
||||||
|
async def _wait_for_sync_for_user(
|
||||||
|
self,
|
||||||
|
sync_config: SyncConfig,
|
||||||
|
sync_version: SyncVersion,
|
||||||
|
since_token: Optional[StreamToken],
|
||||||
|
timeout: int,
|
||||||
|
full_state: bool,
|
||||||
|
cache_context: ResponseCacheContext[SyncRequestKey],
|
||||||
|
) -> Union[SyncResult, E2eeSyncResult]:
|
||||||
"""The start of the machinery that produces a /sync response.
|
"""The start of the machinery that produces a /sync response.
|
||||||
|
|
||||||
See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details.
|
See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details.
|
||||||
|
@ -401,6 +514,24 @@ class SyncHandler:
|
||||||
if context:
|
if context:
|
||||||
context.tag = sync_label
|
context.tag = sync_label
|
||||||
|
|
||||||
|
if since_token is not None:
|
||||||
|
# We need to make sure this worker has caught up with the token. If
|
||||||
|
# this returns false it means we timed out waiting, and we should
|
||||||
|
# just return an empty response.
|
||||||
|
start = self.clock.time_msec()
|
||||||
|
if not await self.notifier.wait_for_stream_token(since_token):
|
||||||
|
logger.warning(
|
||||||
|
"Timed out waiting for worker to catch up. Returning empty response"
|
||||||
|
)
|
||||||
|
return SyncResult.empty(since_token)
|
||||||
|
|
||||||
|
# If we've spent significant time waiting to catch up, take it off
|
||||||
|
# the timeout.
|
||||||
|
now = self.clock.time_msec()
|
||||||
|
if now - start > 1_000:
|
||||||
|
timeout -= now - start
|
||||||
|
timeout = max(timeout, 0)
|
||||||
|
|
||||||
# if we have a since token, delete any to-device messages before that token
|
# if we have a since token, delete any to-device messages before that token
|
||||||
# (since we now know that the device has received them)
|
# (since we now know that the device has received them)
|
||||||
if since_token is not None:
|
if since_token is not None:
|
||||||
|
@ -417,14 +548,16 @@ class SyncHandler:
|
||||||
if timeout == 0 or since_token is None or full_state:
|
if timeout == 0 or since_token is None or full_state:
|
||||||
# we are going to return immediately, so don't bother calling
|
# we are going to return immediately, so don't bother calling
|
||||||
# notifier.wait_for_events.
|
# notifier.wait_for_events.
|
||||||
result: SyncResult = await self.current_sync_for_user(
|
result: Union[SyncResult, E2eeSyncResult] = (
|
||||||
sync_config, sync_version, since_token, full_state=full_state
|
await self.current_sync_for_user(
|
||||||
|
sync_config, sync_version, since_token, full_state=full_state
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Otherwise, we wait for something to happen and report it to the user.
|
# Otherwise, we wait for something to happen and report it to the user.
|
||||||
async def current_sync_callback(
|
async def current_sync_callback(
|
||||||
before_token: StreamToken, after_token: StreamToken
|
before_token: StreamToken, after_token: StreamToken
|
||||||
) -> SyncResult:
|
) -> Union[SyncResult, E2eeSyncResult]:
|
||||||
return await self.current_sync_for_user(
|
return await self.current_sync_for_user(
|
||||||
sync_config, sync_version, since_token
|
sync_config, sync_version, since_token
|
||||||
)
|
)
|
||||||
|
@ -456,14 +589,43 @@ class SyncHandler:
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def current_sync_for_user(
|
||||||
|
self,
|
||||||
|
sync_config: SyncConfig,
|
||||||
|
sync_version: Literal[SyncVersion.SYNC_V2],
|
||||||
|
since_token: Optional[StreamToken] = None,
|
||||||
|
full_state: bool = False,
|
||||||
|
) -> SyncResult: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def current_sync_for_user(
|
||||||
|
self,
|
||||||
|
sync_config: SyncConfig,
|
||||||
|
sync_version: Literal[SyncVersion.E2EE_SYNC],
|
||||||
|
since_token: Optional[StreamToken] = None,
|
||||||
|
full_state: bool = False,
|
||||||
|
) -> E2eeSyncResult: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
async def current_sync_for_user(
|
async def current_sync_for_user(
|
||||||
self,
|
self,
|
||||||
sync_config: SyncConfig,
|
sync_config: SyncConfig,
|
||||||
sync_version: SyncVersion,
|
sync_version: SyncVersion,
|
||||||
since_token: Optional[StreamToken] = None,
|
since_token: Optional[StreamToken] = None,
|
||||||
full_state: bool = False,
|
full_state: bool = False,
|
||||||
) -> SyncResult:
|
) -> Union[SyncResult, E2eeSyncResult]: ...
|
||||||
"""Generates the response body of a sync result, represented as a SyncResult.
|
|
||||||
|
async def current_sync_for_user(
|
||||||
|
self,
|
||||||
|
sync_config: SyncConfig,
|
||||||
|
sync_version: SyncVersion,
|
||||||
|
since_token: Optional[StreamToken] = None,
|
||||||
|
full_state: bool = False,
|
||||||
|
) -> Union[SyncResult, E2eeSyncResult]:
|
||||||
|
"""
|
||||||
|
Generates the response body of a sync result, represented as a
|
||||||
|
`SyncResult`/`E2eeSyncResult`.
|
||||||
|
|
||||||
This is a wrapper around `generate_sync_result` which starts an open tracing
|
This is a wrapper around `generate_sync_result` which starts an open tracing
|
||||||
span to track the sync. See `generate_sync_result` for the next part of your
|
span to track the sync. See `generate_sync_result` for the next part of your
|
||||||
|
@ -474,15 +636,25 @@ class SyncHandler:
|
||||||
sync_version: Determines what kind of sync response to generate.
|
sync_version: Determines what kind of sync response to generate.
|
||||||
since_token: The point in the stream to sync from.p.
|
since_token: The point in the stream to sync from.p.
|
||||||
full_state: Whether to return the full state for each room.
|
full_state: Whether to return the full state for each room.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
|
When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
|
||||||
|
When `SyncVersion.E2EE_SYNC`, returns a `E2eeSyncResult`.
|
||||||
"""
|
"""
|
||||||
with start_active_span("sync.current_sync_for_user"):
|
with start_active_span("sync.current_sync_for_user"):
|
||||||
log_kv({"since_token": since_token})
|
log_kv({"since_token": since_token})
|
||||||
|
|
||||||
# Go through the `/sync` v2 path
|
# Go through the `/sync` v2 path
|
||||||
if sync_version == SyncVersion.SYNC_V2:
|
if sync_version == SyncVersion.SYNC_V2:
|
||||||
sync_result: SyncResult = await self.generate_sync_result(
|
sync_result: Union[SyncResult, E2eeSyncResult] = (
|
||||||
sync_config, since_token, full_state
|
await self.generate_sync_result(
|
||||||
|
sync_config, since_token, full_state
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Go through the MSC3575 Sliding Sync `/sync/e2ee` path
|
||||||
|
elif sync_version == SyncVersion.E2EE_SYNC:
|
||||||
|
sync_result = await self.generate_e2ee_sync_result(
|
||||||
|
sync_config, since_token
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -1691,6 +1863,96 @@ class SyncHandler:
|
||||||
next_batch=sync_result_builder.now_token,
|
next_batch=sync_result_builder.now_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def generate_e2ee_sync_result(
|
||||||
|
self,
|
||||||
|
sync_config: SyncConfig,
|
||||||
|
since_token: Optional[StreamToken] = None,
|
||||||
|
) -> E2eeSyncResult:
|
||||||
|
"""
|
||||||
|
Generates the response body of a MSC3575 Sliding Sync `/sync/e2ee` result.
|
||||||
|
|
||||||
|
This is represented by a `E2eeSyncResult` struct, which is built from small
|
||||||
|
pieces using a `SyncResultBuilder`. The `sync_result_builder` is passed as a
|
||||||
|
mutable ("inout") parameter to various helper functions. These retrieve and
|
||||||
|
process the data which forms the sync body, often writing to the
|
||||||
|
`sync_result_builder` to store their output.
|
||||||
|
|
||||||
|
At the end, we transfer data from the `sync_result_builder` to a new `E2eeSyncResult`
|
||||||
|
instance to signify that the sync calculation is complete.
|
||||||
|
"""
|
||||||
|
user_id = sync_config.user.to_string()
|
||||||
|
app_service = self.store.get_app_service_by_user_id(user_id)
|
||||||
|
if app_service:
|
||||||
|
# We no longer support AS users using /sync directly.
|
||||||
|
# See https://github.com/matrix-org/matrix-doc/issues/1144
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
sync_result_builder = await self.get_sync_result_builder(
|
||||||
|
sync_config,
|
||||||
|
since_token,
|
||||||
|
full_state=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Calculate `to_device` events
|
||||||
|
await self._generate_sync_entry_for_to_device(sync_result_builder)
|
||||||
|
|
||||||
|
# 2. Calculate `device_lists`
|
||||||
|
# Device list updates are sent if a since token is provided.
|
||||||
|
device_lists = DeviceListUpdates()
|
||||||
|
include_device_list_updates = bool(since_token and since_token.device_list_key)
|
||||||
|
if include_device_list_updates:
|
||||||
|
# Note that _generate_sync_entry_for_rooms sets sync_result_builder.joined, which
|
||||||
|
# is used in calculate_user_changes below.
|
||||||
|
#
|
||||||
|
# TODO: Running `_generate_sync_entry_for_rooms()` is a lot of work just to
|
||||||
|
# figure out the membership changes/derived info needed for
|
||||||
|
# `_generate_sync_entry_for_device_list()`. In the future, we should try to
|
||||||
|
# refactor this away.
|
||||||
|
(
|
||||||
|
newly_joined_rooms,
|
||||||
|
newly_left_rooms,
|
||||||
|
) = await self._generate_sync_entry_for_rooms(sync_result_builder)
|
||||||
|
|
||||||
|
# This uses the sync_result_builder.joined which is set in
|
||||||
|
# `_generate_sync_entry_for_rooms`, if that didn't find any joined
|
||||||
|
# rooms for some reason it is a no-op.
|
||||||
|
(
|
||||||
|
newly_joined_or_invited_or_knocked_users,
|
||||||
|
newly_left_users,
|
||||||
|
) = sync_result_builder.calculate_user_changes()
|
||||||
|
|
||||||
|
device_lists = await self._generate_sync_entry_for_device_list(
|
||||||
|
sync_result_builder,
|
||||||
|
newly_joined_rooms=newly_joined_rooms,
|
||||||
|
newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
|
||||||
|
newly_left_rooms=newly_left_rooms,
|
||||||
|
newly_left_users=newly_left_users,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Calculate `device_one_time_keys_count` and `device_unused_fallback_key_types`
|
||||||
|
device_id = sync_config.device_id
|
||||||
|
one_time_keys_count: JsonMapping = {}
|
||||||
|
unused_fallback_key_types: List[str] = []
|
||||||
|
if device_id:
|
||||||
|
# TODO: We should have a way to let clients differentiate between the states of:
|
||||||
|
# * no change in OTK count since the provided since token
|
||||||
|
# * the server has zero OTKs left for this device
|
||||||
|
# Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
|
||||||
|
one_time_keys_count = await self.store.count_e2e_one_time_keys(
|
||||||
|
user_id, device_id
|
||||||
|
)
|
||||||
|
unused_fallback_key_types = list(
|
||||||
|
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
return E2eeSyncResult(
|
||||||
|
to_device=sync_result_builder.to_device,
|
||||||
|
device_lists=device_lists,
|
||||||
|
device_one_time_keys_count=one_time_keys_count,
|
||||||
|
device_unused_fallback_key_types=unused_fallback_key_types,
|
||||||
|
next_batch=sync_result_builder.now_token,
|
||||||
|
)
|
||||||
|
|
||||||
async def get_sync_result_builder(
|
async def get_sync_result_builder(
|
||||||
self,
|
self,
|
||||||
sync_config: SyncConfig,
|
sync_config: SyncConfig,
|
||||||
|
@ -1889,7 +2151,7 @@ class SyncHandler:
|
||||||
users_that_have_changed = (
|
users_that_have_changed = (
|
||||||
await self._device_handler.get_device_changes_in_shared_rooms(
|
await self._device_handler.get_device_changes_in_shared_rooms(
|
||||||
user_id,
|
user_id,
|
||||||
sync_result_builder.joined_room_ids,
|
joined_room_ids,
|
||||||
from_token=since_token,
|
from_token=since_token,
|
||||||
now_token=sync_result_builder.now_token,
|
now_token=sync_result_builder.now_token,
|
||||||
)
|
)
|
||||||
|
|
|
@ -477,9 +477,9 @@ class TypingWriterHandler(FollowerTypingHandler):
|
||||||
|
|
||||||
rows = []
|
rows = []
|
||||||
for room_id in changed_rooms:
|
for room_id in changed_rooms:
|
||||||
serial = self._room_serials[room_id]
|
serial = self._room_serials.get(room_id)
|
||||||
if last_id < serial <= current_id:
|
if serial and last_id < serial <= current_id:
|
||||||
typing = self._room_typing[room_id]
|
typing = self._room_typing.get(room_id, set())
|
||||||
rows.append((serial, [room_id, list(typing)]))
|
rows.append((serial, [room_id, list(typing)]))
|
||||||
rows.sort()
|
rows.sort()
|
||||||
|
|
||||||
|
|
|
@ -650,7 +650,7 @@ class MediaRepository:
|
||||||
|
|
||||||
file_info = FileInfo(server_name=server_name, file_id=file_id)
|
file_info = FileInfo(server_name=server_name, file_id=file_id)
|
||||||
|
|
||||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
async with self.media_storage.store_into_file(file_info) as (f, fname):
|
||||||
try:
|
try:
|
||||||
length, headers = await self.client.download_media(
|
length, headers = await self.client.download_media(
|
||||||
server_name,
|
server_name,
|
||||||
|
@ -693,8 +693,6 @@ class MediaRepository:
|
||||||
)
|
)
|
||||||
raise SynapseError(502, "Failed to fetch remote media")
|
raise SynapseError(502, "Failed to fetch remote media")
|
||||||
|
|
||||||
await finish()
|
|
||||||
|
|
||||||
if b"Content-Type" in headers:
|
if b"Content-Type" in headers:
|
||||||
media_type = headers[b"Content-Type"][0].decode("ascii")
|
media_type = headers[b"Content-Type"][0].decode("ascii")
|
||||||
else:
|
else:
|
||||||
|
@ -1045,17 +1043,17 @@ class MediaRepository:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.media_storage.store_into_file(file_info) as (
|
async with self.media_storage.store_into_file(file_info) as (f, fname):
|
||||||
f,
|
|
||||||
fname,
|
|
||||||
finish,
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
await self.media_storage.write_to_file(t_byte_source, f)
|
await self.media_storage.write_to_file(t_byte_source, f)
|
||||||
await finish()
|
|
||||||
finally:
|
finally:
|
||||||
t_byte_source.close()
|
t_byte_source.close()
|
||||||
|
|
||||||
|
# We flush and close the file to ensure that the bytes have
|
||||||
|
# been written before getting the size.
|
||||||
|
f.flush()
|
||||||
|
f.close()
|
||||||
|
|
||||||
t_len = os.path.getsize(fname)
|
t_len = os.path.getsize(fname)
|
||||||
|
|
||||||
# Write to database
|
# Write to database
|
||||||
|
|
|
@ -27,10 +27,9 @@ from typing import (
|
||||||
IO,
|
IO,
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
AsyncIterator,
|
||||||
BinaryIO,
|
BinaryIO,
|
||||||
Callable,
|
Callable,
|
||||||
Generator,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
@ -97,11 +96,9 @@ class MediaStorage:
|
||||||
the file path written to in the primary media store
|
the file path written to in the primary media store
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with self.store_into_file(file_info) as (f, fname, finish_cb):
|
async with self.store_into_file(file_info) as (f, fname):
|
||||||
# Write to the main media repository
|
# Write to the main media repository
|
||||||
await self.write_to_file(source, f)
|
await self.write_to_file(source, f)
|
||||||
# Write to the other storage providers
|
|
||||||
await finish_cb()
|
|
||||||
|
|
||||||
return fname
|
return fname
|
||||||
|
|
||||||
|
@ -111,32 +108,27 @@ class MediaStorage:
|
||||||
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
|
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
|
||||||
|
|
||||||
@trace_with_opname("MediaStorage.store_into_file")
|
@trace_with_opname("MediaStorage.store_into_file")
|
||||||
@contextlib.contextmanager
|
@contextlib.asynccontextmanager
|
||||||
def store_into_file(
|
async def store_into_file(
|
||||||
self, file_info: FileInfo
|
self, file_info: FileInfo
|
||||||
) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
|
) -> AsyncIterator[Tuple[BinaryIO, str]]:
|
||||||
"""Context manager used to get a file like object to write into, as
|
"""Async Context manager used to get a file like object to write into, as
|
||||||
described by file_info.
|
described by file_info.
|
||||||
|
|
||||||
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
|
Actually yields a 2-tuple (file, fname,), where file is a file
|
||||||
like object that can be written to, fname is the absolute path of file
|
like object that can be written to and fname is the absolute path of file
|
||||||
on disk, and finish_cb is a function that returns an awaitable.
|
on disk.
|
||||||
|
|
||||||
fname can be used to read the contents from after upload, e.g. to
|
fname can be used to read the contents from after upload, e.g. to
|
||||||
generate thumbnails.
|
generate thumbnails.
|
||||||
|
|
||||||
finish_cb must be called and waited on after the file has been successfully been
|
|
||||||
written to. Should not be called if there was an error. Checks for spam and
|
|
||||||
stores the file into the configured storage providers.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_info: Info about the file to store
|
file_info: Info about the file to store
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
with media_storage.store_into_file(info) as (f, fname, finish_cb):
|
async with media_storage.store_into_file(info) as (f, fname,):
|
||||||
# .. write into f ...
|
# .. write into f ...
|
||||||
await finish_cb()
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
path = self._file_info_to_path(file_info)
|
path = self._file_info_to_path(file_info)
|
||||||
|
@ -145,63 +137,38 @@ class MediaStorage:
|
||||||
dirname = os.path.dirname(fname)
|
dirname = os.path.dirname(fname)
|
||||||
os.makedirs(dirname, exist_ok=True)
|
os.makedirs(dirname, exist_ok=True)
|
||||||
|
|
||||||
finished_called = [False]
|
|
||||||
|
|
||||||
main_media_repo_write_trace_scope = start_active_span(
|
|
||||||
"writing to main media repo"
|
|
||||||
)
|
|
||||||
main_media_repo_write_trace_scope.__enter__()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(fname, "wb") as f:
|
with start_active_span("writing to main media repo"):
|
||||||
|
with open(fname, "wb") as f:
|
||||||
|
yield f, fname
|
||||||
|
|
||||||
async def finish() -> None:
|
with start_active_span("writing to other storage providers"):
|
||||||
# When someone calls finish, we assume they are done writing to the main media repo
|
spam_check = (
|
||||||
main_media_repo_write_trace_scope.__exit__(None, None, None)
|
await self._spam_checker_module_callbacks.check_media_file_for_spam(
|
||||||
|
ReadableFileWrapper(self.clock, fname), file_info
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
|
||||||
|
logger.info("Blocking media due to spam checker")
|
||||||
|
# Note that we'll delete the stored media, due to the
|
||||||
|
# try/except below. The media also won't be stored in
|
||||||
|
# the DB.
|
||||||
|
# We currently ignore any additional field returned by
|
||||||
|
# the spam-check API.
|
||||||
|
raise SpamMediaException(errcode=spam_check[0])
|
||||||
|
|
||||||
with start_active_span("writing to other storage providers"):
|
for provider in self.storage_providers:
|
||||||
# Ensure that all writes have been flushed and close the
|
with start_active_span(str(provider)):
|
||||||
# file.
|
await provider.store_file(path, file_info)
|
||||||
f.flush()
|
|
||||||
f.close()
|
|
||||||
|
|
||||||
spam_check = await self._spam_checker_module_callbacks.check_media_file_for_spam(
|
|
||||||
ReadableFileWrapper(self.clock, fname), file_info
|
|
||||||
)
|
|
||||||
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
|
|
||||||
logger.info("Blocking media due to spam checker")
|
|
||||||
# Note that we'll delete the stored media, due to the
|
|
||||||
# try/except below. The media also won't be stored in
|
|
||||||
# the DB.
|
|
||||||
# We currently ignore any additional field returned by
|
|
||||||
# the spam-check API.
|
|
||||||
raise SpamMediaException(errcode=spam_check[0])
|
|
||||||
|
|
||||||
for provider in self.storage_providers:
|
|
||||||
with start_active_span(str(provider)):
|
|
||||||
await provider.store_file(path, file_info)
|
|
||||||
|
|
||||||
finished_called[0] = True
|
|
||||||
|
|
||||||
yield f, fname, finish
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
try:
|
try:
|
||||||
main_media_repo_write_trace_scope.__exit__(
|
|
||||||
type(e), None, e.__traceback__
|
|
||||||
)
|
|
||||||
os.remove(fname)
|
os.remove(fname)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
raise e from None
|
raise e from None
|
||||||
|
|
||||||
if not finished_called:
|
|
||||||
exc = Exception("Finished callback not called")
|
|
||||||
main_media_repo_write_trace_scope.__exit__(
|
|
||||||
type(exc), None, exc.__traceback__
|
|
||||||
)
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
|
async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
|
||||||
"""Attempts to fetch media described by file_info from the local cache
|
"""Attempts to fetch media described by file_info from the local cache
|
||||||
and configured storage providers.
|
and configured storage providers.
|
||||||
|
|
|
@ -22,11 +22,27 @@
|
||||||
import logging
|
import logging
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Optional, Tuple, Type
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from synapse.api.errors import Codes, SynapseError, cs_error
|
||||||
|
from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
|
||||||
|
from synapse.http.server import respond_with_json
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.opentracing import trace
|
from synapse.logging.opentracing import trace
|
||||||
|
from synapse.media._base import (
|
||||||
|
FileInfo,
|
||||||
|
ThumbnailInfo,
|
||||||
|
respond_404,
|
||||||
|
respond_with_file,
|
||||||
|
respond_with_responder,
|
||||||
|
)
|
||||||
|
from synapse.media.media_storage import MediaStorage
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.media.media_repository import MediaRepository
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -231,3 +247,471 @@ class Thumbnailer:
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
# Make sure we actually do close the image, rather than leak data.
|
# Make sure we actually do close the image, rather than leak data.
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
|
class ThumbnailProvider:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hs: "HomeServer",
|
||||||
|
media_repo: "MediaRepository",
|
||||||
|
media_storage: MediaStorage,
|
||||||
|
):
|
||||||
|
self.hs = hs
|
||||||
|
self.media_repo = media_repo
|
||||||
|
self.media_storage = media_storage
|
||||||
|
self.store = hs.get_datastores().main
|
||||||
|
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
|
||||||
|
|
||||||
|
async def respond_local_thumbnail(
|
||||||
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
media_id: str,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
method: str,
|
||||||
|
m_type: str,
|
||||||
|
max_timeout_ms: int,
|
||||||
|
) -> None:
|
||||||
|
media_info = await self.media_repo.get_local_media_info(
|
||||||
|
request, media_id, max_timeout_ms
|
||||||
|
)
|
||||||
|
if not media_info:
|
||||||
|
return
|
||||||
|
|
||||||
|
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
|
||||||
|
await self._select_and_respond_with_thumbnail(
|
||||||
|
request,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
method,
|
||||||
|
m_type,
|
||||||
|
thumbnail_infos,
|
||||||
|
media_id,
|
||||||
|
media_id,
|
||||||
|
url_cache=bool(media_info.url_cache),
|
||||||
|
server_name=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def select_or_generate_local_thumbnail(
|
||||||
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
media_id: str,
|
||||||
|
desired_width: int,
|
||||||
|
desired_height: int,
|
||||||
|
desired_method: str,
|
||||||
|
desired_type: str,
|
||||||
|
max_timeout_ms: int,
|
||||||
|
) -> None:
|
||||||
|
media_info = await self.media_repo.get_local_media_info(
|
||||||
|
request, media_id, max_timeout_ms
|
||||||
|
)
|
||||||
|
|
||||||
|
if not media_info:
|
||||||
|
return
|
||||||
|
|
||||||
|
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
|
||||||
|
for info in thumbnail_infos:
|
||||||
|
t_w = info.width == desired_width
|
||||||
|
t_h = info.height == desired_height
|
||||||
|
t_method = info.method == desired_method
|
||||||
|
t_type = info.type == desired_type
|
||||||
|
|
||||||
|
if t_w and t_h and t_method and t_type:
|
||||||
|
file_info = FileInfo(
|
||||||
|
server_name=None,
|
||||||
|
file_id=media_id,
|
||||||
|
url_cache=bool(media_info.url_cache),
|
||||||
|
thumbnail=info,
|
||||||
|
)
|
||||||
|
|
||||||
|
responder = await self.media_storage.fetch_media(file_info)
|
||||||
|
if responder:
|
||||||
|
await respond_with_responder(
|
||||||
|
request, responder, info.type, info.length
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug("We don't have a thumbnail of that size. Generating")
|
||||||
|
|
||||||
|
# Okay, so we generate one.
|
||||||
|
file_path = await self.media_repo.generate_local_exact_thumbnail(
|
||||||
|
media_id,
|
||||||
|
desired_width,
|
||||||
|
desired_height,
|
||||||
|
desired_method,
|
||||||
|
desired_type,
|
||||||
|
url_cache=bool(media_info.url_cache),
|
||||||
|
)
|
||||||
|
|
||||||
|
if file_path:
|
||||||
|
await respond_with_file(request, desired_type, file_path)
|
||||||
|
else:
|
||||||
|
logger.warning("Failed to generate thumbnail")
|
||||||
|
raise SynapseError(400, "Failed to generate thumbnail.")
|
||||||
|
|
||||||
|
async def select_or_generate_remote_thumbnail(
|
||||||
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
server_name: str,
|
||||||
|
media_id: str,
|
||||||
|
desired_width: int,
|
||||||
|
desired_height: int,
|
||||||
|
desired_method: str,
|
||||||
|
desired_type: str,
|
||||||
|
max_timeout_ms: int,
|
||||||
|
) -> None:
|
||||||
|
media_info = await self.media_repo.get_remote_media_info(
|
||||||
|
server_name, media_id, max_timeout_ms
|
||||||
|
)
|
||||||
|
if not media_info:
|
||||||
|
respond_404(request)
|
||||||
|
return
|
||||||
|
|
||||||
|
thumbnail_infos = await self.store.get_remote_media_thumbnails(
|
||||||
|
server_name, media_id
|
||||||
|
)
|
||||||
|
|
||||||
|
file_id = media_info.filesystem_id
|
||||||
|
|
||||||
|
for info in thumbnail_infos:
|
||||||
|
t_w = info.width == desired_width
|
||||||
|
t_h = info.height == desired_height
|
||||||
|
t_method = info.method == desired_method
|
||||||
|
t_type = info.type == desired_type
|
||||||
|
|
||||||
|
if t_w and t_h and t_method and t_type:
|
||||||
|
file_info = FileInfo(
|
||||||
|
server_name=server_name,
|
||||||
|
file_id=file_id,
|
||||||
|
thumbnail=info,
|
||||||
|
)
|
||||||
|
|
||||||
|
responder = await self.media_storage.fetch_media(file_info)
|
||||||
|
if responder:
|
||||||
|
await respond_with_responder(
|
||||||
|
request, responder, info.type, info.length
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug("We don't have a thumbnail of that size. Generating")
|
||||||
|
|
||||||
|
# Okay, so we generate one.
|
||||||
|
file_path = await self.media_repo.generate_remote_exact_thumbnail(
|
||||||
|
server_name,
|
||||||
|
file_id,
|
||||||
|
media_id,
|
||||||
|
desired_width,
|
||||||
|
desired_height,
|
||||||
|
desired_method,
|
||||||
|
desired_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
if file_path:
|
||||||
|
await respond_with_file(request, desired_type, file_path)
|
||||||
|
else:
|
||||||
|
logger.warning("Failed to generate thumbnail")
|
||||||
|
raise SynapseError(400, "Failed to generate thumbnail.")
|
||||||
|
|
||||||
|
async def respond_remote_thumbnail(
|
||||||
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
server_name: str,
|
||||||
|
media_id: str,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
method: str,
|
||||||
|
m_type: str,
|
||||||
|
max_timeout_ms: int,
|
||||||
|
) -> None:
|
||||||
|
# TODO: Don't download the whole remote file
|
||||||
|
# We should proxy the thumbnail from the remote server instead of
|
||||||
|
# downloading the remote file and generating our own thumbnails.
|
||||||
|
media_info = await self.media_repo.get_remote_media_info(
|
||||||
|
server_name, media_id, max_timeout_ms
|
||||||
|
)
|
||||||
|
if not media_info:
|
||||||
|
return
|
||||||
|
|
||||||
|
thumbnail_infos = await self.store.get_remote_media_thumbnails(
|
||||||
|
server_name, media_id
|
||||||
|
)
|
||||||
|
await self._select_and_respond_with_thumbnail(
|
||||||
|
request,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
method,
|
||||||
|
m_type,
|
||||||
|
thumbnail_infos,
|
||||||
|
media_id,
|
||||||
|
media_info.filesystem_id,
|
||||||
|
url_cache=False,
|
||||||
|
server_name=server_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _select_and_respond_with_thumbnail(
|
||||||
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
desired_width: int,
|
||||||
|
desired_height: int,
|
||||||
|
desired_method: str,
|
||||||
|
desired_type: str,
|
||||||
|
thumbnail_infos: List[ThumbnailInfo],
|
||||||
|
media_id: str,
|
||||||
|
file_id: str,
|
||||||
|
url_cache: bool,
|
||||||
|
server_name: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The incoming request.
|
||||||
|
desired_width: The desired width, the returned thumbnail may be larger than this.
|
||||||
|
desired_height: The desired height, the returned thumbnail may be larger than this.
|
||||||
|
desired_method: The desired method used to generate the thumbnail.
|
||||||
|
desired_type: The desired content-type of the thumbnail.
|
||||||
|
thumbnail_infos: A list of thumbnail info of candidate thumbnails.
|
||||||
|
file_id: The ID of the media that a thumbnail is being requested for.
|
||||||
|
url_cache: True if this is from a URL cache.
|
||||||
|
server_name: The server name, if this is a remote thumbnail.
|
||||||
|
"""
|
||||||
|
logger.debug(
|
||||||
|
"_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s",
|
||||||
|
media_id,
|
||||||
|
desired_width,
|
||||||
|
desired_height,
|
||||||
|
desired_method,
|
||||||
|
thumbnail_infos,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If `dynamic_thumbnails` is enabled, we expect Synapse to go down a
|
||||||
|
# different code path to handle it.
|
||||||
|
assert not self.dynamic_thumbnails
|
||||||
|
|
||||||
|
if thumbnail_infos:
|
||||||
|
file_info = self._select_thumbnail(
|
||||||
|
desired_width,
|
||||||
|
desired_height,
|
||||||
|
desired_method,
|
||||||
|
desired_type,
|
||||||
|
thumbnail_infos,
|
||||||
|
file_id,
|
||||||
|
url_cache,
|
||||||
|
server_name,
|
||||||
|
)
|
||||||
|
if not file_info:
|
||||||
|
logger.info("Couldn't find a thumbnail matching the desired inputs")
|
||||||
|
respond_404(request)
|
||||||
|
return
|
||||||
|
|
||||||
|
# The thumbnail property must exist.
|
||||||
|
assert file_info.thumbnail is not None
|
||||||
|
|
||||||
|
responder = await self.media_storage.fetch_media(file_info)
|
||||||
|
if responder:
|
||||||
|
await respond_with_responder(
|
||||||
|
request,
|
||||||
|
responder,
|
||||||
|
file_info.thumbnail.type,
|
||||||
|
file_info.thumbnail.length,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# If we can't find the thumbnail we regenerate it. This can happen
|
||||||
|
# if e.g. we've deleted the thumbnails but still have the original
|
||||||
|
# image somewhere.
|
||||||
|
#
|
||||||
|
# Since we have an entry for the thumbnail in the DB we a) know we
|
||||||
|
# have have successfully generated the thumbnail in the past (so we
|
||||||
|
# don't need to worry about repeatedly failing to generate
|
||||||
|
# thumbnails), and b) have already calculated that appropriate
|
||||||
|
# width/height/method so we can just call the "generate exact"
|
||||||
|
# methods.
|
||||||
|
|
||||||
|
# First let's check that we do actually have the original image
|
||||||
|
# still. This will throw a 404 if we don't.
|
||||||
|
# TODO: We should refetch the thumbnails for remote media.
|
||||||
|
await self.media_storage.ensure_media_is_in_local_cache(
|
||||||
|
FileInfo(server_name, file_id, url_cache=url_cache)
|
||||||
|
)
|
||||||
|
|
||||||
|
if server_name:
|
||||||
|
await self.media_repo.generate_remote_exact_thumbnail(
|
||||||
|
server_name,
|
||||||
|
file_id=file_id,
|
||||||
|
media_id=media_id,
|
||||||
|
t_width=file_info.thumbnail.width,
|
||||||
|
t_height=file_info.thumbnail.height,
|
||||||
|
t_method=file_info.thumbnail.method,
|
||||||
|
t_type=file_info.thumbnail.type,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.media_repo.generate_local_exact_thumbnail(
|
||||||
|
media_id=media_id,
|
||||||
|
t_width=file_info.thumbnail.width,
|
||||||
|
t_height=file_info.thumbnail.height,
|
||||||
|
t_method=file_info.thumbnail.method,
|
||||||
|
t_type=file_info.thumbnail.type,
|
||||||
|
url_cache=url_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
responder = await self.media_storage.fetch_media(file_info)
|
||||||
|
await respond_with_responder(
|
||||||
|
request,
|
||||||
|
responder,
|
||||||
|
file_info.thumbnail.type,
|
||||||
|
file_info.thumbnail.length,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# This might be because:
|
||||||
|
# 1. We can't create thumbnails for the given media (corrupted or
|
||||||
|
# unsupported file type), or
|
||||||
|
# 2. The thumbnailing process never ran or errored out initially
|
||||||
|
# when the media was first uploaded (these bugs should be
|
||||||
|
# reported and fixed).
|
||||||
|
# Note that we don't attempt to generate a thumbnail now because
|
||||||
|
# `dynamic_thumbnails` is disabled.
|
||||||
|
logger.info("Failed to find any generated thumbnails")
|
||||||
|
|
||||||
|
assert request.path is not None
|
||||||
|
respond_with_json(
|
||||||
|
request,
|
||||||
|
400,
|
||||||
|
cs_error(
|
||||||
|
"Cannot find any thumbnails for the requested media ('%s'). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
|
||||||
|
% (
|
||||||
|
request.path.decode(),
|
||||||
|
", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
|
||||||
|
),
|
||||||
|
code=Codes.UNKNOWN,
|
||||||
|
),
|
||||||
|
send_cors=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _select_thumbnail(
|
||||||
|
self,
|
||||||
|
desired_width: int,
|
||||||
|
desired_height: int,
|
||||||
|
desired_method: str,
|
||||||
|
desired_type: str,
|
||||||
|
thumbnail_infos: List[ThumbnailInfo],
|
||||||
|
file_id: str,
|
||||||
|
url_cache: bool,
|
||||||
|
server_name: Optional[str],
|
||||||
|
) -> Optional[FileInfo]:
|
||||||
|
"""
|
||||||
|
Choose an appropriate thumbnail from the previously generated thumbnails.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
desired_width: The desired width, the returned thumbnail may be larger than this.
|
||||||
|
desired_height: The desired height, the returned thumbnail may be larger than this.
|
||||||
|
desired_method: The desired method used to generate the thumbnail.
|
||||||
|
desired_type: The desired content-type of the thumbnail.
|
||||||
|
thumbnail_infos: A list of thumbnail infos of candidate thumbnails.
|
||||||
|
file_id: The ID of the media that a thumbnail is being requested for.
|
||||||
|
url_cache: True if this is from a URL cache.
|
||||||
|
server_name: The server name, if this is a remote thumbnail.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The thumbnail which best matches the desired parameters.
|
||||||
|
"""
|
||||||
|
desired_method = desired_method.lower()
|
||||||
|
|
||||||
|
# The chosen thumbnail.
|
||||||
|
thumbnail_info = None
|
||||||
|
|
||||||
|
d_w = desired_width
|
||||||
|
d_h = desired_height
|
||||||
|
|
||||||
|
if desired_method == "crop":
|
||||||
|
# Thumbnails that match equal or larger sizes of desired width/height.
|
||||||
|
crop_info_list: List[
|
||||||
|
Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
|
||||||
|
] = []
|
||||||
|
# Other thumbnails.
|
||||||
|
crop_info_list2: List[
|
||||||
|
Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
|
||||||
|
] = []
|
||||||
|
for info in thumbnail_infos:
|
||||||
|
# Skip thumbnails generated with different methods.
|
||||||
|
if info.method != "crop":
|
||||||
|
continue
|
||||||
|
|
||||||
|
t_w = info.width
|
||||||
|
t_h = info.height
|
||||||
|
aspect_quality = abs(d_w * t_h - d_h * t_w)
|
||||||
|
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
|
||||||
|
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
||||||
|
type_quality = desired_type != info.type
|
||||||
|
length_quality = info.length
|
||||||
|
if t_w >= d_w or t_h >= d_h:
|
||||||
|
crop_info_list.append(
|
||||||
|
(
|
||||||
|
aspect_quality,
|
||||||
|
min_quality,
|
||||||
|
size_quality,
|
||||||
|
type_quality,
|
||||||
|
length_quality,
|
||||||
|
info,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
crop_info_list2.append(
|
||||||
|
(
|
||||||
|
aspect_quality,
|
||||||
|
min_quality,
|
||||||
|
size_quality,
|
||||||
|
type_quality,
|
||||||
|
length_quality,
|
||||||
|
info,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Pick the most appropriate thumbnail. Some values of `desired_width` and
|
||||||
|
# `desired_height` may result in a tie, in which case we avoid comparing on
|
||||||
|
# the thumbnail info and pick the thumbnail that appears earlier
|
||||||
|
# in the list of candidates.
|
||||||
|
if crop_info_list:
|
||||||
|
thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1]
|
||||||
|
elif crop_info_list2:
|
||||||
|
thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1]
|
||||||
|
elif desired_method == "scale":
|
||||||
|
# Thumbnails that match equal or larger sizes of desired width/height.
|
||||||
|
info_list: List[Tuple[int, bool, int, ThumbnailInfo]] = []
|
||||||
|
# Other thumbnails.
|
||||||
|
info_list2: List[Tuple[int, bool, int, ThumbnailInfo]] = []
|
||||||
|
|
||||||
|
for info in thumbnail_infos:
|
||||||
|
# Skip thumbnails generated with different methods.
|
||||||
|
if info.method != "scale":
|
||||||
|
continue
|
||||||
|
|
||||||
|
t_w = info.width
|
||||||
|
t_h = info.height
|
||||||
|
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
||||||
|
type_quality = desired_type != info.type
|
||||||
|
length_quality = info.length
|
||||||
|
if t_w >= d_w or t_h >= d_h:
|
||||||
|
info_list.append((size_quality, type_quality, length_quality, info))
|
||||||
|
else:
|
||||||
|
info_list2.append(
|
||||||
|
(size_quality, type_quality, length_quality, info)
|
||||||
|
)
|
||||||
|
# Pick the most appropriate thumbnail. Some values of `desired_width` and
|
||||||
|
# `desired_height` may result in a tie, in which case we avoid comparing on
|
||||||
|
# the thumbnail info and pick the thumbnail that appears earlier
|
||||||
|
# in the list of candidates.
|
||||||
|
if info_list:
|
||||||
|
thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1]
|
||||||
|
elif info_list2:
|
||||||
|
thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1]
|
||||||
|
|
||||||
|
if thumbnail_info:
|
||||||
|
return FileInfo(
|
||||||
|
file_id=file_id,
|
||||||
|
url_cache=url_cache,
|
||||||
|
server_name=server_name,
|
||||||
|
thumbnail=thumbnail_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
# No matching thumbnail was found.
|
||||||
|
return None
|
||||||
|
|
|
@ -592,7 +592,7 @@ class UrlPreviewer:
|
||||||
|
|
||||||
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
|
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
|
||||||
|
|
||||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
async with self.media_storage.store_into_file(file_info) as (f, fname):
|
||||||
if url.startswith("data:"):
|
if url.startswith("data:"):
|
||||||
if not allow_data_urls:
|
if not allow_data_urls:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
|
@ -603,8 +603,6 @@ class UrlPreviewer:
|
||||||
else:
|
else:
|
||||||
download_result = await self._download_url(url, f)
|
download_result = await self._download_url(url, f)
|
||||||
|
|
||||||
await finish()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
time_now_ms = self.clock.time_msec()
|
time_now_ms = self.clock.time_msec()
|
||||||
|
|
||||||
|
|
|
@ -763,6 +763,29 @@ class Notifier:
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def wait_for_stream_token(self, stream_token: StreamToken) -> bool:
|
||||||
|
"""Wait for this worker to catch up with the given stream token."""
|
||||||
|
|
||||||
|
start = self.clock.time_msec()
|
||||||
|
while True:
|
||||||
|
current_token = self.event_sources.get_current_token()
|
||||||
|
if stream_token.is_before_or_eq(current_token):
|
||||||
|
return True
|
||||||
|
|
||||||
|
now = self.clock.time_msec()
|
||||||
|
|
||||||
|
if now - start > 10_000:
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Waiting for current token to reach %s; currently at %s",
|
||||||
|
stream_token,
|
||||||
|
current_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: be better
|
||||||
|
await self.clock.sleep(0.5)
|
||||||
|
|
||||||
async def _get_room_ids(
|
async def _get_room_ids(
|
||||||
self, user: UserID, explicit_room_id: Optional[str]
|
self, user: UserID, explicit_room_id: Optional[str]
|
||||||
) -> Tuple[StrCollection, bool]:
|
) -> Tuple[StrCollection, bool]:
|
||||||
|
|
205
synapse/rest/client/media.py
Normal file
205
synapse/rest/client/media.py
Normal file
|
@ -0,0 +1,205 @@
|
||||||
|
#
|
||||||
|
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
#
|
||||||
|
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
# Copyright (C) 2024 New Vector, Ltd
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as
|
||||||
|
# published by the Free Software Foundation, either version 3 of the
|
||||||
|
# License, or (at your option) any later version.
|
||||||
|
#
|
||||||
|
# See the GNU Affero General Public License for more details:
|
||||||
|
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
#
|
||||||
|
# Originally licensed under the Apache License, Version 2.0:
|
||||||
|
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
||||||
|
#
|
||||||
|
# [This file includes modifications made by New Vector Limited]
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
from synapse.http.server import (
|
||||||
|
HttpServer,
|
||||||
|
respond_with_json,
|
||||||
|
respond_with_json_bytes,
|
||||||
|
set_corp_headers,
|
||||||
|
set_cors_headers,
|
||||||
|
)
|
||||||
|
from synapse.http.servlet import RestServlet, parse_integer, parse_string
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
|
from synapse.media._base import (
|
||||||
|
DEFAULT_MAX_TIMEOUT_MS,
|
||||||
|
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
|
||||||
|
respond_404,
|
||||||
|
)
|
||||||
|
from synapse.media.media_repository import MediaRepository
|
||||||
|
from synapse.media.media_storage import MediaStorage
|
||||||
|
from synapse.media.thumbnailer import ThumbnailProvider
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util.stringutils import parse_and_validate_server_name
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UnstablePreviewURLServlet(RestServlet):
|
||||||
|
"""
|
||||||
|
Same as `GET /_matrix/media/r0/preview_url`, this endpoint provides a generic preview API
|
||||||
|
for URLs which outputs Open Graph (https://ogp.me/) responses (with some Matrix
|
||||||
|
specific additions).
|
||||||
|
|
||||||
|
This does have trade-offs compared to other designs:
|
||||||
|
|
||||||
|
* Pros:
|
||||||
|
* Simple and flexible; can be used by any clients at any point
|
||||||
|
* Cons:
|
||||||
|
* If each homeserver provides one of these independently, all the homeservers in a
|
||||||
|
room may needlessly DoS the target URI
|
||||||
|
* The URL metadata must be stored somewhere, rather than just using Matrix
|
||||||
|
itself to store the media.
|
||||||
|
* Matrix cannot be used to distribute the metadata between homeservers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATTERNS = [
|
||||||
|
re.compile(r"^/_matrix/client/unstable/org.matrix.msc3916/media/preview_url$")
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hs: "HomeServer",
|
||||||
|
media_repo: "MediaRepository",
|
||||||
|
media_storage: MediaStorage,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.media_repo = media_repo
|
||||||
|
self.media_storage = media_storage
|
||||||
|
assert self.media_repo.url_previewer is not None
|
||||||
|
self.url_previewer = self.media_repo.url_previewer
|
||||||
|
|
||||||
|
async def on_GET(self, request: SynapseRequest) -> None:
|
||||||
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
url = parse_string(request, "url", required=True)
|
||||||
|
ts = parse_integer(request, "ts")
|
||||||
|
if ts is None:
|
||||||
|
ts = self.clock.time_msec()
|
||||||
|
|
||||||
|
og = await self.url_previewer.preview(url, requester.user, ts)
|
||||||
|
respond_with_json_bytes(request, 200, og, send_cors=True)
|
||||||
|
|
||||||
|
|
||||||
|
class UnstableMediaConfigResource(RestServlet):
|
||||||
|
PATTERNS = [
|
||||||
|
re.compile(r"^/_matrix/client/unstable/org.matrix.msc3916/media/config$")
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
super().__init__()
|
||||||
|
config = hs.config
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.limits_dict = {"m.upload.size": config.media.max_upload_size}
|
||||||
|
|
||||||
|
async def on_GET(self, request: SynapseRequest) -> None:
|
||||||
|
await self.auth.get_user_by_req(request)
|
||||||
|
respond_with_json(request, 200, self.limits_dict, send_cors=True)
|
||||||
|
|
||||||
|
|
||||||
|
class UnstableThumbnailResource(RestServlet):
|
||||||
|
PATTERNS = [
|
||||||
|
re.compile(
|
||||||
|
"/_matrix/client/unstable/org.matrix.msc3916/media/thumbnail/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hs: "HomeServer",
|
||||||
|
media_repo: "MediaRepository",
|
||||||
|
media_storage: MediaStorage,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.store = hs.get_datastores().main
|
||||||
|
self.media_repo = media_repo
|
||||||
|
self.media_storage = media_storage
|
||||||
|
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
|
||||||
|
self._is_mine_server_name = hs.is_mine_server_name
|
||||||
|
self._server_name = hs.hostname
|
||||||
|
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
|
||||||
|
self.thumbnailer = ThumbnailProvider(hs, media_repo, media_storage)
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, server_name: str, media_id: str
|
||||||
|
) -> None:
|
||||||
|
# Validate the server name, raising if invalid
|
||||||
|
parse_and_validate_server_name(server_name)
|
||||||
|
await self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
set_cors_headers(request)
|
||||||
|
set_corp_headers(request)
|
||||||
|
width = parse_integer(request, "width", required=True)
|
||||||
|
height = parse_integer(request, "height", required=True)
|
||||||
|
method = parse_string(request, "method", "scale")
|
||||||
|
# TODO Parse the Accept header to get an prioritised list of thumbnail types.
|
||||||
|
m_type = "image/png"
|
||||||
|
max_timeout_ms = parse_integer(
|
||||||
|
request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
|
||||||
|
)
|
||||||
|
max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
|
||||||
|
|
||||||
|
if self._is_mine_server_name(server_name):
|
||||||
|
if self.dynamic_thumbnails:
|
||||||
|
await self.thumbnailer.select_or_generate_local_thumbnail(
|
||||||
|
request, media_id, width, height, method, m_type, max_timeout_ms
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.thumbnailer.respond_local_thumbnail(
|
||||||
|
request, media_id, width, height, method, m_type, max_timeout_ms
|
||||||
|
)
|
||||||
|
self.media_repo.mark_recently_accessed(None, media_id)
|
||||||
|
else:
|
||||||
|
# Don't let users download media from configured domains, even if it
|
||||||
|
# is already downloaded. This is Trust & Safety tooling to make some
|
||||||
|
# media inaccessible to local users.
|
||||||
|
# See `prevent_media_downloads_from` config docs for more info.
|
||||||
|
if server_name in self.prevent_media_downloads_from:
|
||||||
|
respond_404(request)
|
||||||
|
return
|
||||||
|
|
||||||
|
remote_resp_function = (
|
||||||
|
self.thumbnailer.select_or_generate_remote_thumbnail
|
||||||
|
if self.dynamic_thumbnails
|
||||||
|
else self.thumbnailer.respond_remote_thumbnail
|
||||||
|
)
|
||||||
|
await remote_resp_function(
|
||||||
|
request,
|
||||||
|
server_name,
|
||||||
|
media_id,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
method,
|
||||||
|
m_type,
|
||||||
|
max_timeout_ms,
|
||||||
|
)
|
||||||
|
self.media_repo.mark_recently_accessed(server_name, media_id)
|
||||||
|
|
||||||
|
|
||||||
|
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||||
|
if hs.config.experimental.msc3916_authenticated_media_enabled:
|
||||||
|
media_repo = hs.get_media_repository()
|
||||||
|
if hs.config.media.url_preview_enabled:
|
||||||
|
UnstablePreviewURLServlet(
|
||||||
|
hs, media_repo, media_repo.media_storage
|
||||||
|
).register(http_server)
|
||||||
|
UnstableMediaConfigResource(hs).register(http_server)
|
||||||
|
UnstableThumbnailResource(hs, media_repo, media_repo.media_storage).register(
|
||||||
|
http_server
|
||||||
|
)
|
|
@ -567,5 +567,176 @@ class SyncRestServlet(RestServlet):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class SlidingSyncE2eeRestServlet(RestServlet):
|
||||||
|
"""
|
||||||
|
API endpoint for MSC3575 Sliding Sync `/sync/e2ee`. This is being introduced as part
|
||||||
|
of Sliding Sync but doesn't have any sliding window component. It's just a way to
|
||||||
|
get E2EE events without having to sit through a big initial sync (`/sync` v2). And
|
||||||
|
we can avoid encryption events being backed up by the main sync response.
|
||||||
|
|
||||||
|
Having To-Device messages split out to this sync endpoint also helps when clients
|
||||||
|
need to have 2 or more sync streams open at a time, e.g a push notification process
|
||||||
|
and a main process. This can cause the two processes to race to fetch the To-Device
|
||||||
|
events, resulting in the need for complex synchronisation rules to ensure the token
|
||||||
|
is correctly and atomically exchanged between processes.
|
||||||
|
|
||||||
|
GET parameters::
|
||||||
|
timeout(int): How long to wait for new events in milliseconds.
|
||||||
|
since(batch_token): Batch token when asking for incremental deltas.
|
||||||
|
|
||||||
|
Response JSON::
|
||||||
|
{
|
||||||
|
"next_batch": // batch token for the next /sync
|
||||||
|
"to_device": {
|
||||||
|
// list of to-device events
|
||||||
|
"events": [
|
||||||
|
{
|
||||||
|
"content: { "algorithm": "m.olm.v1.curve25519-aes-sha2", "ciphertext": { ... }, "org.matrix.msgid": "abcd", "session_id": "abcd" },
|
||||||
|
"type": "m.room.encrypted",
|
||||||
|
"sender": "@alice:example.com",
|
||||||
|
}
|
||||||
|
// ...
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"device_lists": {
|
||||||
|
"changed": ["@alice:example.com"],
|
||||||
|
"left": ["@bob:example.com"]
|
||||||
|
},
|
||||||
|
"device_one_time_keys_count": {
|
||||||
|
"signed_curve25519": 50
|
||||||
|
},
|
||||||
|
"device_unused_fallback_key_types": [
|
||||||
|
"signed_curve25519"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATTERNS = client_patterns(
|
||||||
|
"/org.matrix.msc3575/sync/e2ee$", releases=[], v1=False, unstable=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
super().__init__()
|
||||||
|
self.hs = hs
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.store = hs.get_datastores().main
|
||||||
|
self.sync_handler = hs.get_sync_handler()
|
||||||
|
|
||||||
|
# Filtering only matters for the `device_lists` because it requires a bunch of
|
||||||
|
# derived information from rooms (see how `_generate_sync_entry_for_rooms()`
|
||||||
|
# prepares a bunch of data for `_generate_sync_entry_for_device_list()`).
|
||||||
|
self.only_member_events_filter_collection = FilterCollection(
|
||||||
|
self.hs,
|
||||||
|
{
|
||||||
|
"room": {
|
||||||
|
# We only care about membership events for the `device_lists`.
|
||||||
|
# Membership will tell us whether a user has joined/left a room and
|
||||||
|
# if there are new devices to encrypt for.
|
||||||
|
"timeline": {
|
||||||
|
"types": ["m.room.member"],
|
||||||
|
},
|
||||||
|
"state": {
|
||||||
|
"types": ["m.room.member"],
|
||||||
|
},
|
||||||
|
# We don't want any extra account_data generated because it's not
|
||||||
|
# returned by this endpoint. This helps us avoid work in
|
||||||
|
# `_generate_sync_entry_for_rooms()`
|
||||||
|
"account_data": {
|
||||||
|
"not_types": ["*"],
|
||||||
|
},
|
||||||
|
# We don't want any extra ephemeral data generated because it's not
|
||||||
|
# returned by this endpoint. This helps us avoid work in
|
||||||
|
# `_generate_sync_entry_for_rooms()`
|
||||||
|
"ephemeral": {
|
||||||
|
"not_types": ["*"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# We don't want any extra account_data generated because it's not
|
||||||
|
# returned by this endpoint. (This is just here for good measure)
|
||||||
|
"account_data": {
|
||||||
|
"not_types": ["*"],
|
||||||
|
},
|
||||||
|
# We don't want any extra presence data generated because it's not
|
||||||
|
# returned by this endpoint. (This is just here for good measure)
|
||||||
|
"presence": {
|
||||||
|
"not_types": ["*"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
user = requester.user
|
||||||
|
device_id = requester.device_id
|
||||||
|
|
||||||
|
timeout = parse_integer(request, "timeout", default=0)
|
||||||
|
since = parse_string(request, "since")
|
||||||
|
|
||||||
|
sync_config = SyncConfig(
|
||||||
|
user=user,
|
||||||
|
filter_collection=self.only_member_events_filter_collection,
|
||||||
|
is_guest=requester.is_guest,
|
||||||
|
device_id=device_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
since_token = None
|
||||||
|
if since is not None:
|
||||||
|
since_token = await StreamToken.from_string(self.store, since)
|
||||||
|
|
||||||
|
# Request cache key
|
||||||
|
request_key = (
|
||||||
|
SyncVersion.E2EE_SYNC,
|
||||||
|
user,
|
||||||
|
timeout,
|
||||||
|
since,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gather data for the response
|
||||||
|
sync_result = await self.sync_handler.wait_for_sync_for_user(
|
||||||
|
requester,
|
||||||
|
sync_config,
|
||||||
|
SyncVersion.E2EE_SYNC,
|
||||||
|
request_key,
|
||||||
|
since_token=since_token,
|
||||||
|
timeout=timeout,
|
||||||
|
full_state=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The client may have disconnected by now; don't bother to serialize the
|
||||||
|
# response if so.
|
||||||
|
if request._disconnected:
|
||||||
|
logger.info("Client has disconnected; not serializing response.")
|
||||||
|
return 200, {}
|
||||||
|
|
||||||
|
response: JsonDict = defaultdict(dict)
|
||||||
|
response["next_batch"] = await sync_result.next_batch.to_string(self.store)
|
||||||
|
|
||||||
|
if sync_result.to_device:
|
||||||
|
response["to_device"] = {"events": sync_result.to_device}
|
||||||
|
|
||||||
|
if sync_result.device_lists.changed:
|
||||||
|
response["device_lists"]["changed"] = list(sync_result.device_lists.changed)
|
||||||
|
if sync_result.device_lists.left:
|
||||||
|
response["device_lists"]["left"] = list(sync_result.device_lists.left)
|
||||||
|
|
||||||
|
# We always include this because https://github.com/vector-im/element-android/issues/3725
|
||||||
|
# The spec isn't terribly clear on when this can be omitted and how a client would tell
|
||||||
|
# the difference between "no keys present" and "nothing changed" in terms of whole field
|
||||||
|
# absent / individual key type entry absent
|
||||||
|
# Corresponding synapse issue: https://github.com/matrix-org/synapse/issues/10456
|
||||||
|
response["device_one_time_keys_count"] = sync_result.device_one_time_keys_count
|
||||||
|
|
||||||
|
# https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
|
||||||
|
# states that this field should always be included, as long as the server supports the feature.
|
||||||
|
response["device_unused_fallback_key_types"] = (
|
||||||
|
sync_result.device_unused_fallback_key_types
|
||||||
|
)
|
||||||
|
|
||||||
|
return 200, response
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||||
SyncRestServlet(hs).register(http_server)
|
SyncRestServlet(hs).register(http_server)
|
||||||
|
|
||||||
|
if hs.config.experimental.msc3575_enabled:
|
||||||
|
SlidingSyncE2eeRestServlet(hs).register(http_server)
|
||||||
|
|
|
@ -22,23 +22,18 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError, cs_error
|
from synapse.http.server import set_corp_headers, set_cors_headers
|
||||||
from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
|
|
||||||
from synapse.http.server import respond_with_json, set_corp_headers, set_cors_headers
|
|
||||||
from synapse.http.servlet import RestServlet, parse_integer, parse_string
|
from synapse.http.servlet import RestServlet, parse_integer, parse_string
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.media._base import (
|
from synapse.media._base import (
|
||||||
DEFAULT_MAX_TIMEOUT_MS,
|
DEFAULT_MAX_TIMEOUT_MS,
|
||||||
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
|
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
|
||||||
FileInfo,
|
|
||||||
ThumbnailInfo,
|
|
||||||
respond_404,
|
respond_404,
|
||||||
respond_with_file,
|
|
||||||
respond_with_responder,
|
|
||||||
)
|
)
|
||||||
from synapse.media.media_storage import MediaStorage
|
from synapse.media.media_storage import MediaStorage
|
||||||
|
from synapse.media.thumbnailer import ThumbnailProvider
|
||||||
from synapse.util.stringutils import parse_and_validate_server_name
|
from synapse.util.stringutils import parse_and_validate_server_name
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -66,10 +61,11 @@ class ThumbnailResource(RestServlet):
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.media_repo = media_repo
|
self.media_repo = media_repo
|
||||||
self.media_storage = media_storage
|
self.media_storage = media_storage
|
||||||
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
|
|
||||||
self._is_mine_server_name = hs.is_mine_server_name
|
self._is_mine_server_name = hs.is_mine_server_name
|
||||||
self._server_name = hs.hostname
|
self._server_name = hs.hostname
|
||||||
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
|
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
|
||||||
|
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
|
||||||
|
self.thumbnail_provider = ThumbnailProvider(hs, media_repo, media_storage)
|
||||||
|
|
||||||
async def on_GET(
|
async def on_GET(
|
||||||
self, request: SynapseRequest, server_name: str, media_id: str
|
self, request: SynapseRequest, server_name: str, media_id: str
|
||||||
|
@ -91,11 +87,11 @@ class ThumbnailResource(RestServlet):
|
||||||
|
|
||||||
if self._is_mine_server_name(server_name):
|
if self._is_mine_server_name(server_name):
|
||||||
if self.dynamic_thumbnails:
|
if self.dynamic_thumbnails:
|
||||||
await self._select_or_generate_local_thumbnail(
|
await self.thumbnail_provider.select_or_generate_local_thumbnail(
|
||||||
request, media_id, width, height, method, m_type, max_timeout_ms
|
request, media_id, width, height, method, m_type, max_timeout_ms
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await self._respond_local_thumbnail(
|
await self.thumbnail_provider.respond_local_thumbnail(
|
||||||
request, media_id, width, height, method, m_type, max_timeout_ms
|
request, media_id, width, height, method, m_type, max_timeout_ms
|
||||||
)
|
)
|
||||||
self.media_repo.mark_recently_accessed(None, media_id)
|
self.media_repo.mark_recently_accessed(None, media_id)
|
||||||
|
@ -109,9 +105,9 @@ class ThumbnailResource(RestServlet):
|
||||||
return
|
return
|
||||||
|
|
||||||
remote_resp_function = (
|
remote_resp_function = (
|
||||||
self._select_or_generate_remote_thumbnail
|
self.thumbnail_provider.select_or_generate_remote_thumbnail
|
||||||
if self.dynamic_thumbnails
|
if self.dynamic_thumbnails
|
||||||
else self._respond_remote_thumbnail
|
else self.thumbnail_provider.respond_remote_thumbnail
|
||||||
)
|
)
|
||||||
await remote_resp_function(
|
await remote_resp_function(
|
||||||
request,
|
request,
|
||||||
|
@ -124,457 +120,3 @@ class ThumbnailResource(RestServlet):
|
||||||
max_timeout_ms,
|
max_timeout_ms,
|
||||||
)
|
)
|
||||||
self.media_repo.mark_recently_accessed(server_name, media_id)
|
self.media_repo.mark_recently_accessed(server_name, media_id)
|
||||||
|
|
||||||
async def _respond_local_thumbnail(
|
|
||||||
self,
|
|
||||||
request: SynapseRequest,
|
|
||||||
media_id: str,
|
|
||||||
width: int,
|
|
||||||
height: int,
|
|
||||||
method: str,
|
|
||||||
m_type: str,
|
|
||||||
max_timeout_ms: int,
|
|
||||||
) -> None:
|
|
||||||
media_info = await self.media_repo.get_local_media_info(
|
|
||||||
request, media_id, max_timeout_ms
|
|
||||||
)
|
|
||||||
if not media_info:
|
|
||||||
return
|
|
||||||
|
|
||||||
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
|
|
||||||
await self._select_and_respond_with_thumbnail(
|
|
||||||
request,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
method,
|
|
||||||
m_type,
|
|
||||||
thumbnail_infos,
|
|
||||||
media_id,
|
|
||||||
media_id,
|
|
||||||
url_cache=bool(media_info.url_cache),
|
|
||||||
server_name=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _select_or_generate_local_thumbnail(
|
|
||||||
self,
|
|
||||||
request: SynapseRequest,
|
|
||||||
media_id: str,
|
|
||||||
desired_width: int,
|
|
||||||
desired_height: int,
|
|
||||||
desired_method: str,
|
|
||||||
desired_type: str,
|
|
||||||
max_timeout_ms: int,
|
|
||||||
) -> None:
|
|
||||||
media_info = await self.media_repo.get_local_media_info(
|
|
||||||
request, media_id, max_timeout_ms
|
|
||||||
)
|
|
||||||
|
|
||||||
if not media_info:
|
|
||||||
return
|
|
||||||
|
|
||||||
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
|
|
||||||
for info in thumbnail_infos:
|
|
||||||
t_w = info.width == desired_width
|
|
||||||
t_h = info.height == desired_height
|
|
||||||
t_method = info.method == desired_method
|
|
||||||
t_type = info.type == desired_type
|
|
||||||
|
|
||||||
if t_w and t_h and t_method and t_type:
|
|
||||||
file_info = FileInfo(
|
|
||||||
server_name=None,
|
|
||||||
file_id=media_id,
|
|
||||||
url_cache=bool(media_info.url_cache),
|
|
||||||
thumbnail=info,
|
|
||||||
)
|
|
||||||
|
|
||||||
responder = await self.media_storage.fetch_media(file_info)
|
|
||||||
if responder:
|
|
||||||
await respond_with_responder(
|
|
||||||
request, responder, info.type, info.length
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.debug("We don't have a thumbnail of that size. Generating")
|
|
||||||
|
|
||||||
# Okay, so we generate one.
|
|
||||||
file_path = await self.media_repo.generate_local_exact_thumbnail(
|
|
||||||
media_id,
|
|
||||||
desired_width,
|
|
||||||
desired_height,
|
|
||||||
desired_method,
|
|
||||||
desired_type,
|
|
||||||
url_cache=bool(media_info.url_cache),
|
|
||||||
)
|
|
||||||
|
|
||||||
if file_path:
|
|
||||||
await respond_with_file(request, desired_type, file_path)
|
|
||||||
else:
|
|
||||||
logger.warning("Failed to generate thumbnail")
|
|
||||||
raise SynapseError(400, "Failed to generate thumbnail.")
|
|
||||||
|
|
||||||
async def _select_or_generate_remote_thumbnail(
|
|
||||||
self,
|
|
||||||
request: SynapseRequest,
|
|
||||||
server_name: str,
|
|
||||||
media_id: str,
|
|
||||||
desired_width: int,
|
|
||||||
desired_height: int,
|
|
||||||
desired_method: str,
|
|
||||||
desired_type: str,
|
|
||||||
max_timeout_ms: int,
|
|
||||||
) -> None:
|
|
||||||
media_info = await self.media_repo.get_remote_media_info(
|
|
||||||
server_name, media_id, max_timeout_ms
|
|
||||||
)
|
|
||||||
if not media_info:
|
|
||||||
respond_404(request)
|
|
||||||
return
|
|
||||||
|
|
||||||
thumbnail_infos = await self.store.get_remote_media_thumbnails(
|
|
||||||
server_name, media_id
|
|
||||||
)
|
|
||||||
|
|
||||||
file_id = media_info.filesystem_id
|
|
||||||
|
|
||||||
for info in thumbnail_infos:
|
|
||||||
t_w = info.width == desired_width
|
|
||||||
t_h = info.height == desired_height
|
|
||||||
t_method = info.method == desired_method
|
|
||||||
t_type = info.type == desired_type
|
|
||||||
|
|
||||||
if t_w and t_h and t_method and t_type:
|
|
||||||
file_info = FileInfo(
|
|
||||||
server_name=server_name,
|
|
||||||
file_id=file_id,
|
|
||||||
thumbnail=info,
|
|
||||||
)
|
|
||||||
|
|
||||||
responder = await self.media_storage.fetch_media(file_info)
|
|
||||||
if responder:
|
|
||||||
await respond_with_responder(
|
|
||||||
request, responder, info.type, info.length
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.debug("We don't have a thumbnail of that size. Generating")
|
|
||||||
|
|
||||||
# Okay, so we generate one.
|
|
||||||
file_path = await self.media_repo.generate_remote_exact_thumbnail(
|
|
||||||
server_name,
|
|
||||||
file_id,
|
|
||||||
media_id,
|
|
||||||
desired_width,
|
|
||||||
desired_height,
|
|
||||||
desired_method,
|
|
||||||
desired_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
if file_path:
|
|
||||||
await respond_with_file(request, desired_type, file_path)
|
|
||||||
else:
|
|
||||||
logger.warning("Failed to generate thumbnail")
|
|
||||||
raise SynapseError(400, "Failed to generate thumbnail.")
|
|
||||||
|
|
||||||
async def _respond_remote_thumbnail(
|
|
||||||
self,
|
|
||||||
request: SynapseRequest,
|
|
||||||
server_name: str,
|
|
||||||
media_id: str,
|
|
||||||
width: int,
|
|
||||||
height: int,
|
|
||||||
method: str,
|
|
||||||
m_type: str,
|
|
||||||
max_timeout_ms: int,
|
|
||||||
) -> None:
|
|
||||||
# TODO: Don't download the whole remote file
|
|
||||||
# We should proxy the thumbnail from the remote server instead of
|
|
||||||
# downloading the remote file and generating our own thumbnails.
|
|
||||||
media_info = await self.media_repo.get_remote_media_info(
|
|
||||||
server_name, media_id, max_timeout_ms
|
|
||||||
)
|
|
||||||
if not media_info:
|
|
||||||
return
|
|
||||||
|
|
||||||
thumbnail_infos = await self.store.get_remote_media_thumbnails(
|
|
||||||
server_name, media_id
|
|
||||||
)
|
|
||||||
await self._select_and_respond_with_thumbnail(
|
|
||||||
request,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
method,
|
|
||||||
m_type,
|
|
||||||
thumbnail_infos,
|
|
||||||
media_id,
|
|
||||||
media_info.filesystem_id,
|
|
||||||
url_cache=False,
|
|
||||||
server_name=server_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _select_and_respond_with_thumbnail(
|
|
||||||
self,
|
|
||||||
request: SynapseRequest,
|
|
||||||
desired_width: int,
|
|
||||||
desired_height: int,
|
|
||||||
desired_method: str,
|
|
||||||
desired_type: str,
|
|
||||||
thumbnail_infos: List[ThumbnailInfo],
|
|
||||||
media_id: str,
|
|
||||||
file_id: str,
|
|
||||||
url_cache: bool,
|
|
||||||
server_name: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: The incoming request.
|
|
||||||
desired_width: The desired width, the returned thumbnail may be larger than this.
|
|
||||||
desired_height: The desired height, the returned thumbnail may be larger than this.
|
|
||||||
desired_method: The desired method used to generate the thumbnail.
|
|
||||||
desired_type: The desired content-type of the thumbnail.
|
|
||||||
thumbnail_infos: A list of thumbnail info of candidate thumbnails.
|
|
||||||
file_id: The ID of the media that a thumbnail is being requested for.
|
|
||||||
url_cache: True if this is from a URL cache.
|
|
||||||
server_name: The server name, if this is a remote thumbnail.
|
|
||||||
"""
|
|
||||||
logger.debug(
|
|
||||||
"_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s",
|
|
||||||
media_id,
|
|
||||||
desired_width,
|
|
||||||
desired_height,
|
|
||||||
desired_method,
|
|
||||||
thumbnail_infos,
|
|
||||||
)
|
|
||||||
|
|
||||||
# If `dynamic_thumbnails` is enabled, we expect Synapse to go down a
|
|
||||||
# different code path to handle it.
|
|
||||||
assert not self.dynamic_thumbnails
|
|
||||||
|
|
||||||
if thumbnail_infos:
|
|
||||||
file_info = self._select_thumbnail(
|
|
||||||
desired_width,
|
|
||||||
desired_height,
|
|
||||||
desired_method,
|
|
||||||
desired_type,
|
|
||||||
thumbnail_infos,
|
|
||||||
file_id,
|
|
||||||
url_cache,
|
|
||||||
server_name,
|
|
||||||
)
|
|
||||||
if not file_info:
|
|
||||||
logger.info("Couldn't find a thumbnail matching the desired inputs")
|
|
||||||
respond_404(request)
|
|
||||||
return
|
|
||||||
|
|
||||||
# The thumbnail property must exist.
|
|
||||||
assert file_info.thumbnail is not None
|
|
||||||
|
|
||||||
responder = await self.media_storage.fetch_media(file_info)
|
|
||||||
if responder:
|
|
||||||
await respond_with_responder(
|
|
||||||
request,
|
|
||||||
responder,
|
|
||||||
file_info.thumbnail.type,
|
|
||||||
file_info.thumbnail.length,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# If we can't find the thumbnail we regenerate it. This can happen
|
|
||||||
# if e.g. we've deleted the thumbnails but still have the original
|
|
||||||
# image somewhere.
|
|
||||||
#
|
|
||||||
# Since we have an entry for the thumbnail in the DB we a) know we
|
|
||||||
# have have successfully generated the thumbnail in the past (so we
|
|
||||||
# don't need to worry about repeatedly failing to generate
|
|
||||||
# thumbnails), and b) have already calculated that appropriate
|
|
||||||
# width/height/method so we can just call the "generate exact"
|
|
||||||
# methods.
|
|
||||||
|
|
||||||
# First let's check that we do actually have the original image
|
|
||||||
# still. This will throw a 404 if we don't.
|
|
||||||
# TODO: We should refetch the thumbnails for remote media.
|
|
||||||
await self.media_storage.ensure_media_is_in_local_cache(
|
|
||||||
FileInfo(server_name, file_id, url_cache=url_cache)
|
|
||||||
)
|
|
||||||
|
|
||||||
if server_name:
|
|
||||||
await self.media_repo.generate_remote_exact_thumbnail(
|
|
||||||
server_name,
|
|
||||||
file_id=file_id,
|
|
||||||
media_id=media_id,
|
|
||||||
t_width=file_info.thumbnail.width,
|
|
||||||
t_height=file_info.thumbnail.height,
|
|
||||||
t_method=file_info.thumbnail.method,
|
|
||||||
t_type=file_info.thumbnail.type,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
await self.media_repo.generate_local_exact_thumbnail(
|
|
||||||
media_id=media_id,
|
|
||||||
t_width=file_info.thumbnail.width,
|
|
||||||
t_height=file_info.thumbnail.height,
|
|
||||||
t_method=file_info.thumbnail.method,
|
|
||||||
t_type=file_info.thumbnail.type,
|
|
||||||
url_cache=url_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
responder = await self.media_storage.fetch_media(file_info)
|
|
||||||
await respond_with_responder(
|
|
||||||
request,
|
|
||||||
responder,
|
|
||||||
file_info.thumbnail.type,
|
|
||||||
file_info.thumbnail.length,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# This might be because:
|
|
||||||
# 1. We can't create thumbnails for the given media (corrupted or
|
|
||||||
# unsupported file type), or
|
|
||||||
# 2. The thumbnailing process never ran or errored out initially
|
|
||||||
# when the media was first uploaded (these bugs should be
|
|
||||||
# reported and fixed).
|
|
||||||
# Note that we don't attempt to generate a thumbnail now because
|
|
||||||
# `dynamic_thumbnails` is disabled.
|
|
||||||
logger.info("Failed to find any generated thumbnails")
|
|
||||||
|
|
||||||
assert request.path is not None
|
|
||||||
respond_with_json(
|
|
||||||
request,
|
|
||||||
400,
|
|
||||||
cs_error(
|
|
||||||
"Cannot find any thumbnails for the requested media ('%s'). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
|
|
||||||
% (
|
|
||||||
request.path.decode(),
|
|
||||||
", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
|
|
||||||
),
|
|
||||||
code=Codes.UNKNOWN,
|
|
||||||
),
|
|
||||||
send_cors=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _select_thumbnail(
|
|
||||||
self,
|
|
||||||
desired_width: int,
|
|
||||||
desired_height: int,
|
|
||||||
desired_method: str,
|
|
||||||
desired_type: str,
|
|
||||||
thumbnail_infos: List[ThumbnailInfo],
|
|
||||||
file_id: str,
|
|
||||||
url_cache: bool,
|
|
||||||
server_name: Optional[str],
|
|
||||||
) -> Optional[FileInfo]:
|
|
||||||
"""
|
|
||||||
Choose an appropriate thumbnail from the previously generated thumbnails.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
desired_width: The desired width, the returned thumbnail may be larger than this.
|
|
||||||
desired_height: The desired height, the returned thumbnail may be larger than this.
|
|
||||||
desired_method: The desired method used to generate the thumbnail.
|
|
||||||
desired_type: The desired content-type of the thumbnail.
|
|
||||||
thumbnail_infos: A list of thumbnail infos of candidate thumbnails.
|
|
||||||
file_id: The ID of the media that a thumbnail is being requested for.
|
|
||||||
url_cache: True if this is from a URL cache.
|
|
||||||
server_name: The server name, if this is a remote thumbnail.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The thumbnail which best matches the desired parameters.
|
|
||||||
"""
|
|
||||||
desired_method = desired_method.lower()
|
|
||||||
|
|
||||||
# The chosen thumbnail.
|
|
||||||
thumbnail_info = None
|
|
||||||
|
|
||||||
d_w = desired_width
|
|
||||||
d_h = desired_height
|
|
||||||
|
|
||||||
if desired_method == "crop":
|
|
||||||
# Thumbnails that match equal or larger sizes of desired width/height.
|
|
||||||
crop_info_list: List[
|
|
||||||
Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
|
|
||||||
] = []
|
|
||||||
# Other thumbnails.
|
|
||||||
crop_info_list2: List[
|
|
||||||
Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
|
|
||||||
] = []
|
|
||||||
for info in thumbnail_infos:
|
|
||||||
# Skip thumbnails generated with different methods.
|
|
||||||
if info.method != "crop":
|
|
||||||
continue
|
|
||||||
|
|
||||||
t_w = info.width
|
|
||||||
t_h = info.height
|
|
||||||
aspect_quality = abs(d_w * t_h - d_h * t_w)
|
|
||||||
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
|
|
||||||
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
|
||||||
type_quality = desired_type != info.type
|
|
||||||
length_quality = info.length
|
|
||||||
if t_w >= d_w or t_h >= d_h:
|
|
||||||
crop_info_list.append(
|
|
||||||
(
|
|
||||||
aspect_quality,
|
|
||||||
min_quality,
|
|
||||||
size_quality,
|
|
||||||
type_quality,
|
|
||||||
length_quality,
|
|
||||||
info,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
crop_info_list2.append(
|
|
||||||
(
|
|
||||||
aspect_quality,
|
|
||||||
min_quality,
|
|
||||||
size_quality,
|
|
||||||
type_quality,
|
|
||||||
length_quality,
|
|
||||||
info,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Pick the most appropriate thumbnail. Some values of `desired_width` and
|
|
||||||
# `desired_height` may result in a tie, in which case we avoid comparing on
|
|
||||||
# the thumbnail info and pick the thumbnail that appears earlier
|
|
||||||
# in the list of candidates.
|
|
||||||
if crop_info_list:
|
|
||||||
thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1]
|
|
||||||
elif crop_info_list2:
|
|
||||||
thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1]
|
|
||||||
elif desired_method == "scale":
|
|
||||||
# Thumbnails that match equal or larger sizes of desired width/height.
|
|
||||||
info_list: List[Tuple[int, bool, int, ThumbnailInfo]] = []
|
|
||||||
# Other thumbnails.
|
|
||||||
info_list2: List[Tuple[int, bool, int, ThumbnailInfo]] = []
|
|
||||||
|
|
||||||
for info in thumbnail_infos:
|
|
||||||
# Skip thumbnails generated with different methods.
|
|
||||||
if info.method != "scale":
|
|
||||||
continue
|
|
||||||
|
|
||||||
t_w = info.width
|
|
||||||
t_h = info.height
|
|
||||||
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
|
||||||
type_quality = desired_type != info.type
|
|
||||||
length_quality = info.length
|
|
||||||
if t_w >= d_w or t_h >= d_h:
|
|
||||||
info_list.append((size_quality, type_quality, length_quality, info))
|
|
||||||
else:
|
|
||||||
info_list2.append(
|
|
||||||
(size_quality, type_quality, length_quality, info)
|
|
||||||
)
|
|
||||||
# Pick the most appropriate thumbnail. Some values of `desired_width` and
|
|
||||||
# `desired_height` may result in a tie, in which case we avoid comparing on
|
|
||||||
# the thumbnail info and pick the thumbnail that appears earlier
|
|
||||||
# in the list of candidates.
|
|
||||||
if info_list:
|
|
||||||
thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1]
|
|
||||||
elif info_list2:
|
|
||||||
thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1]
|
|
||||||
|
|
||||||
if thumbnail_info:
|
|
||||||
return FileInfo(
|
|
||||||
file_id=file_id,
|
|
||||||
url_cache=url_cache,
|
|
||||||
server_name=server_name,
|
|
||||||
thumbnail=thumbnail_info,
|
|
||||||
)
|
|
||||||
|
|
||||||
# No matching thumbnail was found.
|
|
||||||
return None
|
|
||||||
|
|
|
@ -2461,7 +2461,11 @@ class DatabasePool:
|
||||||
|
|
||||||
|
|
||||||
def make_in_list_sql_clause(
|
def make_in_list_sql_clause(
|
||||||
database_engine: BaseDatabaseEngine, column: str, iterable: Collection[Any]
|
database_engine: BaseDatabaseEngine,
|
||||||
|
column: str,
|
||||||
|
iterable: Collection[Any],
|
||||||
|
*,
|
||||||
|
negative: bool = False,
|
||||||
) -> Tuple[str, list]:
|
) -> Tuple[str, list]:
|
||||||
"""Returns an SQL clause that checks the given column is in the iterable.
|
"""Returns an SQL clause that checks the given column is in the iterable.
|
||||||
|
|
||||||
|
@ -2474,6 +2478,7 @@ def make_in_list_sql_clause(
|
||||||
database_engine
|
database_engine
|
||||||
column: Name of the column
|
column: Name of the column
|
||||||
iterable: The values to check the column against.
|
iterable: The values to check the column against.
|
||||||
|
negative: Whether we should check for inequality, i.e. `NOT IN`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of SQL query and the args
|
A tuple of SQL query and the args
|
||||||
|
@ -2482,9 +2487,19 @@ def make_in_list_sql_clause(
|
||||||
if database_engine.supports_using_any_list:
|
if database_engine.supports_using_any_list:
|
||||||
# This should hopefully be faster, but also makes postgres query
|
# This should hopefully be faster, but also makes postgres query
|
||||||
# stats easier to understand.
|
# stats easier to understand.
|
||||||
return "%s = ANY(?)" % (column,), [list(iterable)]
|
if not negative:
|
||||||
|
clause = f"{column} = ANY(?)"
|
||||||
|
else:
|
||||||
|
clause = f"{column} != ALL(?)"
|
||||||
|
|
||||||
|
return clause, [list(iterable)]
|
||||||
else:
|
else:
|
||||||
return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
|
params = ",".join("?" for _ in iterable)
|
||||||
|
if not negative:
|
||||||
|
clause = f"{column} IN ({params})"
|
||||||
|
else:
|
||||||
|
clause = f"{column} NOT IN ({params})"
|
||||||
|
return clause, list(iterable)
|
||||||
|
|
||||||
|
|
||||||
# These overloads ensure that `columns` and `iterable` values have the same length.
|
# These overloads ensure that `columns` and `iterable` values have the same length.
|
||||||
|
|
|
@ -43,11 +43,9 @@ from synapse.storage.database import (
|
||||||
)
|
)
|
||||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
|
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
|
||||||
from synapse.storage.engines import PostgresEngine
|
|
||||||
from synapse.storage.util.id_generators import (
|
from synapse.storage.util.id_generators import (
|
||||||
AbstractStreamIdGenerator,
|
AbstractStreamIdGenerator,
|
||||||
MultiWriterIdGenerator,
|
MultiWriterIdGenerator,
|
||||||
StreamIdGenerator,
|
|
||||||
)
|
)
|
||||||
from synapse.types import JsonDict, JsonMapping
|
from synapse.types import JsonDict, JsonMapping
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
|
@ -75,37 +73,20 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
||||||
|
|
||||||
self._account_data_id_gen: AbstractStreamIdGenerator
|
self._account_data_id_gen: AbstractStreamIdGenerator
|
||||||
|
|
||||||
if isinstance(database.engine, PostgresEngine):
|
self._account_data_id_gen = MultiWriterIdGenerator(
|
||||||
self._account_data_id_gen = MultiWriterIdGenerator(
|
db_conn=db_conn,
|
||||||
db_conn=db_conn,
|
db=database,
|
||||||
db=database,
|
notifier=hs.get_replication_notifier(),
|
||||||
notifier=hs.get_replication_notifier(),
|
stream_name="account_data",
|
||||||
stream_name="account_data",
|
instance_name=self._instance_name,
|
||||||
instance_name=self._instance_name,
|
tables=[
|
||||||
tables=[
|
("room_account_data", "instance_name", "stream_id"),
|
||||||
("room_account_data", "instance_name", "stream_id"),
|
("room_tags_revisions", "instance_name", "stream_id"),
|
||||||
("room_tags_revisions", "instance_name", "stream_id"),
|
("account_data", "instance_name", "stream_id"),
|
||||||
("account_data", "instance_name", "stream_id"),
|
],
|
||||||
],
|
sequence_name="account_data_sequence",
|
||||||
sequence_name="account_data_sequence",
|
writers=hs.config.worker.writers.account_data,
|
||||||
writers=hs.config.worker.writers.account_data,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Multiple writers are not supported for SQLite.
|
|
||||||
#
|
|
||||||
# We shouldn't be running in worker mode with SQLite, but its useful
|
|
||||||
# to support it for unit tests.
|
|
||||||
self._account_data_id_gen = StreamIdGenerator(
|
|
||||||
db_conn,
|
|
||||||
hs.get_replication_notifier(),
|
|
||||||
"room_account_data",
|
|
||||||
"stream_id",
|
|
||||||
extra_tables=[
|
|
||||||
("account_data", "stream_id"),
|
|
||||||
("room_tags_revisions", "stream_id"),
|
|
||||||
],
|
|
||||||
is_writer=self._instance_name in hs.config.worker.writers.account_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
account_max = self.get_max_account_data_stream_id()
|
account_max = self.get_max_account_data_stream_id()
|
||||||
self._account_data_stream_cache = StreamChangeCache(
|
self._account_data_stream_cache = StreamChangeCache(
|
||||||
|
|
|
@ -318,7 +318,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
self._invalidate_local_get_event_cache(redacts) # type: ignore[attr-defined]
|
self._invalidate_local_get_event_cache(redacts) # type: ignore[attr-defined]
|
||||||
# Caches which might leak edits must be invalidated for the event being
|
# Caches which might leak edits must be invalidated for the event being
|
||||||
# redacted.
|
# redacted.
|
||||||
self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,))
|
self._attempt_to_invalidate_cache(
|
||||||
|
"get_relations_for_event",
|
||||||
|
(
|
||||||
|
room_id,
|
||||||
|
redacts,
|
||||||
|
),
|
||||||
|
)
|
||||||
self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,))
|
self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,))
|
||||||
self._attempt_to_invalidate_cache("get_thread_id", (redacts,))
|
self._attempt_to_invalidate_cache("get_thread_id", (redacts,))
|
||||||
self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,))
|
self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,))
|
||||||
|
@ -345,7 +351,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
if relates_to:
|
if relates_to:
|
||||||
self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,))
|
self._attempt_to_invalidate_cache(
|
||||||
|
"get_relations_for_event",
|
||||||
|
(
|
||||||
|
room_id,
|
||||||
|
relates_to,
|
||||||
|
),
|
||||||
|
)
|
||||||
self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,))
|
self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,))
|
||||||
self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
|
self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
|
||||||
self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
|
self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
|
||||||
|
@ -380,9 +392,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
self._attempt_to_invalidate_cache(
|
self._attempt_to_invalidate_cache(
|
||||||
"get_unread_event_push_actions_by_room_for_user", (room_id,)
|
"get_unread_event_push_actions_by_room_for_user", (room_id,)
|
||||||
)
|
)
|
||||||
|
self._attempt_to_invalidate_cache("get_relations_for_event", (room_id,))
|
||||||
|
|
||||||
self._attempt_to_invalidate_cache("_get_membership_from_event_id", None)
|
self._attempt_to_invalidate_cache("_get_membership_from_event_id", None)
|
||||||
self._attempt_to_invalidate_cache("get_relations_for_event", None)
|
|
||||||
self._attempt_to_invalidate_cache("get_applicable_edit", None)
|
self._attempt_to_invalidate_cache("get_applicable_edit", None)
|
||||||
self._attempt_to_invalidate_cache("get_thread_id", None)
|
self._attempt_to_invalidate_cache("get_thread_id", None)
|
||||||
self._attempt_to_invalidate_cache("get_thread_id_for_receipts", None)
|
self._attempt_to_invalidate_cache("get_thread_id_for_receipts", None)
|
||||||
|
|
|
@ -50,16 +50,15 @@ from synapse.storage.database import (
|
||||||
LoggingTransaction,
|
LoggingTransaction,
|
||||||
make_in_list_sql_clause,
|
make_in_list_sql_clause,
|
||||||
)
|
)
|
||||||
from synapse.storage.engines import PostgresEngine
|
|
||||||
from synapse.storage.util.id_generators import (
|
from synapse.storage.util.id_generators import (
|
||||||
AbstractStreamIdGenerator,
|
AbstractStreamIdGenerator,
|
||||||
MultiWriterIdGenerator,
|
MultiWriterIdGenerator,
|
||||||
StreamIdGenerator,
|
|
||||||
)
|
)
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
from synapse.util.stringutils import parse_and_validate_server_name
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -89,35 +88,23 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
expiry_ms=30 * 60 * 1000,
|
expiry_ms=30 * 60 * 1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(database.engine, PostgresEngine):
|
self._can_write_to_device = (
|
||||||
self._can_write_to_device = (
|
self._instance_name in hs.config.worker.writers.to_device
|
||||||
self._instance_name in hs.config.worker.writers.to_device
|
)
|
||||||
)
|
|
||||||
|
|
||||||
self._to_device_msg_id_gen: AbstractStreamIdGenerator = (
|
self._to_device_msg_id_gen: AbstractStreamIdGenerator = MultiWriterIdGenerator(
|
||||||
MultiWriterIdGenerator(
|
db_conn=db_conn,
|
||||||
db_conn=db_conn,
|
db=database,
|
||||||
db=database,
|
notifier=hs.get_replication_notifier(),
|
||||||
notifier=hs.get_replication_notifier(),
|
stream_name="to_device",
|
||||||
stream_name="to_device",
|
instance_name=self._instance_name,
|
||||||
instance_name=self._instance_name,
|
tables=[
|
||||||
tables=[
|
("device_inbox", "instance_name", "stream_id"),
|
||||||
("device_inbox", "instance_name", "stream_id"),
|
("device_federation_outbox", "instance_name", "stream_id"),
|
||||||
("device_federation_outbox", "instance_name", "stream_id"),
|
],
|
||||||
],
|
sequence_name="device_inbox_sequence",
|
||||||
sequence_name="device_inbox_sequence",
|
writers=hs.config.worker.writers.to_device,
|
||||||
writers=hs.config.worker.writers.to_device,
|
)
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._can_write_to_device = True
|
|
||||||
self._to_device_msg_id_gen = StreamIdGenerator(
|
|
||||||
db_conn,
|
|
||||||
hs.get_replication_notifier(),
|
|
||||||
"device_inbox",
|
|
||||||
"stream_id",
|
|
||||||
extra_tables=[("device_federation_outbox", "stream_id")],
|
|
||||||
)
|
|
||||||
|
|
||||||
max_device_inbox_id = self._to_device_msg_id_gen.get_current_token()
|
max_device_inbox_id = self._to_device_msg_id_gen.get_current_token()
|
||||||
device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
|
device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
|
||||||
|
@ -978,6 +965,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
|
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
|
||||||
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
|
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
|
||||||
REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
|
REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
|
||||||
|
CLEANUP_DEVICE_FEDERATION_OUTBOX = "cleanup_device_federation_outbox"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -1003,6 +991,11 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
|
||||||
self._remove_dead_devices_from_device_inbox,
|
self._remove_dead_devices_from_device_inbox,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.db_pool.updates.register_background_update_handler(
|
||||||
|
self.CLEANUP_DEVICE_FEDERATION_OUTBOX,
|
||||||
|
self._cleanup_device_federation_outbox,
|
||||||
|
)
|
||||||
|
|
||||||
async def _background_drop_index_device_inbox(
|
async def _background_drop_index_device_inbox(
|
||||||
self, progress: JsonDict, batch_size: int
|
self, progress: JsonDict, batch_size: int
|
||||||
) -> int:
|
) -> int:
|
||||||
|
@ -1094,6 +1087,75 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
|
||||||
|
|
||||||
return batch_size
|
return batch_size
|
||||||
|
|
||||||
|
async def _cleanup_device_federation_outbox(
|
||||||
|
self,
|
||||||
|
progress: JsonDict,
|
||||||
|
batch_size: int,
|
||||||
|
) -> int:
|
||||||
|
def _cleanup_device_federation_outbox_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> bool:
|
||||||
|
if "max_stream_id" in progress:
|
||||||
|
max_stream_id = progress["max_stream_id"]
|
||||||
|
else:
|
||||||
|
txn.execute("SELECT max(stream_id) FROM device_federation_outbox")
|
||||||
|
res = cast(Tuple[Optional[int]], txn.fetchone())
|
||||||
|
if res[0] is None:
|
||||||
|
# this can only happen if the `device_inbox` table is empty, in which
|
||||||
|
# case we have no work to do.
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
max_stream_id = res[0]
|
||||||
|
|
||||||
|
start = progress.get("stream_id", 0)
|
||||||
|
stop = start + batch_size
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
SELECT destination FROM device_federation_outbox
|
||||||
|
WHERE ? < stream_id AND stream_id <= ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
txn.execute(sql, (start, stop))
|
||||||
|
|
||||||
|
destinations = {d for d, in txn}
|
||||||
|
to_remove = set()
|
||||||
|
for d in destinations:
|
||||||
|
try:
|
||||||
|
parse_and_validate_server_name(d)
|
||||||
|
except ValueError:
|
||||||
|
to_remove.add(d)
|
||||||
|
|
||||||
|
self.db_pool.simple_delete_many_txn(
|
||||||
|
txn,
|
||||||
|
table="device_federation_outbox",
|
||||||
|
column="destination",
|
||||||
|
values=to_remove,
|
||||||
|
keyvalues={},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db_pool.updates._background_update_progress_txn(
|
||||||
|
txn,
|
||||||
|
self.CLEANUP_DEVICE_FEDERATION_OUTBOX,
|
||||||
|
{
|
||||||
|
"stream_id": stop,
|
||||||
|
"max_stream_id": max_stream_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return stop >= max_stream_id
|
||||||
|
|
||||||
|
finished = await self.db_pool.runInteraction(
|
||||||
|
"_cleanup_device_federation_outbox",
|
||||||
|
_cleanup_device_federation_outbox_txn,
|
||||||
|
)
|
||||||
|
|
||||||
|
if finished:
|
||||||
|
await self.db_pool.updates._end_background_update(
|
||||||
|
self.CLEANUP_DEVICE_FEDERATION_OUTBOX,
|
||||||
|
)
|
||||||
|
|
||||||
|
return batch_size
|
||||||
|
|
||||||
|
|
||||||
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
|
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -57,10 +57,7 @@ from synapse.storage.database import (
|
||||||
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
|
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
|
||||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
from synapse.storage.util.id_generators import (
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
AbstractStreamIdGenerator,
|
|
||||||
StreamIdGenerator,
|
|
||||||
)
|
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
JsonDict,
|
JsonDict,
|
||||||
JsonMapping,
|
JsonMapping,
|
||||||
|
@ -99,19 +96,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
|
|
||||||
# In the worker store this is an ID tracker which we overwrite in the non-worker
|
# In the worker store this is an ID tracker which we overwrite in the non-worker
|
||||||
# class below that is used on the main process.
|
# class below that is used on the main process.
|
||||||
self._device_list_id_gen = StreamIdGenerator(
|
self._device_list_id_gen = MultiWriterIdGenerator(
|
||||||
db_conn,
|
db_conn=db_conn,
|
||||||
hs.get_replication_notifier(),
|
db=database,
|
||||||
"device_lists_stream",
|
notifier=hs.get_replication_notifier(),
|
||||||
"stream_id",
|
stream_name="device_lists_stream",
|
||||||
extra_tables=[
|
instance_name=self._instance_name,
|
||||||
("user_signature_stream", "stream_id"),
|
tables=[
|
||||||
("device_lists_outbound_pokes", "stream_id"),
|
("device_lists_stream", "instance_name", "stream_id"),
|
||||||
("device_lists_changes_in_room", "stream_id"),
|
("user_signature_stream", "instance_name", "stream_id"),
|
||||||
("device_lists_remote_pending", "stream_id"),
|
("device_lists_outbound_pokes", "instance_name", "stream_id"),
|
||||||
("device_lists_changes_converted_stream_position", "stream_id"),
|
("device_lists_changes_in_room", "instance_name", "stream_id"),
|
||||||
|
("device_lists_remote_pending", "instance_name", "stream_id"),
|
||||||
],
|
],
|
||||||
is_writer=hs.config.worker.worker_app is None,
|
sequence_name="device_lists_sequence",
|
||||||
|
writers=["master"],
|
||||||
)
|
)
|
||||||
|
|
||||||
device_list_max = self._device_list_id_gen.get_current_token()
|
device_list_max = self._device_list_id_gen.get_current_token()
|
||||||
|
@ -762,6 +761,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
"stream_id": stream_id,
|
"stream_id": stream_id,
|
||||||
"from_user_id": from_user_id,
|
"from_user_id": from_user_id,
|
||||||
"user_ids": json_encoder.encode(user_ids),
|
"user_ids": json_encoder.encode(user_ids),
|
||||||
|
"instance_name": self._instance_name,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1582,6 +1582,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
||||||
):
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
|
self._instance_name = hs.get_instance_name()
|
||||||
|
|
||||||
self.db_pool.updates.register_background_index_update(
|
self.db_pool.updates.register_background_index_update(
|
||||||
"device_lists_stream_idx",
|
"device_lists_stream_idx",
|
||||||
index_name="device_lists_stream_user_id",
|
index_name="device_lists_stream_user_id",
|
||||||
|
@ -1694,6 +1696,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
||||||
"device_lists_outbound_pokes",
|
"device_lists_outbound_pokes",
|
||||||
{
|
{
|
||||||
"stream_id": stream_id,
|
"stream_id": stream_id,
|
||||||
|
"instance_name": self._instance_name,
|
||||||
"destination": destination,
|
"destination": destination,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
|
@ -1730,10 +1733,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
||||||
|
|
||||||
|
|
||||||
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
# Because we have write access, this will be a StreamIdGenerator
|
|
||||||
# (see DeviceWorkerStore.__init__)
|
|
||||||
_device_list_id_gen: AbstractStreamIdGenerator
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
database: DatabasePool,
|
database: DatabasePool,
|
||||||
|
@ -2092,9 +2091,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
self.db_pool.simple_insert_many_txn(
|
self.db_pool.simple_insert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="device_lists_stream",
|
table="device_lists_stream",
|
||||||
keys=("stream_id", "user_id", "device_id"),
|
keys=("instance_name", "stream_id", "user_id", "device_id"),
|
||||||
values=[
|
values=[
|
||||||
(stream_id, user_id, device_id)
|
(self._instance_name, stream_id, user_id, device_id)
|
||||||
for stream_id, device_id in zip(stream_ids, device_ids)
|
for stream_id, device_id in zip(stream_ids, device_ids)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -2124,6 +2123,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
values = [
|
values = [
|
||||||
(
|
(
|
||||||
destination,
|
destination,
|
||||||
|
self._instance_name,
|
||||||
next(stream_id_iterator),
|
next(stream_id_iterator),
|
||||||
user_id,
|
user_id,
|
||||||
device_id,
|
device_id,
|
||||||
|
@ -2139,6 +2139,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
table="device_lists_outbound_pokes",
|
table="device_lists_outbound_pokes",
|
||||||
keys=(
|
keys=(
|
||||||
"destination",
|
"destination",
|
||||||
|
"instance_name",
|
||||||
"stream_id",
|
"stream_id",
|
||||||
"user_id",
|
"user_id",
|
||||||
"device_id",
|
"device_id",
|
||||||
|
@ -2157,10 +2158,34 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
device_id,
|
device_id,
|
||||||
{
|
{
|
||||||
stream_id: destination
|
stream_id: destination
|
||||||
for (destination, stream_id, _, _, _, _, _) in values
|
for (destination, _, stream_id, _, _, _, _, _) in values
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def mark_redundant_device_lists_pokes(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
device_id: str,
|
||||||
|
room_id: str,
|
||||||
|
converted_upto_stream_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""If we've calculated the outbound pokes for a given room/device list
|
||||||
|
update, mark any subsequent changes as already converted"""
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
UPDATE device_lists_changes_in_room
|
||||||
|
SET converted_to_destinations = true
|
||||||
|
WHERE stream_id > ? AND user_id = ? AND device_id = ?
|
||||||
|
AND room_id = ? AND NOT converted_to_destinations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mark_redundant_device_lists_pokes_txn(txn: LoggingTransaction) -> None:
|
||||||
|
txn.execute(sql, (converted_upto_stream_id, user_id, device_id, room_id))
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"mark_redundant_device_lists_pokes", mark_redundant_device_lists_pokes_txn
|
||||||
|
)
|
||||||
|
|
||||||
def _add_device_outbound_room_poke_txn(
|
def _add_device_outbound_room_poke_txn(
|
||||||
self,
|
self,
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
|
@ -2186,6 +2211,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
"device_id",
|
"device_id",
|
||||||
"room_id",
|
"room_id",
|
||||||
"stream_id",
|
"stream_id",
|
||||||
|
"instance_name",
|
||||||
"converted_to_destinations",
|
"converted_to_destinations",
|
||||||
"opentracing_context",
|
"opentracing_context",
|
||||||
),
|
),
|
||||||
|
@ -2195,6 +2221,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
device_id,
|
device_id,
|
||||||
room_id,
|
room_id,
|
||||||
stream_id,
|
stream_id,
|
||||||
|
self._instance_name,
|
||||||
# We only need to calculate outbound pokes for local users
|
# We only need to calculate outbound pokes for local users
|
||||||
not self.hs.is_mine_id(user_id),
|
not self.hs.is_mine_id(user_id),
|
||||||
encoded_context,
|
encoded_context,
|
||||||
|
@ -2314,7 +2341,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
},
|
},
|
||||||
values={"stream_id": stream_id},
|
values={
|
||||||
|
"stream_id": stream_id,
|
||||||
|
"instance_name": self._instance_name,
|
||||||
|
},
|
||||||
desc="add_remote_device_list_to_pending",
|
desc="add_remote_device_list_to_pending",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -58,7 +58,7 @@ from synapse.storage.database import (
|
||||||
)
|
)
|
||||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
from synapse.types import JsonDict, JsonMapping
|
from synapse.types import JsonDict, JsonMapping
|
||||||
from synapse.util import json_decoder, json_encoder
|
from synapse.util import json_decoder, json_encoder
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
@ -1448,11 +1448,17 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||||
):
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self._cross_signing_id_gen = StreamIdGenerator(
|
self._cross_signing_id_gen = MultiWriterIdGenerator(
|
||||||
db_conn,
|
db_conn=db_conn,
|
||||||
hs.get_replication_notifier(),
|
db=database,
|
||||||
"e2e_cross_signing_keys",
|
notifier=hs.get_replication_notifier(),
|
||||||
"stream_id",
|
stream_name="e2e_cross_signing_keys",
|
||||||
|
instance_name=self._instance_name,
|
||||||
|
tables=[
|
||||||
|
("e2e_cross_signing_keys", "instance_name", "stream_id"),
|
||||||
|
],
|
||||||
|
sequence_name="e2e_cross_signing_keys_sequence",
|
||||||
|
writers=["master"],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def set_e2e_device_keys(
|
async def set_e2e_device_keys(
|
||||||
|
@ -1627,6 +1633,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||||
"keytype": key_type,
|
"keytype": key_type,
|
||||||
"keydata": json_encoder.encode(key),
|
"keydata": json_encoder.encode(key),
|
||||||
"stream_id": stream_id,
|
"stream_id": stream_id,
|
||||||
|
"instance_name": self._instance_name,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -95,6 +95,10 @@ class DeltaState:
|
||||||
to_insert: StateMap[str]
|
to_insert: StateMap[str]
|
||||||
no_longer_in_room: bool = False
|
no_longer_in_room: bool = False
|
||||||
|
|
||||||
|
def is_noop(self) -> bool:
|
||||||
|
"""Whether this state delta is actually empty"""
|
||||||
|
return not self.to_delete and not self.to_insert and not self.no_longer_in_room
|
||||||
|
|
||||||
|
|
||||||
class PersistEventsStore:
|
class PersistEventsStore:
|
||||||
"""Contains all the functions for writing events to the database.
|
"""Contains all the functions for writing events to the database.
|
||||||
|
@ -1017,6 +1021,9 @@ class PersistEventsStore:
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update the current state stored in the datatabase for the given room"""
|
"""Update the current state stored in the datatabase for the given room"""
|
||||||
|
|
||||||
|
if state_delta.is_noop():
|
||||||
|
return
|
||||||
|
|
||||||
async with self._stream_id_gen.get_next() as stream_ordering:
|
async with self._stream_id_gen.get_next() as stream_ordering:
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"update_current_state",
|
"update_current_state",
|
||||||
|
@ -1923,7 +1930,12 @@ class PersistEventsStore:
|
||||||
|
|
||||||
# Any relation information for the related event must be cleared.
|
# Any relation information for the related event must be cleared.
|
||||||
self.store._invalidate_cache_and_stream(
|
self.store._invalidate_cache_and_stream(
|
||||||
txn, self.store.get_relations_for_event, (redacted_relates_to,)
|
txn,
|
||||||
|
self.store.get_relations_for_event,
|
||||||
|
(
|
||||||
|
room_id,
|
||||||
|
redacted_relates_to,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if rel_type == RelationTypes.REFERENCE:
|
if rel_type == RelationTypes.REFERENCE:
|
||||||
self.store._invalidate_cache_and_stream(
|
self.store._invalidate_cache_and_stream(
|
||||||
|
|
|
@ -1181,7 +1181,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||||
|
|
||||||
results = list(txn)
|
results = list(txn)
|
||||||
# (event_id, parent_id, rel_type) for each relation
|
# (event_id, parent_id, rel_type) for each relation
|
||||||
relations_to_insert: List[Tuple[str, str, str]] = []
|
relations_to_insert: List[Tuple[str, str, str, str]] = []
|
||||||
for event_id, event_json_raw in results:
|
for event_id, event_json_raw in results:
|
||||||
try:
|
try:
|
||||||
event_json = db_to_json(event_json_raw)
|
event_json = db_to_json(event_json_raw)
|
||||||
|
@ -1214,7 +1214,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||||
if not isinstance(parent_id, str):
|
if not isinstance(parent_id, str):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
relations_to_insert.append((event_id, parent_id, rel_type))
|
room_id = event_json["room_id"]
|
||||||
|
relations_to_insert.append((room_id, event_id, parent_id, rel_type))
|
||||||
|
|
||||||
# Insert the missing data, note that we upsert here in case the event
|
# Insert the missing data, note that we upsert here in case the event
|
||||||
# has already been processed.
|
# has already been processed.
|
||||||
|
@ -1223,18 +1224,27 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||||
txn=txn,
|
txn=txn,
|
||||||
table="event_relations",
|
table="event_relations",
|
||||||
key_names=("event_id",),
|
key_names=("event_id",),
|
||||||
key_values=[(r[0],) for r in relations_to_insert],
|
key_values=[(r[1],) for r in relations_to_insert],
|
||||||
value_names=("relates_to_id", "relation_type"),
|
value_names=("relates_to_id", "relation_type"),
|
||||||
value_values=[r[1:] for r in relations_to_insert],
|
value_values=[r[2:] for r in relations_to_insert],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Iterate the parent IDs and invalidate caches.
|
# Iterate the parent IDs and invalidate caches.
|
||||||
cache_tuples = {(r[1],) for r in relations_to_insert}
|
|
||||||
self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
|
self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
|
||||||
txn, self.get_relations_for_event, cache_tuples # type: ignore[attr-defined]
|
txn,
|
||||||
|
self.get_relations_for_event, # type: ignore[attr-defined]
|
||||||
|
{
|
||||||
|
(
|
||||||
|
r[0], # room_id
|
||||||
|
r[2], # parent_id
|
||||||
|
)
|
||||||
|
for r in relations_to_insert
|
||||||
|
},
|
||||||
)
|
)
|
||||||
self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
|
self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
|
||||||
txn, self.get_thread_summary, cache_tuples # type: ignore[attr-defined]
|
txn,
|
||||||
|
self.get_thread_summary, # type: ignore[attr-defined]
|
||||||
|
{(r[1],) for r in relations_to_insert},
|
||||||
)
|
)
|
||||||
|
|
||||||
if results:
|
if results:
|
||||||
|
|
|
@ -75,12 +75,10 @@ from synapse.storage.database import (
|
||||||
LoggingDatabaseConnection,
|
LoggingDatabaseConnection,
|
||||||
LoggingTransaction,
|
LoggingTransaction,
|
||||||
)
|
)
|
||||||
from synapse.storage.engines import PostgresEngine
|
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
from synapse.storage.util.id_generators import (
|
from synapse.storage.util.id_generators import (
|
||||||
AbstractStreamIdGenerator,
|
AbstractStreamIdGenerator,
|
||||||
MultiWriterIdGenerator,
|
MultiWriterIdGenerator,
|
||||||
StreamIdGenerator,
|
|
||||||
)
|
)
|
||||||
from synapse.storage.util.sequence import build_sequence_generator
|
from synapse.storage.util.sequence import build_sequence_generator
|
||||||
from synapse.types import JsonDict, get_domain_from_id
|
from synapse.types import JsonDict, get_domain_from_id
|
||||||
|
@ -195,51 +193,35 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
self._stream_id_gen: AbstractStreamIdGenerator
|
self._stream_id_gen: AbstractStreamIdGenerator
|
||||||
self._backfill_id_gen: AbstractStreamIdGenerator
|
self._backfill_id_gen: AbstractStreamIdGenerator
|
||||||
if isinstance(database.engine, PostgresEngine):
|
|
||||||
# If we're using Postgres than we can use `MultiWriterIdGenerator`
|
self._stream_id_gen = MultiWriterIdGenerator(
|
||||||
# regardless of whether this process writes to the streams or not.
|
db_conn=db_conn,
|
||||||
self._stream_id_gen = MultiWriterIdGenerator(
|
db=database,
|
||||||
db_conn=db_conn,
|
notifier=hs.get_replication_notifier(),
|
||||||
db=database,
|
stream_name="events",
|
||||||
notifier=hs.get_replication_notifier(),
|
instance_name=hs.get_instance_name(),
|
||||||
stream_name="events",
|
tables=[
|
||||||
instance_name=hs.get_instance_name(),
|
("events", "instance_name", "stream_ordering"),
|
||||||
tables=[("events", "instance_name", "stream_ordering")],
|
("current_state_delta_stream", "instance_name", "stream_id"),
|
||||||
sequence_name="events_stream_seq",
|
("ex_outlier_stream", "instance_name", "event_stream_ordering"),
|
||||||
writers=hs.config.worker.writers.events,
|
],
|
||||||
)
|
sequence_name="events_stream_seq",
|
||||||
self._backfill_id_gen = MultiWriterIdGenerator(
|
writers=hs.config.worker.writers.events,
|
||||||
db_conn=db_conn,
|
)
|
||||||
db=database,
|
self._backfill_id_gen = MultiWriterIdGenerator(
|
||||||
notifier=hs.get_replication_notifier(),
|
db_conn=db_conn,
|
||||||
stream_name="backfill",
|
db=database,
|
||||||
instance_name=hs.get_instance_name(),
|
notifier=hs.get_replication_notifier(),
|
||||||
tables=[("events", "instance_name", "stream_ordering")],
|
stream_name="backfill",
|
||||||
sequence_name="events_backfill_stream_seq",
|
instance_name=hs.get_instance_name(),
|
||||||
positive=False,
|
tables=[
|
||||||
writers=hs.config.worker.writers.events,
|
("events", "instance_name", "stream_ordering"),
|
||||||
)
|
("ex_outlier_stream", "instance_name", "event_stream_ordering"),
|
||||||
else:
|
],
|
||||||
# Multiple writers are not supported for SQLite.
|
sequence_name="events_backfill_stream_seq",
|
||||||
#
|
positive=False,
|
||||||
# We shouldn't be running in worker mode with SQLite, but its useful
|
writers=hs.config.worker.writers.events,
|
||||||
# to support it for unit tests.
|
)
|
||||||
self._stream_id_gen = StreamIdGenerator(
|
|
||||||
db_conn,
|
|
||||||
hs.get_replication_notifier(),
|
|
||||||
"events",
|
|
||||||
"stream_ordering",
|
|
||||||
is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
|
|
||||||
)
|
|
||||||
self._backfill_id_gen = StreamIdGenerator(
|
|
||||||
db_conn,
|
|
||||||
hs.get_replication_notifier(),
|
|
||||||
"events",
|
|
||||||
"stream_ordering",
|
|
||||||
step=-1,
|
|
||||||
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
|
|
||||||
is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
|
|
||||||
)
|
|
||||||
|
|
||||||
events_max = self._stream_id_gen.get_current_token()
|
events_max = self._stream_id_gen.get_current_token()
|
||||||
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
|
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
|
||||||
|
@ -309,27 +291,17 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
self._un_partial_stated_events_stream_id_gen: AbstractStreamIdGenerator
|
self._un_partial_stated_events_stream_id_gen: AbstractStreamIdGenerator
|
||||||
|
|
||||||
if isinstance(database.engine, PostgresEngine):
|
self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
|
||||||
self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
|
db_conn=db_conn,
|
||||||
db_conn=db_conn,
|
db=database,
|
||||||
db=database,
|
notifier=hs.get_replication_notifier(),
|
||||||
notifier=hs.get_replication_notifier(),
|
stream_name="un_partial_stated_event_stream",
|
||||||
stream_name="un_partial_stated_event_stream",
|
instance_name=hs.get_instance_name(),
|
||||||
instance_name=hs.get_instance_name(),
|
tables=[("un_partial_stated_event_stream", "instance_name", "stream_id")],
|
||||||
tables=[
|
sequence_name="un_partial_stated_event_stream_sequence",
|
||||||
("un_partial_stated_event_stream", "instance_name", "stream_id")
|
# TODO(faster_joins, multiple writers) Support multiple writers.
|
||||||
],
|
writers=["master"],
|
||||||
sequence_name="un_partial_stated_event_stream_sequence",
|
)
|
||||||
# TODO(faster_joins, multiple writers) Support multiple writers.
|
|
||||||
writers=["master"],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._un_partial_stated_events_stream_id_gen = StreamIdGenerator(
|
|
||||||
db_conn,
|
|
||||||
hs.get_replication_notifier(),
|
|
||||||
"un_partial_stated_event_stream",
|
|
||||||
"stream_id",
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_un_partial_stated_events_token(self, instance_name: str) -> int:
|
def get_un_partial_stated_events_token(self, instance_name: str) -> int:
|
||||||
return (
|
return (
|
||||||
|
|
|
@ -40,13 +40,11 @@ from synapse.storage.database import (
|
||||||
LoggingTransaction,
|
LoggingTransaction,
|
||||||
)
|
)
|
||||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
from synapse.storage.engines import PostgresEngine
|
|
||||||
from synapse.storage.engines._base import IsolationLevel
|
from synapse.storage.engines._base import IsolationLevel
|
||||||
from synapse.storage.types import Connection
|
from synapse.storage.types import Connection
|
||||||
from synapse.storage.util.id_generators import (
|
from synapse.storage.util.id_generators import (
|
||||||
AbstractStreamIdGenerator,
|
AbstractStreamIdGenerator,
|
||||||
MultiWriterIdGenerator,
|
MultiWriterIdGenerator,
|
||||||
StreamIdGenerator,
|
|
||||||
)
|
)
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
@ -91,21 +89,16 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
|
||||||
self._instance_name in hs.config.worker.writers.presence
|
self._instance_name in hs.config.worker.writers.presence
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(database.engine, PostgresEngine):
|
self._presence_id_gen = MultiWriterIdGenerator(
|
||||||
self._presence_id_gen = MultiWriterIdGenerator(
|
db_conn=db_conn,
|
||||||
db_conn=db_conn,
|
db=database,
|
||||||
db=database,
|
notifier=hs.get_replication_notifier(),
|
||||||
notifier=hs.get_replication_notifier(),
|
stream_name="presence_stream",
|
||||||
stream_name="presence_stream",
|
instance_name=self._instance_name,
|
||||||
instance_name=self._instance_name,
|
tables=[("presence_stream", "instance_name", "stream_id")],
|
||||||
tables=[("presence_stream", "instance_name", "stream_id")],
|
sequence_name="presence_stream_sequence",
|
||||||
sequence_name="presence_stream_sequence",
|
writers=hs.config.worker.writers.presence,
|
||||||
writers=hs.config.worker.writers.presence,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._presence_id_gen = StreamIdGenerator(
|
|
||||||
db_conn, hs.get_replication_notifier(), "presence_stream", "stream_id"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self._presence_on_startup = self._get_active_presence(db_conn)
|
self._presence_on_startup = self._get_active_presence(db_conn)
|
||||||
|
|
|
@ -53,7 +53,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
|
||||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||||
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
|
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
|
||||||
from synapse.storage.util.id_generators import IdGenerator, StreamIdGenerator
|
from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
|
||||||
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
|
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder, unwrapFirstError
|
from synapse.util import json_encoder, unwrapFirstError
|
||||||
|
@ -126,7 +126,7 @@ class PushRulesWorkerStore(
|
||||||
`get_max_push_rules_stream_id` which can be called in the initializer.
|
`get_max_push_rules_stream_id` which can be called in the initializer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_push_rules_stream_id_gen: StreamIdGenerator
|
_push_rules_stream_id_gen: MultiWriterIdGenerator
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -140,14 +140,17 @@ class PushRulesWorkerStore(
|
||||||
hs.get_instance_name() in hs.config.worker.writers.push_rules
|
hs.get_instance_name() in hs.config.worker.writers.push_rules
|
||||||
)
|
)
|
||||||
|
|
||||||
# In the worker store this is an ID tracker which we overwrite in the non-worker
|
self._push_rules_stream_id_gen = MultiWriterIdGenerator(
|
||||||
# class below that is used on the main process.
|
db_conn=db_conn,
|
||||||
self._push_rules_stream_id_gen = StreamIdGenerator(
|
db=database,
|
||||||
db_conn,
|
notifier=hs.get_replication_notifier(),
|
||||||
hs.get_replication_notifier(),
|
stream_name="push_rules_stream",
|
||||||
"push_rules_stream",
|
instance_name=self._instance_name,
|
||||||
"stream_id",
|
tables=[
|
||||||
is_writer=self._is_push_writer,
|
("push_rules_stream", "instance_name", "stream_id"),
|
||||||
|
],
|
||||||
|
sequence_name="push_rules_stream_sequence",
|
||||||
|
writers=hs.config.worker.writers.push_rules,
|
||||||
)
|
)
|
||||||
|
|
||||||
push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
|
push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
|
||||||
|
@ -880,6 +883,7 @@ class PushRulesWorkerStore(
|
||||||
raise Exception("Not a push writer")
|
raise Exception("Not a push writer")
|
||||||
|
|
||||||
values = {
|
values = {
|
||||||
|
"instance_name": self._instance_name,
|
||||||
"stream_id": stream_id,
|
"stream_id": stream_id,
|
||||||
"event_stream_ordering": event_stream_ordering,
|
"event_stream_ordering": event_stream_ordering,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
|
|
@ -40,10 +40,7 @@ from synapse.storage.database import (
|
||||||
LoggingDatabaseConnection,
|
LoggingDatabaseConnection,
|
||||||
LoggingTransaction,
|
LoggingTransaction,
|
||||||
)
|
)
|
||||||
from synapse.storage.util.id_generators import (
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
AbstractStreamIdGenerator,
|
|
||||||
StreamIdGenerator,
|
|
||||||
)
|
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
@ -84,15 +81,20 @@ class PusherWorkerStore(SQLBaseStore):
|
||||||
):
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
# In the worker store this is an ID tracker which we overwrite in the non-worker
|
self._instance_name = hs.get_instance_name()
|
||||||
# class below that is used on the main process.
|
|
||||||
self._pushers_id_gen = StreamIdGenerator(
|
self._pushers_id_gen = MultiWriterIdGenerator(
|
||||||
db_conn,
|
db_conn=db_conn,
|
||||||
hs.get_replication_notifier(),
|
db=database,
|
||||||
"pushers",
|
notifier=hs.get_replication_notifier(),
|
||||||
"id",
|
stream_name="pushers",
|
||||||
extra_tables=[("deleted_pushers", "stream_id")],
|
instance_name=self._instance_name,
|
||||||
is_writer=hs.config.worker.worker_app is None,
|
tables=[
|
||||||
|
("pushers", "instance_name", "id"),
|
||||||
|
("deleted_pushers", "instance_name", "stream_id"),
|
||||||
|
],
|
||||||
|
sequence_name="pushers_sequence",
|
||||||
|
writers=["master"],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.db_pool.updates.register_background_update_handler(
|
self.db_pool.updates.register_background_update_handler(
|
||||||
|
@ -655,7 +657,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
|
||||||
class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
||||||
# Because we have write access, this will be a StreamIdGenerator
|
# Because we have write access, this will be a StreamIdGenerator
|
||||||
# (see PusherWorkerStore.__init__)
|
# (see PusherWorkerStore.__init__)
|
||||||
_pushers_id_gen: AbstractStreamIdGenerator
|
_pushers_id_gen: MultiWriterIdGenerator
|
||||||
|
|
||||||
async def add_pusher(
|
async def add_pusher(
|
||||||
self,
|
self,
|
||||||
|
@ -688,6 +690,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
||||||
"last_stream_ordering": last_stream_ordering,
|
"last_stream_ordering": last_stream_ordering,
|
||||||
"profile_tag": profile_tag,
|
"profile_tag": profile_tag,
|
||||||
"id": stream_id,
|
"id": stream_id,
|
||||||
|
"instance_name": self._instance_name,
|
||||||
"enabled": enabled,
|
"enabled": enabled,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
# XXX(quenting): We're only really persisting the access token ID
|
# XXX(quenting): We're only really persisting the access token ID
|
||||||
|
@ -735,6 +738,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
||||||
table="deleted_pushers",
|
table="deleted_pushers",
|
||||||
values={
|
values={
|
||||||
"stream_id": stream_id,
|
"stream_id": stream_id,
|
||||||
|
"instance_name": self._instance_name,
|
||||||
"app_id": app_id,
|
"app_id": app_id,
|
||||||
"pushkey": pushkey,
|
"pushkey": pushkey,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
@ -773,9 +777,15 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
||||||
self.db_pool.simple_insert_many_txn(
|
self.db_pool.simple_insert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="deleted_pushers",
|
table="deleted_pushers",
|
||||||
keys=("stream_id", "app_id", "pushkey", "user_id"),
|
keys=("stream_id", "instance_name", "app_id", "pushkey", "user_id"),
|
||||||
values=[
|
values=[
|
||||||
(stream_id, pusher.app_id, pusher.pushkey, user_id)
|
(
|
||||||
|
stream_id,
|
||||||
|
self._instance_name,
|
||||||
|
pusher.app_id,
|
||||||
|
pusher.pushkey,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
for stream_id, pusher in zip(stream_ids, pushers)
|
for stream_id, pusher in zip(stream_ids, pushers)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -44,12 +44,10 @@ from synapse.storage.database import (
|
||||||
LoggingDatabaseConnection,
|
LoggingDatabaseConnection,
|
||||||
LoggingTransaction,
|
LoggingTransaction,
|
||||||
)
|
)
|
||||||
from synapse.storage.engines import PostgresEngine
|
|
||||||
from synapse.storage.engines._base import IsolationLevel
|
from synapse.storage.engines._base import IsolationLevel
|
||||||
from synapse.storage.util.id_generators import (
|
from synapse.storage.util.id_generators import (
|
||||||
AbstractStreamIdGenerator,
|
AbstractStreamIdGenerator,
|
||||||
MultiWriterIdGenerator,
|
MultiWriterIdGenerator,
|
||||||
StreamIdGenerator,
|
|
||||||
)
|
)
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
JsonDict,
|
JsonDict,
|
||||||
|
@ -80,35 +78,20 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
# class below that is used on the main process.
|
# class below that is used on the main process.
|
||||||
self._receipts_id_gen: AbstractStreamIdGenerator
|
self._receipts_id_gen: AbstractStreamIdGenerator
|
||||||
|
|
||||||
if isinstance(database.engine, PostgresEngine):
|
self._can_write_to_receipts = (
|
||||||
self._can_write_to_receipts = (
|
self._instance_name in hs.config.worker.writers.receipts
|
||||||
self._instance_name in hs.config.worker.writers.receipts
|
)
|
||||||
)
|
|
||||||
|
|
||||||
self._receipts_id_gen = MultiWriterIdGenerator(
|
self._receipts_id_gen = MultiWriterIdGenerator(
|
||||||
db_conn=db_conn,
|
db_conn=db_conn,
|
||||||
db=database,
|
db=database,
|
||||||
notifier=hs.get_replication_notifier(),
|
notifier=hs.get_replication_notifier(),
|
||||||
stream_name="receipts",
|
stream_name="receipts",
|
||||||
instance_name=self._instance_name,
|
instance_name=self._instance_name,
|
||||||
tables=[("receipts_linearized", "instance_name", "stream_id")],
|
tables=[("receipts_linearized", "instance_name", "stream_id")],
|
||||||
sequence_name="receipts_sequence",
|
sequence_name="receipts_sequence",
|
||||||
writers=hs.config.worker.writers.receipts,
|
writers=hs.config.worker.writers.receipts,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
self._can_write_to_receipts = True
|
|
||||||
|
|
||||||
# Multiple writers are not supported for SQLite.
|
|
||||||
#
|
|
||||||
# We shouldn't be running in worker mode with SQLite, but its useful
|
|
||||||
# to support it for unit tests.
|
|
||||||
self._receipts_id_gen = StreamIdGenerator(
|
|
||||||
db_conn,
|
|
||||||
hs.get_replication_notifier(),
|
|
||||||
"receipts_linearized",
|
|
||||||
"stream_id",
|
|
||||||
is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts,
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
|
|
|
@ -169,9 +169,9 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
@cached(uncached_args=("event",), tree=True)
|
@cached(uncached_args=("event",), tree=True)
|
||||||
async def get_relations_for_event(
|
async def get_relations_for_event(
|
||||||
self,
|
self,
|
||||||
|
room_id: str,
|
||||||
event_id: str,
|
event_id: str,
|
||||||
event: EventBase,
|
event: EventBase,
|
||||||
room_id: str,
|
|
||||||
relation_type: Optional[str] = None,
|
relation_type: Optional[str] = None,
|
||||||
event_type: Optional[str] = None,
|
event_type: Optional[str] = None,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
|
|
|
@ -58,13 +58,11 @@ from synapse.storage.database import (
|
||||||
LoggingTransaction,
|
LoggingTransaction,
|
||||||
)
|
)
|
||||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
from synapse.storage.engines import PostgresEngine
|
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
from synapse.storage.util.id_generators import (
|
from synapse.storage.util.id_generators import (
|
||||||
AbstractStreamIdGenerator,
|
AbstractStreamIdGenerator,
|
||||||
IdGenerator,
|
IdGenerator,
|
||||||
MultiWriterIdGenerator,
|
MultiWriterIdGenerator,
|
||||||
StreamIdGenerator,
|
|
||||||
)
|
)
|
||||||
from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
|
from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
|
@ -155,27 +153,17 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator
|
self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator
|
||||||
|
|
||||||
if isinstance(database.engine, PostgresEngine):
|
self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
|
||||||
self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
|
db_conn=db_conn,
|
||||||
db_conn=db_conn,
|
db=database,
|
||||||
db=database,
|
notifier=hs.get_replication_notifier(),
|
||||||
notifier=hs.get_replication_notifier(),
|
stream_name="un_partial_stated_room_stream",
|
||||||
stream_name="un_partial_stated_room_stream",
|
instance_name=self._instance_name,
|
||||||
instance_name=self._instance_name,
|
tables=[("un_partial_stated_room_stream", "instance_name", "stream_id")],
|
||||||
tables=[
|
sequence_name="un_partial_stated_room_stream_sequence",
|
||||||
("un_partial_stated_room_stream", "instance_name", "stream_id")
|
# TODO(faster_joins, multiple writers) Support multiple writers.
|
||||||
],
|
writers=["master"],
|
||||||
sequence_name="un_partial_stated_room_stream_sequence",
|
)
|
||||||
# TODO(faster_joins, multiple writers) Support multiple writers.
|
|
||||||
writers=["master"],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator(
|
|
||||||
db_conn,
|
|
||||||
hs.get_replication_notifier(),
|
|
||||||
"un_partial_stated_room_stream",
|
|
||||||
"stream_id",
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_replication_position(
|
def process_replication_position(
|
||||||
self, stream_name: str, instance_name: str, token: int
|
self, stream_name: str, instance_name: str, token: int
|
||||||
|
|
|
@ -142,6 +142,10 @@ class PostgresEngine(
|
||||||
apply stricter checks on new databases versus existing database.
|
apply stricter checks on new databases versus existing database.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
|
||||||
|
if allow_unsafe_locale:
|
||||||
|
return
|
||||||
|
|
||||||
collation, ctype = self.get_db_locale(txn)
|
collation, ctype = self.get_db_locale(txn)
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
|
@ -155,7 +159,9 @@ class PostgresEngine(
|
||||||
if errors:
|
if errors:
|
||||||
raise IncorrectDatabaseSetup(
|
raise IncorrectDatabaseSetup(
|
||||||
"Database is incorrectly configured:\n\n%s\n\n"
|
"Database is incorrectly configured:\n\n%s\n\n"
|
||||||
"See docs/postgres.md for more information." % ("\n".join(errors))
|
"See docs/postgres.md for more information. You can override this check by"
|
||||||
|
"setting 'allow_unsafe_locale' to true in the database config.",
|
||||||
|
"\n".join(errors),
|
||||||
)
|
)
|
||||||
|
|
||||||
def convert_param_style(self, sql: str) -> str:
|
def convert_param_style(self, sql: str) -> str:
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
--
|
||||||
|
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
--
|
||||||
|
-- Copyright (C) 2024 New Vector, Ltd
|
||||||
|
--
|
||||||
|
-- This program is free software: you can redistribute it and/or modify
|
||||||
|
-- it under the terms of the GNU Affero General Public License as
|
||||||
|
-- published by the Free Software Foundation, either version 3 of the
|
||||||
|
-- License, or (at your option) any later version.
|
||||||
|
--
|
||||||
|
-- See the GNU Affero General Public License for more details:
|
||||||
|
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
|
||||||
|
-- Add `instance_name` columns to stream tables to allow them to be used with
|
||||||
|
-- `MultiWriterIdGenerator`
|
||||||
|
ALTER TABLE device_lists_stream ADD COLUMN instance_name TEXT;
|
||||||
|
ALTER TABLE user_signature_stream ADD COLUMN instance_name TEXT;
|
||||||
|
ALTER TABLE device_lists_outbound_pokes ADD COLUMN instance_name TEXT;
|
||||||
|
ALTER TABLE device_lists_changes_in_room ADD COLUMN instance_name TEXT;
|
||||||
|
ALTER TABLE device_lists_remote_pending ADD COLUMN instance_name TEXT;
|
||||||
|
|
||||||
|
ALTER TABLE e2e_cross_signing_keys ADD COLUMN instance_name TEXT;
|
||||||
|
|
||||||
|
ALTER TABLE push_rules_stream ADD COLUMN instance_name TEXT;
|
||||||
|
|
||||||
|
ALTER TABLE pushers ADD COLUMN instance_name TEXT;
|
||||||
|
ALTER TABLE deleted_pushers ADD COLUMN instance_name TEXT;
|
|
@ -0,0 +1,54 @@
|
||||||
|
--
|
||||||
|
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
--
|
||||||
|
-- Copyright (C) 2024 New Vector, Ltd
|
||||||
|
--
|
||||||
|
-- This program is free software: you can redistribute it and/or modify
|
||||||
|
-- it under the terms of the GNU Affero General Public License as
|
||||||
|
-- published by the Free Software Foundation, either version 3 of the
|
||||||
|
-- License, or (at your option) any later version.
|
||||||
|
--
|
||||||
|
-- See the GNU Affero General Public License for more details:
|
||||||
|
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
|
||||||
|
-- Add squences for stream tables to allow them to be used with
|
||||||
|
-- `MultiWriterIdGenerator`
|
||||||
|
CREATE SEQUENCE IF NOT EXISTS device_lists_sequence;
|
||||||
|
|
||||||
|
-- We need to take the max across all the device lists tables as they share the
|
||||||
|
-- ID generator
|
||||||
|
SELECT setval('device_lists_sequence', (
|
||||||
|
SELECT GREATEST(
|
||||||
|
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_stream),
|
||||||
|
(SELECT COALESCE(MAX(stream_id), 1) FROM user_signature_stream),
|
||||||
|
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_outbound_pokes),
|
||||||
|
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_in_room),
|
||||||
|
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_remote_pending),
|
||||||
|
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_converted_stream_position)
|
||||||
|
)
|
||||||
|
));
|
||||||
|
|
||||||
|
CREATE SEQUENCE IF NOT EXISTS e2e_cross_signing_keys_sequence;
|
||||||
|
|
||||||
|
SELECT setval('e2e_cross_signing_keys_sequence', (
|
||||||
|
SELECT COALESCE(MAX(stream_id), 1) FROM e2e_cross_signing_keys
|
||||||
|
));
|
||||||
|
|
||||||
|
|
||||||
|
CREATE SEQUENCE IF NOT EXISTS push_rules_stream_sequence;
|
||||||
|
|
||||||
|
SELECT setval('push_rules_stream_sequence', (
|
||||||
|
SELECT COALESCE(MAX(stream_id), 1) FROM push_rules_stream
|
||||||
|
));
|
||||||
|
|
||||||
|
|
||||||
|
CREATE SEQUENCE IF NOT EXISTS pushers_sequence;
|
||||||
|
|
||||||
|
-- We need to take the max across all the pusher tables as they share the
|
||||||
|
-- ID generator
|
||||||
|
SELECT setval('pushers_sequence', (
|
||||||
|
SELECT GREATEST(
|
||||||
|
(SELECT COALESCE(MAX(id), 1) FROM pushers),
|
||||||
|
(SELECT COALESCE(MAX(stream_id), 1) FROM deleted_pushers)
|
||||||
|
)
|
||||||
|
));
|
|
@ -0,0 +1,15 @@
|
||||||
|
--
|
||||||
|
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
--
|
||||||
|
-- Copyright (C) 2024 New Vector, Ltd
|
||||||
|
--
|
||||||
|
-- This program is free software: you can redistribute it and/or modify
|
||||||
|
-- it under the terms of the GNU Affero General Public License as
|
||||||
|
-- published by the Free Software Foundation, either version 3 of the
|
||||||
|
-- License, or (at your option) any later version.
|
||||||
|
--
|
||||||
|
-- See the GNU Affero General Public License for more details:
|
||||||
|
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
|
||||||
|
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||||
|
(8504, 'cleanup_device_federation_outbox', '{}');
|
|
@ -23,15 +23,12 @@ import abc
|
||||||
import heapq
|
import heapq
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from collections import OrderedDict
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
AsyncContextManager,
|
AsyncContextManager,
|
||||||
ContextManager,
|
ContextManager,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
|
||||||
Generic,
|
Generic,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
|
@ -53,9 +50,11 @@ from synapse.storage.database import (
|
||||||
DatabasePool,
|
DatabasePool,
|
||||||
LoggingDatabaseConnection,
|
LoggingDatabaseConnection,
|
||||||
LoggingTransaction,
|
LoggingTransaction,
|
||||||
|
make_in_list_sql_clause,
|
||||||
)
|
)
|
||||||
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
from synapse.storage.util.sequence import PostgresSequenceGenerator
|
from synapse.storage.util.sequence import build_sequence_generator
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.notifier import ReplicationNotifier
|
from synapse.notifier import ReplicationNotifier
|
||||||
|
@ -177,161 +176,6 @@ class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class StreamIdGenerator(AbstractStreamIdGenerator):
|
|
||||||
"""Generates and tracks stream IDs for a stream with a single writer.
|
|
||||||
|
|
||||||
This class must only be used when the current Synapse process is the sole
|
|
||||||
writer for a stream.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db_conn(connection): A database connection to use to fetch the
|
|
||||||
initial value of the generator from.
|
|
||||||
table(str): A database table to read the initial value of the id
|
|
||||||
generator from.
|
|
||||||
column(str): The column of the database table to read the initial
|
|
||||||
value from the id generator from.
|
|
||||||
extra_tables(list): List of pairs of database tables and columns to
|
|
||||||
use to source the initial value of the generator from. The value
|
|
||||||
with the largest magnitude is used.
|
|
||||||
step(int): which direction the stream ids grow in. +1 to grow
|
|
||||||
upwards, -1 to grow downwards.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
async with stream_id_gen.get_next() as stream_id:
|
|
||||||
# ... persist event ...
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
db_conn: LoggingDatabaseConnection,
|
|
||||||
notifier: "ReplicationNotifier",
|
|
||||||
table: str,
|
|
||||||
column: str,
|
|
||||||
extra_tables: Iterable[Tuple[str, str]] = (),
|
|
||||||
step: int = 1,
|
|
||||||
is_writer: bool = True,
|
|
||||||
) -> None:
|
|
||||||
assert step != 0
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
self._step: int = step
|
|
||||||
self._current: int = _load_current_id(db_conn, table, column, step)
|
|
||||||
self._is_writer = is_writer
|
|
||||||
for table, column in extra_tables:
|
|
||||||
self._current = (max if step > 0 else min)(
|
|
||||||
self._current, _load_current_id(db_conn, table, column, step)
|
|
||||||
)
|
|
||||||
|
|
||||||
# We use this as an ordered set, as we want to efficiently append items,
|
|
||||||
# remove items and get the first item. Since we insert IDs in order, the
|
|
||||||
# insertion ordering will ensure its in the correct ordering.
|
|
||||||
#
|
|
||||||
# The key and values are the same, but we never look at the values.
|
|
||||||
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
|
|
||||||
|
|
||||||
self._notifier = notifier
|
|
||||||
|
|
||||||
def advance(self, instance_name: str, new_id: int) -> None:
|
|
||||||
# Advance should never be called on a writer instance, only over replication
|
|
||||||
if self._is_writer:
|
|
||||||
raise Exception("Replication is not supported by writer StreamIdGenerator")
|
|
||||||
|
|
||||||
self._current = (max if self._step > 0 else min)(self._current, new_id)
|
|
||||||
|
|
||||||
def get_next(self) -> AsyncContextManager[int]:
|
|
||||||
with self._lock:
|
|
||||||
self._current += self._step
|
|
||||||
next_id = self._current
|
|
||||||
|
|
||||||
self._unfinished_ids[next_id] = next_id
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def manager() -> Generator[int, None, None]:
|
|
||||||
try:
|
|
||||||
yield next_id
|
|
||||||
finally:
|
|
||||||
with self._lock:
|
|
||||||
self._unfinished_ids.pop(next_id)
|
|
||||||
|
|
||||||
self._notifier.notify_replication()
|
|
||||||
|
|
||||||
return _AsyncCtxManagerWrapper(manager())
|
|
||||||
|
|
||||||
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
|
|
||||||
with self._lock:
|
|
||||||
next_ids = range(
|
|
||||||
self._current + self._step,
|
|
||||||
self._current + self._step * (n + 1),
|
|
||||||
self._step,
|
|
||||||
)
|
|
||||||
self._current += n * self._step
|
|
||||||
|
|
||||||
for next_id in next_ids:
|
|
||||||
self._unfinished_ids[next_id] = next_id
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def manager() -> Generator[Sequence[int], None, None]:
|
|
||||||
try:
|
|
||||||
yield next_ids
|
|
||||||
finally:
|
|
||||||
with self._lock:
|
|
||||||
for next_id in next_ids:
|
|
||||||
self._unfinished_ids.pop(next_id)
|
|
||||||
|
|
||||||
self._notifier.notify_replication()
|
|
||||||
|
|
||||||
return _AsyncCtxManagerWrapper(manager())
|
|
||||||
|
|
||||||
def get_next_txn(self, txn: LoggingTransaction) -> int:
|
|
||||||
"""
|
|
||||||
Retrieve the next stream ID from within a database transaction.
|
|
||||||
|
|
||||||
Clean-up functions will be called when the transaction finishes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
txn: The database transaction object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The next stream ID.
|
|
||||||
"""
|
|
||||||
if not self._is_writer:
|
|
||||||
raise Exception("Tried to allocate stream ID on non-writer")
|
|
||||||
|
|
||||||
# Get the next stream ID.
|
|
||||||
with self._lock:
|
|
||||||
self._current += self._step
|
|
||||||
next_id = self._current
|
|
||||||
|
|
||||||
self._unfinished_ids[next_id] = next_id
|
|
||||||
|
|
||||||
def clear_unfinished_id(id_to_clear: int) -> None:
|
|
||||||
"""A function to mark processing this ID as finished"""
|
|
||||||
with self._lock:
|
|
||||||
self._unfinished_ids.pop(id_to_clear)
|
|
||||||
|
|
||||||
# Mark this ID as finished once the database transaction itself finishes.
|
|
||||||
txn.call_after(clear_unfinished_id, next_id)
|
|
||||||
txn.call_on_exception(clear_unfinished_id, next_id)
|
|
||||||
|
|
||||||
# Return the new ID.
|
|
||||||
return next_id
|
|
||||||
|
|
||||||
def get_current_token(self) -> int:
|
|
||||||
if not self._is_writer:
|
|
||||||
return self._current
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
if self._unfinished_ids:
|
|
||||||
return next(iter(self._unfinished_ids)) - self._step
|
|
||||||
|
|
||||||
return self._current
|
|
||||||
|
|
||||||
def get_current_token_for_writer(self, instance_name: str) -> int:
|
|
||||||
return self.get_current_token()
|
|
||||||
|
|
||||||
def get_minimal_local_current_token(self) -> int:
|
|
||||||
return self.get_current_token()
|
|
||||||
|
|
||||||
|
|
||||||
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
"""Generates and tracks stream IDs for a stream with multiple writers.
|
"""Generates and tracks stream IDs for a stream with multiple writers.
|
||||||
|
|
||||||
|
@ -432,7 +276,22 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
# no active writes in progress.
|
# no active writes in progress.
|
||||||
self._max_position_of_local_instance = self._max_seen_allocated_stream_id
|
self._max_position_of_local_instance = self._max_seen_allocated_stream_id
|
||||||
|
|
||||||
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
|
# This goes and fills out the above state from the database.
|
||||||
|
self._load_current_ids(db_conn, tables)
|
||||||
|
|
||||||
|
self._sequence_gen = build_sequence_generator(
|
||||||
|
db_conn=db_conn,
|
||||||
|
database_engine=db.engine,
|
||||||
|
get_first_callback=lambda _: self._persisted_upto_position,
|
||||||
|
sequence_name=sequence_name,
|
||||||
|
# We only need to set the below if we want it to call
|
||||||
|
# `check_consistency`, but we do that ourselves below so we can
|
||||||
|
# leave them blank.
|
||||||
|
table=None,
|
||||||
|
id_column=None,
|
||||||
|
stream_name=None,
|
||||||
|
positive=positive,
|
||||||
|
)
|
||||||
|
|
||||||
# We check that the table and sequence haven't diverged.
|
# We check that the table and sequence haven't diverged.
|
||||||
for table, _, id_column in tables:
|
for table, _, id_column in tables:
|
||||||
|
@ -444,9 +303,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
positive=positive,
|
positive=positive,
|
||||||
)
|
)
|
||||||
|
|
||||||
# This goes and fills out the above state from the database.
|
|
||||||
self._load_current_ids(db_conn, tables)
|
|
||||||
|
|
||||||
self._max_seen_allocated_stream_id = max(
|
self._max_seen_allocated_stream_id = max(
|
||||||
self._current_positions.values(), default=1
|
self._current_positions.values(), default=1
|
||||||
)
|
)
|
||||||
|
@ -480,13 +336,17 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
# important if we add back a writer after a long time; we want to
|
# important if we add back a writer after a long time; we want to
|
||||||
# consider that a "new" writer, rather than using the old stale
|
# consider that a "new" writer, rather than using the old stale
|
||||||
# entry here.
|
# entry here.
|
||||||
sql = """
|
clause, args = make_in_list_sql_clause(
|
||||||
|
self._db.engine, "instance_name", self._writers, negative=True
|
||||||
|
)
|
||||||
|
|
||||||
|
sql = f"""
|
||||||
DELETE FROM stream_positions
|
DELETE FROM stream_positions
|
||||||
WHERE
|
WHERE
|
||||||
stream_name = ?
|
stream_name = ?
|
||||||
AND instance_name != ALL(?)
|
AND {clause}
|
||||||
"""
|
"""
|
||||||
cur.execute(sql, (self._stream_name, self._writers))
|
cur.execute(sql, [self._stream_name] + args)
|
||||||
|
|
||||||
sql = """
|
sql = """
|
||||||
SELECT instance_name, stream_id FROM stream_positions
|
SELECT instance_name, stream_id FROM stream_positions
|
||||||
|
@ -508,12 +368,16 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
# We add a GREATEST here to ensure that the result is always
|
# We add a GREATEST here to ensure that the result is always
|
||||||
# positive. (This can be a problem for e.g. backfill streams where
|
# positive. (This can be a problem for e.g. backfill streams where
|
||||||
# the server has never backfilled).
|
# the server has never backfilled).
|
||||||
|
greatest_func = (
|
||||||
|
"GREATEST" if isinstance(self._db.engine, PostgresEngine) else "MAX"
|
||||||
|
)
|
||||||
max_stream_id = 1
|
max_stream_id = 1
|
||||||
for table, _, id_column in tables:
|
for table, _, id_column in tables:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
|
SELECT %(greatest_func)s(COALESCE(%(agg)s(%(id)s), 1), 1)
|
||||||
FROM %(table)s
|
FROM %(table)s
|
||||||
""" % {
|
""" % {
|
||||||
|
"greatest_func": greatest_func,
|
||||||
"id": id_column,
|
"id": id_column,
|
||||||
"table": table,
|
"table": table,
|
||||||
"agg": "MAX" if self._positive else "-MIN",
|
"agg": "MAX" if self._positive else "-MIN",
|
||||||
|
@ -913,6 +777,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
|
|
||||||
# We upsert the value, ensuring on conflict that we always increase the
|
# We upsert the value, ensuring on conflict that we always increase the
|
||||||
# value (or decrease if stream goes backwards).
|
# value (or decrease if stream goes backwards).
|
||||||
|
if isinstance(self._db.engine, PostgresEngine):
|
||||||
|
agg = "GREATEST" if self._positive else "LEAST"
|
||||||
|
else:
|
||||||
|
agg = "MAX" if self._positive else "MIN"
|
||||||
|
|
||||||
sql = """
|
sql = """
|
||||||
INSERT INTO stream_positions (stream_name, instance_name, stream_id)
|
INSERT INTO stream_positions (stream_name, instance_name, stream_id)
|
||||||
VALUES (?, ?, ?)
|
VALUES (?, ?, ?)
|
||||||
|
@ -920,10 +789,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
DO UPDATE SET
|
DO UPDATE SET
|
||||||
stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
|
stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
|
||||||
""" % {
|
""" % {
|
||||||
"agg": "GREATEST" if self._positive else "LEAST",
|
"agg": agg,
|
||||||
}
|
}
|
||||||
|
|
||||||
pos = (self.get_current_token_for_writer(self._instance_name),)
|
pos = self.get_current_token_for_writer(self._instance_name)
|
||||||
txn.execute(sql, (self._stream_name, self._instance_name, pos))
|
txn.execute(sql, (self._stream_name, self._instance_name, pos))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -48,7 +48,7 @@ import attr
|
||||||
from immutabledict import immutabledict
|
from immutabledict import immutabledict
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
from signedjson.types import VerifyKey
|
from signedjson.types import VerifyKey
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import Self, TypedDict
|
||||||
from unpaddedbase64 import decode_base64
|
from unpaddedbase64 import decode_base64
|
||||||
from zope.interface import Interface
|
from zope.interface import Interface
|
||||||
|
|
||||||
|
@ -515,6 +515,27 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
|
||||||
# at `self.stream`.
|
# at `self.stream`.
|
||||||
return self.instance_map.get(instance_name, self.stream)
|
return self.instance_map.get(instance_name, self.stream)
|
||||||
|
|
||||||
|
def is_before_or_eq(self, other_token: Self) -> bool:
|
||||||
|
"""Wether this token is before the other token, i.e. every constituent
|
||||||
|
part is before the other.
|
||||||
|
|
||||||
|
Essentially it is `self <= other`.
|
||||||
|
|
||||||
|
Note: if `self.is_before_or_eq(other_token) is False` then that does not
|
||||||
|
imply that the reverse is True.
|
||||||
|
"""
|
||||||
|
if self.stream > other_token.stream:
|
||||||
|
return False
|
||||||
|
|
||||||
|
instances = self.instance_map.keys() | other_token.instance_map.keys()
|
||||||
|
for instance in instances:
|
||||||
|
if self.instance_map.get(
|
||||||
|
instance, self.stream
|
||||||
|
) > other_token.instance_map.get(instance, other_token.stream):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
@attr.s(frozen=True, slots=True, order=False)
|
@attr.s(frozen=True, slots=True, order=False)
|
||||||
class RoomStreamToken(AbstractMultiWriterStreamToken):
|
class RoomStreamToken(AbstractMultiWriterStreamToken):
|
||||||
|
@ -1008,6 +1029,41 @@ class StreamToken:
|
||||||
"""Returns the stream ID for the given key."""
|
"""Returns the stream ID for the given key."""
|
||||||
return getattr(self, key.value)
|
return getattr(self, key.value)
|
||||||
|
|
||||||
|
def is_before_or_eq(self, other_token: "StreamToken") -> bool:
|
||||||
|
"""Wether this token is before the other token, i.e. every constituent
|
||||||
|
part is before the other.
|
||||||
|
|
||||||
|
Essentially it is `self <= other`.
|
||||||
|
|
||||||
|
Note: if `self.is_before_or_eq(other_token) is False` then that does not
|
||||||
|
imply that the reverse is True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
for _, key in StreamKeyType.__members__.items():
|
||||||
|
if key == StreamKeyType.TYPING:
|
||||||
|
# Typing stream is allowed to "reset", and so comparisons don't
|
||||||
|
# really make sense as is.
|
||||||
|
# TODO: Figure out a better way of tracking resets.
|
||||||
|
continue
|
||||||
|
|
||||||
|
self_value = self.get_field(key)
|
||||||
|
other_value = other_token.get_field(key)
|
||||||
|
|
||||||
|
if isinstance(self_value, RoomStreamToken):
|
||||||
|
assert isinstance(other_value, RoomStreamToken)
|
||||||
|
if not self_value.is_before_or_eq(other_value):
|
||||||
|
return False
|
||||||
|
elif isinstance(self_value, MultiWriterStreamToken):
|
||||||
|
assert isinstance(other_value, MultiWriterStreamToken)
|
||||||
|
if not self_value.is_before_or_eq(other_value):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
assert isinstance(other_value, int)
|
||||||
|
if self_value > other_value:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
StreamToken.START = StreamToken(
|
StreamToken.START = StreamToken(
|
||||||
RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0
|
RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0
|
||||||
|
|
|
@ -24,7 +24,12 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set
|
||||||
|
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
from synapse.logging.context import nested_logging_context
|
from synapse.logging.context import (
|
||||||
|
ContextResourceUsage,
|
||||||
|
LoggingContext,
|
||||||
|
nested_logging_context,
|
||||||
|
set_current_context,
|
||||||
|
)
|
||||||
from synapse.metrics import LaterGauge
|
from synapse.metrics import LaterGauge
|
||||||
from synapse.metrics.background_process_metrics import (
|
from synapse.metrics.background_process_metrics import (
|
||||||
run_as_background_process,
|
run_as_background_process,
|
||||||
|
@ -81,6 +86,8 @@ class TaskScheduler:
|
||||||
MAX_CONCURRENT_RUNNING_TASKS = 5
|
MAX_CONCURRENT_RUNNING_TASKS = 5
|
||||||
# Time from the last task update after which we will log a warning
|
# Time from the last task update after which we will log a warning
|
||||||
LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000 # 24hrs
|
LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000 # 24hrs
|
||||||
|
# Report a running task's status and usage every so often.
|
||||||
|
OCCASIONAL_REPORT_INTERVAL_MS = 5 * 60 * 1000 # 5 minutes
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self._hs = hs
|
self._hs = hs
|
||||||
|
@ -346,6 +353,33 @@ class TaskScheduler:
|
||||||
assert task.id not in self._running_tasks
|
assert task.id not in self._running_tasks
|
||||||
await self._store.delete_scheduled_task(task.id)
|
await self._store.delete_scheduled_task(task.id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _log_task_usage(
|
||||||
|
state: str, task: ScheduledTask, usage: ContextResourceUsage, active_time: float
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Log a line describing the state and usage of a task.
|
||||||
|
The log line is inspired by / a copy of the request log line format,
|
||||||
|
but with irrelevant fields removed.
|
||||||
|
|
||||||
|
active_time: Time that the task has been running for, in seconds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Task %s: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
|
||||||
|
" [%d dbevts] %r, %r",
|
||||||
|
state,
|
||||||
|
active_time,
|
||||||
|
usage.ru_utime,
|
||||||
|
usage.ru_stime,
|
||||||
|
usage.db_sched_duration_sec,
|
||||||
|
usage.db_txn_duration_sec,
|
||||||
|
int(usage.db_txn_count),
|
||||||
|
usage.evt_db_fetch_count,
|
||||||
|
task.resource_id,
|
||||||
|
task.params,
|
||||||
|
)
|
||||||
|
|
||||||
async def _launch_task(self, task: ScheduledTask) -> None:
|
async def _launch_task(self, task: ScheduledTask) -> None:
|
||||||
"""Launch a scheduled task now.
|
"""Launch a scheduled task now.
|
||||||
|
|
||||||
|
@ -360,8 +394,32 @@ class TaskScheduler:
|
||||||
)
|
)
|
||||||
function = self._actions[task.action]
|
function = self._actions[task.action]
|
||||||
|
|
||||||
|
def _occasional_report(
|
||||||
|
task_log_context: LoggingContext, start_time: float
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Helper to log a 'Task continuing' line every so often.
|
||||||
|
"""
|
||||||
|
|
||||||
|
current_time = self._clock.time()
|
||||||
|
calling_context = set_current_context(task_log_context)
|
||||||
|
try:
|
||||||
|
usage = task_log_context.get_resource_usage()
|
||||||
|
TaskScheduler._log_task_usage(
|
||||||
|
"continuing", task, usage, current_time - start_time
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
set_current_context(calling_context)
|
||||||
|
|
||||||
async def wrapper() -> None:
|
async def wrapper() -> None:
|
||||||
with nested_logging_context(task.id):
|
with nested_logging_context(task.id) as log_context:
|
||||||
|
start_time = self._clock.time()
|
||||||
|
occasional_status_call = self._clock.looping_call(
|
||||||
|
_occasional_report,
|
||||||
|
TaskScheduler.OCCASIONAL_REPORT_INTERVAL_MS,
|
||||||
|
log_context,
|
||||||
|
start_time,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
(status, result, error) = await function(task)
|
(status, result, error) = await function(task)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -383,6 +441,13 @@ class TaskScheduler:
|
||||||
)
|
)
|
||||||
self._running_tasks.remove(task.id)
|
self._running_tasks.remove(task.id)
|
||||||
|
|
||||||
|
current_time = self._clock.time()
|
||||||
|
usage = log_context.get_resource_usage()
|
||||||
|
TaskScheduler._log_task_usage(
|
||||||
|
status.value, task, usage, current_time - start_time
|
||||||
|
)
|
||||||
|
occasional_status_call.stop()
|
||||||
|
|
||||||
# Try launch a new task since we've finished with this one.
|
# Try launch a new task since we've finished with this one.
|
||||||
self._clock.call_later(0.1, self._launch_scheduled_tasks)
|
self._clock.call_later(0.1, self._launch_scheduled_tasks)
|
||||||
|
|
||||||
|
|
|
@ -407,3 +407,24 @@ class RoomMemberMasterHandlerTestCase(HomeserverTestCase):
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
self.get_success(self.store.did_forget(self.alice, self.room_id))
|
self.get_success(self.store.did_forget(self.alice, self.room_id))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_deduplicate_joins(self) -> None:
|
||||||
|
"""
|
||||||
|
Test that calling /join multiple times does not store a new state group.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.helper.join(self.room_id, user=self.bob, tok=self.bob_token)
|
||||||
|
|
||||||
|
sql = "SELECT COUNT(*) FROM state_groups WHERE room_id = ?"
|
||||||
|
rows = self.get_success(
|
||||||
|
self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id)
|
||||||
|
)
|
||||||
|
initial_count = rows[0][0]
|
||||||
|
|
||||||
|
self.helper.join(self.room_id, user=self.bob, tok=self.bob_token)
|
||||||
|
rows = self.get_success(
|
||||||
|
self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id)
|
||||||
|
)
|
||||||
|
new_count = rows[0][0]
|
||||||
|
|
||||||
|
self.assertEqual(initial_count, new_count)
|
||||||
|
|
|
@ -32,7 +32,7 @@ from twisted.web.resource import Resource
|
||||||
from synapse.api.constants import EduTypes
|
from synapse.api.constants import EduTypes
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
from synapse.federation.transport.server import TransportLayerServer
|
from synapse.federation.transport.server import TransportLayerServer
|
||||||
from synapse.handlers.typing import TypingWriterHandler
|
from synapse.handlers.typing import FORGET_TIMEOUT, TypingWriterHandler
|
||||||
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
|
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict, Requester, StreamKeyType, UserID, create_requester
|
from synapse.types import JsonDict, Requester, StreamKeyType, UserID, create_requester
|
||||||
|
@ -501,3 +501,54 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_prune_typing_replication(self) -> None:
|
||||||
|
"""Regression test for `get_all_typing_updates` breaking when we prune
|
||||||
|
old updates
|
||||||
|
"""
|
||||||
|
self.room_members = [U_APPLE, U_BANANA]
|
||||||
|
|
||||||
|
instance_name = self.hs.get_instance_name()
|
||||||
|
|
||||||
|
self.get_success(
|
||||||
|
self.handler.started_typing(
|
||||||
|
target_user=U_APPLE,
|
||||||
|
requester=create_requester(U_APPLE),
|
||||||
|
room_id=ROOM_ID,
|
||||||
|
timeout=10000,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
rows, _, _ = self.get_success(
|
||||||
|
self.handler.get_all_typing_updates(
|
||||||
|
instance_name=instance_name,
|
||||||
|
last_id=0,
|
||||||
|
current_id=self.handler.get_current_token(),
|
||||||
|
limit=100,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(rows, [(1, [ROOM_ID, [U_APPLE.to_string()]])])
|
||||||
|
|
||||||
|
self.reactor.advance(20000)
|
||||||
|
|
||||||
|
rows, _, _ = self.get_success(
|
||||||
|
self.handler.get_all_typing_updates(
|
||||||
|
instance_name=instance_name,
|
||||||
|
last_id=1,
|
||||||
|
current_id=self.handler.get_current_token(),
|
||||||
|
limit=100,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(rows, [(2, [ROOM_ID, []])])
|
||||||
|
|
||||||
|
self.reactor.advance(FORGET_TIMEOUT)
|
||||||
|
|
||||||
|
rows, _, _ = self.get_success(
|
||||||
|
self.handler.get_all_typing_updates(
|
||||||
|
instance_name=instance_name,
|
||||||
|
last_id=1,
|
||||||
|
current_id=self.handler.get_current_token(),
|
||||||
|
limit=100,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(rows, [])
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
# [This file includes modifications made by New Vector Limited]
|
# [This file includes modifications made by New Vector Limited]
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
|
import itertools
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
|
@ -46,11 +47,11 @@ from synapse.media._base import FileInfo, ThumbnailInfo
|
||||||
from synapse.media.filepath import MediaFilePaths
|
from synapse.media.filepath import MediaFilePaths
|
||||||
from synapse.media.media_storage import MediaStorage, ReadableFileWrapper
|
from synapse.media.media_storage import MediaStorage, ReadableFileWrapper
|
||||||
from synapse.media.storage_provider import FileStorageProviderBackend
|
from synapse.media.storage_provider import FileStorageProviderBackend
|
||||||
|
from synapse.media.thumbnailer import ThumbnailProvider
|
||||||
from synapse.module_api import ModuleApi
|
from synapse.module_api import ModuleApi
|
||||||
from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers
|
from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login
|
from synapse.rest.client import login, media
|
||||||
from synapse.rest.media.thumbnail_resource import ThumbnailResource
|
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict, RoomAlias
|
from synapse.types import JsonDict, RoomAlias
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
@ -153,68 +154,54 @@ class _TestImage:
|
||||||
is_inline: bool = True
|
is_inline: bool = True
|
||||||
|
|
||||||
|
|
||||||
@parameterized_class(
|
small_png = _TestImage(
|
||||||
("test_image",),
|
SMALL_PNG,
|
||||||
[
|
b"image/png",
|
||||||
# small png
|
b".png",
|
||||||
(
|
unhexlify(
|
||||||
_TestImage(
|
b"89504e470d0a1a0a0000000d4948445200000020000000200806"
|
||||||
SMALL_PNG,
|
b"000000737a7af40000001a49444154789cedc101010000008220"
|
||||||
b"image/png",
|
b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
|
||||||
b".png",
|
b"44ae426082"
|
||||||
unhexlify(
|
),
|
||||||
b"89504e470d0a1a0a0000000d4948445200000020000000200806"
|
unhexlify(
|
||||||
b"000000737a7af40000001a49444154789cedc101010000008220"
|
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
|
||||||
b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
|
b"0000001f15c4890000000d49444154789c636060606000000005"
|
||||||
b"44ae426082"
|
b"0001a5f645400000000049454e44ae426082"
|
||||||
),
|
),
|
||||||
unhexlify(
|
)
|
||||||
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
|
|
||||||
b"0000001f15c4890000000d49444154789c636060606000000005"
|
small_png_with_transparency = _TestImage(
|
||||||
b"0001a5f645400000000049454e44ae426082"
|
unhexlify(
|
||||||
),
|
b"89504e470d0a1a0a0000000d49484452000000010000000101000"
|
||||||
),
|
b"00000376ef9240000000274524e5300010194fdae0000000a4944"
|
||||||
),
|
b"4154789c636800000082008177cd72b60000000049454e44ae426"
|
||||||
# small png with transparency.
|
b"082"
|
||||||
(
|
),
|
||||||
_TestImage(
|
b"image/png",
|
||||||
unhexlify(
|
b".png",
|
||||||
b"89504e470d0a1a0a0000000d49484452000000010000000101000"
|
# Note that we don't check the output since it varies across
|
||||||
b"00000376ef9240000000274524e5300010194fdae0000000a4944"
|
# different versions of Pillow.
|
||||||
b"4154789c636800000082008177cd72b60000000049454e44ae426"
|
)
|
||||||
b"082"
|
|
||||||
),
|
small_lossless_webp = _TestImage(
|
||||||
b"image/png",
|
unhexlify(
|
||||||
b".png",
|
b"524946461a000000574542505650384c0d0000002f0000001007" b"1011118888fe0700"
|
||||||
# Note that we don't check the output since it varies across
|
),
|
||||||
# different versions of Pillow.
|
b"image/webp",
|
||||||
),
|
b".webp",
|
||||||
),
|
)
|
||||||
# small lossless webp
|
|
||||||
(
|
empty_file = _TestImage(
|
||||||
_TestImage(
|
b"",
|
||||||
unhexlify(
|
b"image/gif",
|
||||||
b"524946461a000000574542505650384c0d0000002f0000001007"
|
b".gif",
|
||||||
b"1011118888fe0700"
|
expected_found=False,
|
||||||
),
|
unable_to_thumbnail=True,
|
||||||
b"image/webp",
|
)
|
||||||
b".webp",
|
|
||||||
),
|
SVG = _TestImage(
|
||||||
),
|
b"""<?xml version="1.0"?>
|
||||||
# an empty file
|
|
||||||
(
|
|
||||||
_TestImage(
|
|
||||||
b"",
|
|
||||||
b"image/gif",
|
|
||||||
b".gif",
|
|
||||||
expected_found=False,
|
|
||||||
unable_to_thumbnail=True,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
# An SVG.
|
|
||||||
(
|
|
||||||
_TestImage(
|
|
||||||
b"""<?xml version="1.0"?>
|
|
||||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
|
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
|
||||||
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||||
|
|
||||||
|
@ -223,19 +210,32 @@ class _TestImage:
|
||||||
<circle cx="100" cy="100" r="50" stroke="black"
|
<circle cx="100" cy="100" r="50" stroke="black"
|
||||||
stroke-width="5" fill="red" />
|
stroke-width="5" fill="red" />
|
||||||
</svg>""",
|
</svg>""",
|
||||||
b"image/svg",
|
b"image/svg",
|
||||||
b".svg",
|
b".svg",
|
||||||
expected_found=False,
|
expected_found=False,
|
||||||
unable_to_thumbnail=True,
|
unable_to_thumbnail=True,
|
||||||
is_inline=False,
|
is_inline=False,
|
||||||
),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
test_images = [
|
||||||
|
small_png,
|
||||||
|
small_png_with_transparency,
|
||||||
|
small_lossless_webp,
|
||||||
|
empty_file,
|
||||||
|
SVG,
|
||||||
|
]
|
||||||
|
urls = [
|
||||||
|
"_matrix/media/r0/thumbnail",
|
||||||
|
"_matrix/client/unstable/org.matrix.msc3916/media/thumbnail",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@parameterized_class(("test_image", "url"), itertools.product(test_images, urls))
|
||||||
class MediaRepoTests(unittest.HomeserverTestCase):
|
class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
|
servlets = [media.register_servlets]
|
||||||
test_image: ClassVar[_TestImage]
|
test_image: ClassVar[_TestImage]
|
||||||
hijack_auth = True
|
hijack_auth = True
|
||||||
user_id = "@test:user"
|
user_id = "@test:user"
|
||||||
|
url: ClassVar[str]
|
||||||
|
|
||||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
self.fetches: List[
|
self.fetches: List[
|
||||||
|
@ -298,6 +298,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
"config": {"directory": self.storage_path},
|
"config": {"directory": self.storage_path},
|
||||||
}
|
}
|
||||||
config["media_storage_providers"] = [provider_config]
|
config["media_storage_providers"] = [provider_config]
|
||||||
|
config["experimental_features"] = {"msc3916_authenticated_media_enabled": True}
|
||||||
|
|
||||||
hs = self.setup_test_homeserver(config=config, federation_http_client=client)
|
hs = self.setup_test_homeserver(config=config, federation_http_client=client)
|
||||||
|
|
||||||
|
@ -502,7 +503,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
params = "?width=32&height=32&method=scale"
|
params = "?width=32&height=32&method=scale"
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/_matrix/media/v3/thumbnail/{self.media_id}{params}",
|
f"/{self.url}/{self.media_id}{params}",
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
await_result=False,
|
await_result=False,
|
||||||
)
|
)
|
||||||
|
@ -530,7 +531,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/_matrix/media/v3/thumbnail/{self.media_id}{params}",
|
f"/{self.url}/{self.media_id}{params}",
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
await_result=False,
|
await_result=False,
|
||||||
)
|
)
|
||||||
|
@ -566,12 +567,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
params = "?width=32&height=32&method=" + method
|
params = "?width=32&height=32&method=" + method
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/_matrix/media/r0/thumbnail/{self.media_id}{params}",
|
f"/{self.url}/{self.media_id}{params}",
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
await_result=False,
|
await_result=False,
|
||||||
)
|
)
|
||||||
self.pump()
|
self.pump()
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
b"Content-Length": [b"%d" % (len(self.test_image.data))],
|
b"Content-Length": [b"%d" % (len(self.test_image.data))],
|
||||||
b"Content-Type": [self.test_image.content_type],
|
b"Content-Type": [self.test_image.content_type],
|
||||||
|
@ -580,7 +580,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
(self.test_image.data, (len(self.test_image.data), headers))
|
(self.test_image.data, (len(self.test_image.data), headers))
|
||||||
)
|
)
|
||||||
self.pump()
|
self.pump()
|
||||||
|
|
||||||
if expected_found:
|
if expected_found:
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
|
|
||||||
|
@ -603,7 +602,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
channel.json_body,
|
channel.json_body,
|
||||||
{
|
{
|
||||||
"errcode": "M_UNKNOWN",
|
"errcode": "M_UNKNOWN",
|
||||||
"error": "Cannot find any thumbnails for the requested media ('/_matrix/media/r0/thumbnail/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
|
"error": f"Cannot find any thumbnails for the requested media ('/{self.url}/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -613,7 +612,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
channel.json_body,
|
channel.json_body,
|
||||||
{
|
{
|
||||||
"errcode": "M_NOT_FOUND",
|
"errcode": "M_NOT_FOUND",
|
||||||
"error": "Not found '/_matrix/media/r0/thumbnail/example.com/12345'",
|
"error": f"Not found '/{self.url}/example.com/12345'",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -625,12 +624,12 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
content_type = self.test_image.content_type.decode()
|
content_type = self.test_image.content_type.decode()
|
||||||
media_repo = self.hs.get_media_repository()
|
media_repo = self.hs.get_media_repository()
|
||||||
thumbnail_resouce = ThumbnailResource(
|
thumbnail_provider = ThumbnailProvider(
|
||||||
self.hs, media_repo, media_repo.media_storage
|
self.hs, media_repo, media_repo.media_storage
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNotNone(
|
self.assertIsNotNone(
|
||||||
thumbnail_resouce._select_thumbnail(
|
thumbnail_provider._select_thumbnail(
|
||||||
desired_width=desired_size,
|
desired_width=desired_size,
|
||||||
desired_height=desired_size,
|
desired_height=desired_size,
|
||||||
desired_method=method,
|
desired_method=method,
|
||||||
|
|
|
@ -24,8 +24,8 @@ from twisted.internet.defer import ensureDeferred
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.errors import NotFoundError
|
from synapse.api.errors import NotFoundError
|
||||||
from synapse.rest import admin, devices, room, sync
|
from synapse.rest import admin, devices, sync
|
||||||
from synapse.rest.client import account, keys, login, register
|
from synapse.rest.client import keys, login, register
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict, UserID, create_requester
|
from synapse.types import JsonDict, UserID, create_requester
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
@ -33,146 +33,6 @@ from synapse.util import Clock
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
class DeviceListsTestCase(unittest.HomeserverTestCase):
|
|
||||||
"""Tests regarding device list changes."""
|
|
||||||
|
|
||||||
servlets = [
|
|
||||||
admin.register_servlets_for_client_rest_resource,
|
|
||||||
login.register_servlets,
|
|
||||||
register.register_servlets,
|
|
||||||
account.register_servlets,
|
|
||||||
room.register_servlets,
|
|
||||||
sync.register_servlets,
|
|
||||||
devices.register_servlets,
|
|
||||||
]
|
|
||||||
|
|
||||||
def test_receiving_local_device_list_changes(self) -> None:
|
|
||||||
"""Tests that a local users that share a room receive each other's device list
|
|
||||||
changes.
|
|
||||||
"""
|
|
||||||
# Register two users
|
|
||||||
test_device_id = "TESTDEVICE"
|
|
||||||
alice_user_id = self.register_user("alice", "correcthorse")
|
|
||||||
alice_access_token = self.login(
|
|
||||||
alice_user_id, "correcthorse", device_id=test_device_id
|
|
||||||
)
|
|
||||||
|
|
||||||
bob_user_id = self.register_user("bob", "ponyponypony")
|
|
||||||
bob_access_token = self.login(bob_user_id, "ponyponypony")
|
|
||||||
|
|
||||||
# Create a room for them to coexist peacefully in
|
|
||||||
new_room_id = self.helper.create_room_as(
|
|
||||||
alice_user_id, is_public=True, tok=alice_access_token
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(new_room_id)
|
|
||||||
|
|
||||||
# Have Bob join the room
|
|
||||||
self.helper.invite(
|
|
||||||
new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
|
|
||||||
)
|
|
||||||
self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
|
|
||||||
|
|
||||||
# Now have Bob initiate an initial sync (in order to get a since token)
|
|
||||||
channel = self.make_request(
|
|
||||||
"GET",
|
|
||||||
"/sync",
|
|
||||||
access_token=bob_access_token,
|
|
||||||
)
|
|
||||||
self.assertEqual(channel.code, 200, channel.json_body)
|
|
||||||
next_batch_token = channel.json_body["next_batch"]
|
|
||||||
|
|
||||||
# ...and then an incremental sync. This should block until the sync stream is woken up,
|
|
||||||
# which we hope will happen as a result of Alice updating their device list.
|
|
||||||
bob_sync_channel = self.make_request(
|
|
||||||
"GET",
|
|
||||||
f"/sync?since={next_batch_token}&timeout=30000",
|
|
||||||
access_token=bob_access_token,
|
|
||||||
# Start the request, then continue on.
|
|
||||||
await_result=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Have alice update their device list
|
|
||||||
channel = self.make_request(
|
|
||||||
"PUT",
|
|
||||||
f"/devices/{test_device_id}",
|
|
||||||
{
|
|
||||||
"display_name": "New Device Name",
|
|
||||||
},
|
|
||||||
access_token=alice_access_token,
|
|
||||||
)
|
|
||||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
|
||||||
|
|
||||||
# Check that bob's incremental sync contains the updated device list.
|
|
||||||
# If not, the client would only receive the device list update on the
|
|
||||||
# *next* sync.
|
|
||||||
bob_sync_channel.await_result()
|
|
||||||
self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
|
|
||||||
|
|
||||||
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
|
|
||||||
"changed", []
|
|
||||||
)
|
|
||||||
self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
|
|
||||||
|
|
||||||
def test_not_receiving_local_device_list_changes(self) -> None:
|
|
||||||
"""Tests a local users DO NOT receive device updates from each other if they do not
|
|
||||||
share a room.
|
|
||||||
"""
|
|
||||||
# Register two users
|
|
||||||
test_device_id = "TESTDEVICE"
|
|
||||||
alice_user_id = self.register_user("alice", "correcthorse")
|
|
||||||
alice_access_token = self.login(
|
|
||||||
alice_user_id, "correcthorse", device_id=test_device_id
|
|
||||||
)
|
|
||||||
|
|
||||||
bob_user_id = self.register_user("bob", "ponyponypony")
|
|
||||||
bob_access_token = self.login(bob_user_id, "ponyponypony")
|
|
||||||
|
|
||||||
# These users do not share a room. They are lonely.
|
|
||||||
|
|
||||||
# Have Bob initiate an initial sync (in order to get a since token)
|
|
||||||
channel = self.make_request(
|
|
||||||
"GET",
|
|
||||||
"/sync",
|
|
||||||
access_token=bob_access_token,
|
|
||||||
)
|
|
||||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
|
||||||
next_batch_token = channel.json_body["next_batch"]
|
|
||||||
|
|
||||||
# ...and then an incremental sync. This should block until the sync stream is woken up,
|
|
||||||
# which we hope will happen as a result of Alice updating their device list.
|
|
||||||
bob_sync_channel = self.make_request(
|
|
||||||
"GET",
|
|
||||||
f"/sync?since={next_batch_token}&timeout=1000",
|
|
||||||
access_token=bob_access_token,
|
|
||||||
# Start the request, then continue on.
|
|
||||||
await_result=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Have alice update their device list
|
|
||||||
channel = self.make_request(
|
|
||||||
"PUT",
|
|
||||||
f"/devices/{test_device_id}",
|
|
||||||
{
|
|
||||||
"display_name": "New Device Name",
|
|
||||||
},
|
|
||||||
access_token=alice_access_token,
|
|
||||||
)
|
|
||||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
|
||||||
|
|
||||||
# Check that bob's incremental sync does not contain the updated device list.
|
|
||||||
bob_sync_channel.await_result()
|
|
||||||
self.assertEqual(
|
|
||||||
bob_sync_channel.code, HTTPStatus.OK, bob_sync_channel.json_body
|
|
||||||
)
|
|
||||||
|
|
||||||
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
|
|
||||||
"changed", []
|
|
||||||
)
|
|
||||||
self.assertNotIn(
|
|
||||||
alice_user_id, changed_device_lists, bob_sync_channel.json_body
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DevicesTestCase(unittest.HomeserverTestCase):
|
class DevicesTestCase(unittest.HomeserverTestCase):
|
||||||
servlets = [
|
servlets = [
|
||||||
admin.register_servlets,
|
admin.register_servlets,
|
||||||
|
|
1609
tests/rest/client/test_media.py
Normal file
1609
tests/rest/client/test_media.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -18,15 +18,39 @@
|
||||||
# [This file includes modifications made by New Vector Limited]
|
# [This file includes modifications made by New Vector Limited]
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
|
from parameterized import parameterized_class
|
||||||
|
|
||||||
from synapse.api.constants import EduTypes
|
from synapse.api.constants import EduTypes
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login, sendtodevice, sync
|
from synapse.rest.client import login, sendtodevice, sync
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase, override_config
|
from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
|
||||||
|
|
||||||
|
@parameterized_class(
|
||||||
|
("sync_endpoint", "experimental_features"),
|
||||||
|
[
|
||||||
|
("/sync", {}),
|
||||||
|
(
|
||||||
|
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
|
||||||
|
# Enable sliding sync
|
||||||
|
{"msc3575_enabled": True},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
class SendToDeviceTestCase(HomeserverTestCase):
|
class SendToDeviceTestCase(HomeserverTestCase):
|
||||||
|
"""
|
||||||
|
Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
sync_endpoint: The endpoint under test to use for syncing.
|
||||||
|
experimental_features: The experimental features homeserver config to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sync_endpoint: str
|
||||||
|
experimental_features: JsonDict
|
||||||
|
|
||||||
servlets = [
|
servlets = [
|
||||||
admin.register_servlets,
|
admin.register_servlets,
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
|
@ -34,6 +58,11 @@ class SendToDeviceTestCase(HomeserverTestCase):
|
||||||
sync.register_servlets,
|
sync.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def default_config(self) -> JsonDict:
|
||||||
|
config = super().default_config()
|
||||||
|
config["experimental_features"] = self.experimental_features
|
||||||
|
return config
|
||||||
|
|
||||||
def test_user_to_user(self) -> None:
|
def test_user_to_user(self) -> None:
|
||||||
"""A to-device message from one user to another should get delivered"""
|
"""A to-device message from one user to another should get delivered"""
|
||||||
|
|
||||||
|
@ -54,7 +83,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
|
||||||
self.assertEqual(chan.code, 200, chan.result)
|
self.assertEqual(chan.code, 200, chan.result)
|
||||||
|
|
||||||
# check it appears
|
# check it appears
|
||||||
channel = self.make_request("GET", "/sync", access_token=user2_tok)
|
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
expected_result = {
|
expected_result = {
|
||||||
"events": [
|
"events": [
|
||||||
|
@ -67,15 +96,19 @@ class SendToDeviceTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
self.assertEqual(channel.json_body["to_device"], expected_result)
|
self.assertEqual(channel.json_body["to_device"], expected_result)
|
||||||
|
|
||||||
# it should re-appear if we do another sync
|
# it should re-appear if we do another sync because the to-device message is not
|
||||||
channel = self.make_request("GET", "/sync", access_token=user2_tok)
|
# deleted until we acknowledge it by sending a `?since=...` parameter in the
|
||||||
|
# next sync request corresponding to the `next_batch` value from the response.
|
||||||
|
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
self.assertEqual(channel.json_body["to_device"], expected_result)
|
self.assertEqual(channel.json_body["to_device"], expected_result)
|
||||||
|
|
||||||
# it should *not* appear if we do an incremental sync
|
# it should *not* appear if we do an incremental sync
|
||||||
sync_token = channel.json_body["next_batch"]
|
sync_token = channel.json_body["next_batch"]
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET", f"/sync?since={sync_token}", access_token=user2_tok
|
"GET",
|
||||||
|
f"{self.sync_endpoint}?since={sync_token}",
|
||||||
|
access_token=user2_tok,
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])
|
self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])
|
||||||
|
@ -99,15 +132,19 @@ class SendToDeviceTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(chan.code, 200, chan.result)
|
self.assertEqual(chan.code, 200, chan.result)
|
||||||
|
|
||||||
# now sync: we should get two of the three
|
# now sync: we should get two of the three (because burst_count=2)
|
||||||
channel = self.make_request("GET", "/sync", access_token=user2_tok)
|
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
msgs = channel.json_body["to_device"]["events"]
|
msgs = channel.json_body["to_device"]["events"]
|
||||||
self.assertEqual(len(msgs), 2)
|
self.assertEqual(len(msgs), 2)
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
msgs[i],
|
msgs[i],
|
||||||
{"sender": user1, "type": "m.room_key_request", "content": {"idx": i}},
|
{
|
||||||
|
"sender": user1,
|
||||||
|
"type": "m.room_key_request",
|
||||||
|
"content": {"idx": i},
|
||||||
|
},
|
||||||
)
|
)
|
||||||
sync_token = channel.json_body["next_batch"]
|
sync_token = channel.json_body["next_batch"]
|
||||||
|
|
||||||
|
@ -125,7 +162,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# ... which should arrive
|
# ... which should arrive
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET", f"/sync?since={sync_token}", access_token=user2_tok
|
"GET",
|
||||||
|
f"{self.sync_endpoint}?since={sync_token}",
|
||||||
|
access_token=user2_tok,
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
msgs = channel.json_body["to_device"]["events"]
|
msgs = channel.json_body["to_device"]["events"]
|
||||||
|
@ -159,7 +198,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# now sync: we should get two of the three
|
# now sync: we should get two of the three
|
||||||
channel = self.make_request("GET", "/sync", access_token=user2_tok)
|
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
msgs = channel.json_body["to_device"]["events"]
|
msgs = channel.json_body["to_device"]["events"]
|
||||||
self.assertEqual(len(msgs), 2)
|
self.assertEqual(len(msgs), 2)
|
||||||
|
@ -193,7 +232,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# ... which should arrive
|
# ... which should arrive
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET", f"/sync?since={sync_token}", access_token=user2_tok
|
"GET",
|
||||||
|
f"{self.sync_endpoint}?since={sync_token}",
|
||||||
|
access_token=user2_tok,
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
msgs = channel.json_body["to_device"]["events"]
|
msgs = channel.json_body["to_device"]["events"]
|
||||||
|
@ -217,7 +258,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
|
||||||
user2_tok = self.login("u2", "pass", "d2")
|
user2_tok = self.login("u2", "pass", "d2")
|
||||||
|
|
||||||
# Do an initial sync
|
# Do an initial sync
|
||||||
channel = self.make_request("GET", "/sync", access_token=user2_tok)
|
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
sync_token = channel.json_body["next_batch"]
|
sync_token = channel.json_body["next_batch"]
|
||||||
|
|
||||||
|
@ -233,7 +274,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
|
||||||
self.assertEqual(chan.code, 200, chan.result)
|
self.assertEqual(chan.code, 200, chan.result)
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
|
"GET",
|
||||||
|
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
|
||||||
|
access_token=user2_tok,
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
messages = channel.json_body.get("to_device", {}).get("events", [])
|
messages = channel.json_body.get("to_device", {}).get("events", [])
|
||||||
|
@ -241,7 +284,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
|
||||||
sync_token = channel.json_body["next_batch"]
|
sync_token = channel.json_body["next_batch"]
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
|
"GET",
|
||||||
|
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
|
||||||
|
access_token=user2_tok,
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
messages = channel.json_body.get("to_device", {}).get("events", [])
|
messages = channel.json_body.get("to_device", {}).get("events", [])
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
import json
|
import json
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized, parameterized_class
|
||||||
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
@ -688,24 +688,180 @@ class SyncCacheTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.code, 200, channel.json_body)
|
self.assertEqual(channel.code, 200, channel.json_body)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterized_class(
|
||||||
|
("sync_endpoint", "experimental_features"),
|
||||||
|
[
|
||||||
|
("/sync", {}),
|
||||||
|
(
|
||||||
|
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
|
||||||
|
# Enable sliding sync
|
||||||
|
{"msc3575_enabled": True},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
class DeviceListSyncTestCase(unittest.HomeserverTestCase):
|
class DeviceListSyncTestCase(unittest.HomeserverTestCase):
|
||||||
|
"""
|
||||||
|
Tests regarding device list (`device_lists`) changes.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
sync_endpoint: The endpoint under test to use for syncing.
|
||||||
|
experimental_features: The experimental features homeserver config to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sync_endpoint: str
|
||||||
|
experimental_features: JsonDict
|
||||||
|
|
||||||
servlets = [
|
servlets = [
|
||||||
synapse.rest.admin.register_servlets,
|
synapse.rest.admin.register_servlets,
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
|
room.register_servlets,
|
||||||
sync.register_servlets,
|
sync.register_servlets,
|
||||||
devices.register_servlets,
|
devices.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def default_config(self) -> JsonDict:
|
||||||
|
config = super().default_config()
|
||||||
|
config["experimental_features"] = self.experimental_features
|
||||||
|
return config
|
||||||
|
|
||||||
|
def test_receiving_local_device_list_changes(self) -> None:
|
||||||
|
"""Tests that a local users that share a room receive each other's device list
|
||||||
|
changes.
|
||||||
|
"""
|
||||||
|
# Register two users
|
||||||
|
test_device_id = "TESTDEVICE"
|
||||||
|
alice_user_id = self.register_user("alice", "correcthorse")
|
||||||
|
alice_access_token = self.login(
|
||||||
|
alice_user_id, "correcthorse", device_id=test_device_id
|
||||||
|
)
|
||||||
|
|
||||||
|
bob_user_id = self.register_user("bob", "ponyponypony")
|
||||||
|
bob_access_token = self.login(bob_user_id, "ponyponypony")
|
||||||
|
|
||||||
|
# Create a room for them to coexist peacefully in
|
||||||
|
new_room_id = self.helper.create_room_as(
|
||||||
|
alice_user_id, is_public=True, tok=alice_access_token
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(new_room_id)
|
||||||
|
|
||||||
|
# Have Bob join the room
|
||||||
|
self.helper.invite(
|
||||||
|
new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
|
||||||
|
)
|
||||||
|
self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
|
||||||
|
|
||||||
|
# Now have Bob initiate an initial sync (in order to get a since token)
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
self.sync_endpoint,
|
||||||
|
access_token=bob_access_token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200, channel.json_body)
|
||||||
|
next_batch_token = channel.json_body["next_batch"]
|
||||||
|
|
||||||
|
# ...and then an incremental sync. This should block until the sync stream is woken up,
|
||||||
|
# which we hope will happen as a result of Alice updating their device list.
|
||||||
|
bob_sync_channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"{self.sync_endpoint}?since={next_batch_token}&timeout=30000",
|
||||||
|
access_token=bob_access_token,
|
||||||
|
# Start the request, then continue on.
|
||||||
|
await_result=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Have alice update their device list
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
f"/devices/{test_device_id}",
|
||||||
|
{
|
||||||
|
"display_name": "New Device Name",
|
||||||
|
},
|
||||||
|
access_token=alice_access_token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200, channel.json_body)
|
||||||
|
|
||||||
|
# Check that bob's incremental sync contains the updated device list.
|
||||||
|
# If not, the client would only receive the device list update on the
|
||||||
|
# *next* sync.
|
||||||
|
bob_sync_channel.await_result()
|
||||||
|
self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
|
||||||
|
|
||||||
|
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
|
||||||
|
"changed", []
|
||||||
|
)
|
||||||
|
self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
|
||||||
|
|
||||||
|
def test_not_receiving_local_device_list_changes(self) -> None:
|
||||||
|
"""Tests a local users DO NOT receive device updates from each other if they do not
|
||||||
|
share a room.
|
||||||
|
"""
|
||||||
|
# Register two users
|
||||||
|
test_device_id = "TESTDEVICE"
|
||||||
|
alice_user_id = self.register_user("alice", "correcthorse")
|
||||||
|
alice_access_token = self.login(
|
||||||
|
alice_user_id, "correcthorse", device_id=test_device_id
|
||||||
|
)
|
||||||
|
|
||||||
|
bob_user_id = self.register_user("bob", "ponyponypony")
|
||||||
|
bob_access_token = self.login(bob_user_id, "ponyponypony")
|
||||||
|
|
||||||
|
# These users do not share a room. They are lonely.
|
||||||
|
|
||||||
|
# Have Bob initiate an initial sync (in order to get a since token)
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
self.sync_endpoint,
|
||||||
|
access_token=bob_access_token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200, channel.json_body)
|
||||||
|
next_batch_token = channel.json_body["next_batch"]
|
||||||
|
|
||||||
|
# ...and then an incremental sync. This should block until the sync stream is woken up,
|
||||||
|
# which we hope will happen as a result of Alice updating their device list.
|
||||||
|
bob_sync_channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"{self.sync_endpoint}?since={next_batch_token}&timeout=1000",
|
||||||
|
access_token=bob_access_token,
|
||||||
|
# Start the request, then continue on.
|
||||||
|
await_result=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Have alice update their device list
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
f"/devices/{test_device_id}",
|
||||||
|
{
|
||||||
|
"display_name": "New Device Name",
|
||||||
|
},
|
||||||
|
access_token=alice_access_token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200, channel.json_body)
|
||||||
|
|
||||||
|
# Check that bob's incremental sync does not contain the updated device list.
|
||||||
|
bob_sync_channel.await_result()
|
||||||
|
self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
|
||||||
|
|
||||||
|
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
|
||||||
|
"changed", []
|
||||||
|
)
|
||||||
|
self.assertNotIn(
|
||||||
|
alice_user_id, changed_device_lists, bob_sync_channel.json_body
|
||||||
|
)
|
||||||
|
|
||||||
def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None:
|
def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None:
|
||||||
"""Tests that a user with no rooms still receives their own device list updates"""
|
"""Tests that a user with no rooms still receives their own device list updates"""
|
||||||
device_id = "TESTDEVICE"
|
test_device_id = "TESTDEVICE"
|
||||||
|
|
||||||
# Register a user and login, creating a device
|
# Register a user and login, creating a device
|
||||||
self.user_id = self.register_user("kermit", "monkey")
|
alice_user_id = self.register_user("alice", "correcthorse")
|
||||||
self.tok = self.login("kermit", "monkey", device_id=device_id)
|
alice_access_token = self.login(
|
||||||
|
alice_user_id, "correcthorse", device_id=test_device_id
|
||||||
|
)
|
||||||
|
|
||||||
# Request an initial sync
|
# Request an initial sync
|
||||||
channel = self.make_request("GET", "/sync", access_token=self.tok)
|
channel = self.make_request(
|
||||||
|
"GET", self.sync_endpoint, access_token=alice_access_token
|
||||||
|
)
|
||||||
self.assertEqual(channel.code, 200, channel.json_body)
|
self.assertEqual(channel.code, 200, channel.json_body)
|
||||||
next_batch = channel.json_body["next_batch"]
|
next_batch = channel.json_body["next_batch"]
|
||||||
|
|
||||||
|
@ -713,19 +869,19 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
|
||||||
# It won't return until something has happened
|
# It won't return until something has happened
|
||||||
incremental_sync_channel = self.make_request(
|
incremental_sync_channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/sync?since={next_batch}&timeout=30000",
|
f"{self.sync_endpoint}?since={next_batch}&timeout=30000",
|
||||||
access_token=self.tok,
|
access_token=alice_access_token,
|
||||||
await_result=False,
|
await_result=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Change our device's display name
|
# Change our device's display name
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"PUT",
|
"PUT",
|
||||||
f"devices/{device_id}",
|
f"devices/{test_device_id}",
|
||||||
{
|
{
|
||||||
"display_name": "freeze ray",
|
"display_name": "freeze ray",
|
||||||
},
|
},
|
||||||
access_token=self.tok,
|
access_token=alice_access_token,
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 200, channel.json_body)
|
self.assertEqual(channel.code, 200, channel.json_body)
|
||||||
|
|
||||||
|
@ -739,7 +895,230 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
|
||||||
).get("changed", [])
|
).get("changed", [])
|
||||||
|
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
self.user_id, device_list_changes, incremental_sync_channel.json_body
|
alice_user_id, device_list_changes, incremental_sync_channel.json_body
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterized_class(
|
||||||
|
("sync_endpoint", "experimental_features"),
|
||||||
|
[
|
||||||
|
("/sync", {}),
|
||||||
|
(
|
||||||
|
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
|
||||||
|
# Enable sliding sync
|
||||||
|
{"msc3575_enabled": True},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
class DeviceOneTimeKeysSyncTestCase(unittest.HomeserverTestCase):
|
||||||
|
"""
|
||||||
|
Tests regarding device one time keys (`device_one_time_keys_count`) changes.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
sync_endpoint: The endpoint under test to use for syncing.
|
||||||
|
experimental_features: The experimental features homeserver config to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sync_endpoint: str
|
||||||
|
experimental_features: JsonDict
|
||||||
|
|
||||||
|
servlets = [
|
||||||
|
synapse.rest.admin.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
sync.register_servlets,
|
||||||
|
devices.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def default_config(self) -> JsonDict:
|
||||||
|
config = super().default_config()
|
||||||
|
config["experimental_features"] = self.experimental_features
|
||||||
|
return config
|
||||||
|
|
||||||
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||||
|
|
||||||
|
def test_no_device_one_time_keys(self) -> None:
|
||||||
|
"""
|
||||||
|
Tests when no one time keys set, it still has the default `signed_curve25519` in
|
||||||
|
`device_one_time_keys_count`
|
||||||
|
"""
|
||||||
|
test_device_id = "TESTDEVICE"
|
||||||
|
|
||||||
|
alice_user_id = self.register_user("alice", "correcthorse")
|
||||||
|
alice_access_token = self.login(
|
||||||
|
alice_user_id, "correcthorse", device_id=test_device_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Request an initial sync
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET", self.sync_endpoint, access_token=alice_access_token
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200, channel.json_body)
|
||||||
|
|
||||||
|
# Check for those one time key counts
|
||||||
|
self.assertDictEqual(
|
||||||
|
channel.json_body["device_one_time_keys_count"],
|
||||||
|
# Note that "signed_curve25519" is always returned in key count responses
|
||||||
|
# regardless of whether we uploaded any keys for it. This is necessary until
|
||||||
|
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
|
||||||
|
{"signed_curve25519": 0},
|
||||||
|
channel.json_body["device_one_time_keys_count"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_returns_device_one_time_keys(self) -> None:
|
||||||
|
"""
|
||||||
|
Tests that one time keys for the device/user are counted correctly in the `/sync`
|
||||||
|
response
|
||||||
|
"""
|
||||||
|
test_device_id = "TESTDEVICE"
|
||||||
|
|
||||||
|
alice_user_id = self.register_user("alice", "correcthorse")
|
||||||
|
alice_access_token = self.login(
|
||||||
|
alice_user_id, "correcthorse", device_id=test_device_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Upload one time keys for the user/device
|
||||||
|
keys: JsonDict = {
|
||||||
|
"alg1:k1": "key1",
|
||||||
|
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
|
||||||
|
"alg2:k3": {"key": "key3"},
|
||||||
|
}
|
||||||
|
res = self.get_success(
|
||||||
|
self.e2e_keys_handler.upload_keys_for_user(
|
||||||
|
alice_user_id, test_device_id, {"one_time_keys": keys}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Note that "signed_curve25519" is always returned in key count responses
|
||||||
|
# regardless of whether we uploaded any keys for it. This is necessary until
|
||||||
|
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
|
||||||
|
self.assertDictEqual(
|
||||||
|
res,
|
||||||
|
{"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Request an initial sync
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET", self.sync_endpoint, access_token=alice_access_token
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200, channel.json_body)
|
||||||
|
|
||||||
|
# Check for those one time key counts
|
||||||
|
self.assertDictEqual(
|
||||||
|
channel.json_body["device_one_time_keys_count"],
|
||||||
|
{"alg1": 1, "alg2": 2, "signed_curve25519": 0},
|
||||||
|
channel.json_body["device_one_time_keys_count"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterized_class(
|
||||||
|
("sync_endpoint", "experimental_features"),
|
||||||
|
[
|
||||||
|
("/sync", {}),
|
||||||
|
(
|
||||||
|
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
|
||||||
|
# Enable sliding sync
|
||||||
|
{"msc3575_enabled": True},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase):
|
||||||
|
"""
|
||||||
|
Tests regarding device one time keys (`device_unused_fallback_key_types`) changes.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
sync_endpoint: The endpoint under test to use for syncing.
|
||||||
|
experimental_features: The experimental features homeserver config to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sync_endpoint: str
|
||||||
|
experimental_features: JsonDict
|
||||||
|
|
||||||
|
servlets = [
|
||||||
|
synapse.rest.admin.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
sync.register_servlets,
|
||||||
|
devices.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def default_config(self) -> JsonDict:
|
||||||
|
config = super().default_config()
|
||||||
|
config["experimental_features"] = self.experimental_features
|
||||||
|
return config
|
||||||
|
|
||||||
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
self.store = self.hs.get_datastores().main
|
||||||
|
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||||
|
|
||||||
|
def test_no_device_unused_fallback_key(self) -> None:
|
||||||
|
"""
|
||||||
|
Test when no unused fallback key is set, it just returns an empty list. The MSC
|
||||||
|
says "The device_unused_fallback_key_types parameter must be present if the
|
||||||
|
server supports fallback keys.",
|
||||||
|
https://github.com/matrix-org/matrix-spec-proposals/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
|
||||||
|
"""
|
||||||
|
test_device_id = "TESTDEVICE"
|
||||||
|
|
||||||
|
alice_user_id = self.register_user("alice", "correcthorse")
|
||||||
|
alice_access_token = self.login(
|
||||||
|
alice_user_id, "correcthorse", device_id=test_device_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Request an initial sync
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET", self.sync_endpoint, access_token=alice_access_token
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200, channel.json_body)
|
||||||
|
|
||||||
|
# Check for those one time key counts
|
||||||
|
self.assertListEqual(
|
||||||
|
channel.json_body["device_unused_fallback_key_types"],
|
||||||
|
[],
|
||||||
|
channel.json_body["device_unused_fallback_key_types"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_returns_device_one_time_keys(self) -> None:
|
||||||
|
"""
|
||||||
|
Tests that device unused fallback key type is returned correctly in the `/sync`
|
||||||
|
"""
|
||||||
|
test_device_id = "TESTDEVICE"
|
||||||
|
|
||||||
|
alice_user_id = self.register_user("alice", "correcthorse")
|
||||||
|
alice_access_token = self.login(
|
||||||
|
alice_user_id, "correcthorse", device_id=test_device_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# We shouldn't have any unused fallback keys yet
|
||||||
|
res = self.get_success(
|
||||||
|
self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
|
||||||
|
)
|
||||||
|
self.assertEqual(res, [])
|
||||||
|
|
||||||
|
# Upload a fallback key for the user/device
|
||||||
|
fallback_key = {"alg1:k1": "fallback_key1"}
|
||||||
|
self.get_success(
|
||||||
|
self.e2e_keys_handler.upload_keys_for_user(
|
||||||
|
alice_user_id,
|
||||||
|
test_device_id,
|
||||||
|
{"fallback_keys": fallback_key},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# We should now have an unused alg1 key
|
||||||
|
fallback_res = self.get_success(
|
||||||
|
self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
|
||||||
|
)
|
||||||
|
self.assertEqual(fallback_res, ["alg1"], fallback_res)
|
||||||
|
|
||||||
|
# Request an initial sync
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET", self.sync_endpoint, access_token=alice_access_token
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200, channel.json_body)
|
||||||
|
|
||||||
|
# Check for the unused fallback key types
|
||||||
|
self.assertListEqual(
|
||||||
|
channel.json_body["device_unused_fallback_key_types"],
|
||||||
|
["alg1"],
|
||||||
|
channel.json_body["device_unused_fallback_key_types"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -44,13 +44,13 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
|
||||||
# from a regular 404.
|
# from a regular 404.
|
||||||
file_id = "abcdefg12345"
|
file_id = "abcdefg12345"
|
||||||
file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id)
|
file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id)
|
||||||
with hs.get_media_repository().media_storage.store_into_file(file_info) as (
|
|
||||||
f,
|
media_storage = hs.get_media_repository().media_storage
|
||||||
fname,
|
|
||||||
finish,
|
ctx = media_storage.store_into_file(file_info)
|
||||||
):
|
(f, fname) = self.get_success(ctx.__aenter__())
|
||||||
f.write(SMALL_PNG)
|
f.write(SMALL_PNG)
|
||||||
self.get_success(finish())
|
self.get_success(ctx.__aexit__(None, None, None))
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.store_cached_remote_media(
|
self.store.store_cached_remote_media(
|
||||||
|
|
|
@ -30,163 +30,34 @@ from synapse.storage.database import (
|
||||||
)
|
)
|
||||||
from synapse.storage.engines import IncorrectDatabaseSetup
|
from synapse.storage.engines import IncorrectDatabaseSetup
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
|
from synapse.storage.util.sequence import (
|
||||||
|
LocalSequenceGenerator,
|
||||||
|
PostgresSequenceGenerator,
|
||||||
|
SequenceGenerator,
|
||||||
|
)
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
from tests.utils import USE_POSTGRES_FOR_TESTS
|
from tests.utils import USE_POSTGRES_FOR_TESTS
|
||||||
|
|
||||||
|
|
||||||
class StreamIdGeneratorTestCase(HomeserverTestCase):
|
class MultiWriterIdGeneratorBase(HomeserverTestCase):
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.db_pool: DatabasePool = self.store.db_pool
|
self.db_pool: DatabasePool = self.store.db_pool
|
||||||
|
|
||||||
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
|
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
|
||||||
|
|
||||||
def _setup_db(self, txn: LoggingTransaction) -> None:
|
if USE_POSTGRES_FOR_TESTS:
|
||||||
txn.execute(
|
self.seq_gen: SequenceGenerator = PostgresSequenceGenerator("foobar_seq")
|
||||||
"""
|
else:
|
||||||
CREATE TABLE foobar (
|
self.seq_gen = LocalSequenceGenerator(lambda _: 0)
|
||||||
stream_id BIGINT NOT NULL,
|
|
||||||
data TEXT
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
txn.execute("INSERT INTO foobar VALUES (123, 'hello world');")
|
|
||||||
|
|
||||||
def _create_id_generator(self) -> StreamIdGenerator:
|
|
||||||
def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator:
|
|
||||||
return StreamIdGenerator(
|
|
||||||
db_conn=conn,
|
|
||||||
notifier=self.hs.get_replication_notifier(),
|
|
||||||
table="foobar",
|
|
||||||
column="stream_id",
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
|
|
||||||
|
|
||||||
def test_initial_value(self) -> None:
|
|
||||||
"""Check that we read the current token from the DB."""
|
|
||||||
id_gen = self._create_id_generator()
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 123)
|
|
||||||
|
|
||||||
def test_single_gen_next(self) -> None:
|
|
||||||
"""Check that we correctly increment the current token from the DB."""
|
|
||||||
id_gen = self._create_id_generator()
|
|
||||||
|
|
||||||
async def test_gen_next() -> None:
|
|
||||||
async with id_gen.get_next() as next_id:
|
|
||||||
# We haven't persisted `next_id` yet; current token is still 123
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 123)
|
|
||||||
# But we did learn what the next value is
|
|
||||||
self.assertEqual(next_id, 124)
|
|
||||||
|
|
||||||
# Once the context manager closes we assume that the `next_id` has been
|
|
||||||
# written to the DB.
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 124)
|
|
||||||
|
|
||||||
self.get_success(test_gen_next())
|
|
||||||
|
|
||||||
def test_multiple_gen_nexts(self) -> None:
|
|
||||||
"""Check that we handle overlapping calls to gen_next sensibly."""
|
|
||||||
id_gen = self._create_id_generator()
|
|
||||||
|
|
||||||
async def test_gen_next() -> None:
|
|
||||||
ctx1 = id_gen.get_next()
|
|
||||||
ctx2 = id_gen.get_next()
|
|
||||||
ctx3 = id_gen.get_next()
|
|
||||||
|
|
||||||
# Request three new stream IDs.
|
|
||||||
self.assertEqual(await ctx1.__aenter__(), 124)
|
|
||||||
self.assertEqual(await ctx2.__aenter__(), 125)
|
|
||||||
self.assertEqual(await ctx3.__aenter__(), 126)
|
|
||||||
|
|
||||||
# None are persisted: current token unchanged.
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 123)
|
|
||||||
|
|
||||||
# Persist each in turn.
|
|
||||||
await ctx1.__aexit__(None, None, None)
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 124)
|
|
||||||
await ctx2.__aexit__(None, None, None)
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 125)
|
|
||||||
await ctx3.__aexit__(None, None, None)
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 126)
|
|
||||||
|
|
||||||
self.get_success(test_gen_next())
|
|
||||||
|
|
||||||
def test_multiple_gen_nexts_closed_in_different_order(self) -> None:
|
|
||||||
"""Check that we handle overlapping calls to gen_next, even when their IDs
|
|
||||||
created and persisted in different orders."""
|
|
||||||
id_gen = self._create_id_generator()
|
|
||||||
|
|
||||||
async def test_gen_next() -> None:
|
|
||||||
ctx1 = id_gen.get_next()
|
|
||||||
ctx2 = id_gen.get_next()
|
|
||||||
ctx3 = id_gen.get_next()
|
|
||||||
|
|
||||||
# Request three new stream IDs.
|
|
||||||
self.assertEqual(await ctx1.__aenter__(), 124)
|
|
||||||
self.assertEqual(await ctx2.__aenter__(), 125)
|
|
||||||
self.assertEqual(await ctx3.__aenter__(), 126)
|
|
||||||
|
|
||||||
# None are persisted: current token unchanged.
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 123)
|
|
||||||
|
|
||||||
# Persist them in a different order, starting with 126 from ctx3.
|
|
||||||
await ctx3.__aexit__(None, None, None)
|
|
||||||
# We haven't persisted 124 from ctx1 yet---current token is still 123.
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 123)
|
|
||||||
|
|
||||||
# Now persist 124 from ctx1.
|
|
||||||
await ctx1.__aexit__(None, None, None)
|
|
||||||
# Current token is then 124, waiting for 125 to be persisted.
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 124)
|
|
||||||
|
|
||||||
# Finally persist 125 from ctx2.
|
|
||||||
await ctx2.__aexit__(None, None, None)
|
|
||||||
# Current token is then 126 (skipping over 125).
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 126)
|
|
||||||
|
|
||||||
self.get_success(test_gen_next())
|
|
||||||
|
|
||||||
def test_gen_next_while_still_waiting_for_persistence(self) -> None:
|
|
||||||
"""Check that we handle overlapping calls to gen_next."""
|
|
||||||
id_gen = self._create_id_generator()
|
|
||||||
|
|
||||||
async def test_gen_next() -> None:
|
|
||||||
ctx1 = id_gen.get_next()
|
|
||||||
ctx2 = id_gen.get_next()
|
|
||||||
ctx3 = id_gen.get_next()
|
|
||||||
|
|
||||||
# Request two new stream IDs.
|
|
||||||
self.assertEqual(await ctx1.__aenter__(), 124)
|
|
||||||
self.assertEqual(await ctx2.__aenter__(), 125)
|
|
||||||
|
|
||||||
# Persist ctx2 first.
|
|
||||||
await ctx2.__aexit__(None, None, None)
|
|
||||||
# Still waiting on ctx1's ID to be persisted.
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 123)
|
|
||||||
|
|
||||||
# Now request a third stream ID. It should be 126 (the smallest ID that
|
|
||||||
# we've not yet handed out.)
|
|
||||||
self.assertEqual(await ctx3.__aenter__(), 126)
|
|
||||||
|
|
||||||
self.get_success(test_gen_next())
|
|
||||||
|
|
||||||
|
|
||||||
class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|
||||||
if not USE_POSTGRES_FOR_TESTS:
|
|
||||||
skip = "Requires Postgres"
|
|
||||||
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
|
||||||
self.store = hs.get_datastores().main
|
|
||||||
self.db_pool: DatabasePool = self.store.db_pool
|
|
||||||
|
|
||||||
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
|
|
||||||
|
|
||||||
def _setup_db(self, txn: LoggingTransaction) -> None:
|
def _setup_db(self, txn: LoggingTransaction) -> None:
|
||||||
txn.execute("CREATE SEQUENCE foobar_seq")
|
if USE_POSTGRES_FOR_TESTS:
|
||||||
|
txn.execute("CREATE SEQUENCE foobar_seq")
|
||||||
|
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""
|
"""
|
||||||
CREATE TABLE foobar (
|
CREATE TABLE foobar (
|
||||||
|
@ -221,44 +92,27 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
def _insert(txn: LoggingTransaction) -> None:
|
def _insert(txn: LoggingTransaction) -> None:
|
||||||
for _ in range(number):
|
for _ in range(number):
|
||||||
|
next_val = self.seq_gen.get_next_id_txn(txn)
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
|
"INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)",
|
||||||
(instance_name,),
|
(
|
||||||
|
next_val,
|
||||||
|
instance_name,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
|
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
|
||||||
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
|
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
|
||||||
""",
|
""",
|
||||||
(instance_name,),
|
(instance_name, next_val, next_val),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
|
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
|
||||||
|
|
||||||
def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None:
|
|
||||||
"""Insert one row as the given instance with given stream_id, updating
|
|
||||||
the postgres sequence position to match.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _insert(txn: LoggingTransaction) -> None:
|
|
||||||
txn.execute(
|
|
||||||
"INSERT INTO foobar VALUES (?, ?)",
|
|
||||||
(
|
|
||||||
stream_id,
|
|
||||||
instance_name,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
|
|
||||||
txn.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
|
|
||||||
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
|
|
||||||
""",
|
|
||||||
(instance_name, stream_id, stream_id),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
|
|
||||||
|
|
||||||
|
class MultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
|
||||||
def test_empty(self) -> None:
|
def test_empty(self) -> None:
|
||||||
"""Test an ID generator against an empty database gives sensible
|
"""Test an ID generator against an empty database gives sensible
|
||||||
current positions.
|
current positions.
|
||||||
|
@ -347,6 +201,176 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
self.assertEqual(id_gen.get_positions(), {"master": 11})
|
self.assertEqual(id_gen.get_positions(), {"master": 11})
|
||||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
|
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
|
||||||
|
|
||||||
|
def test_get_next_txn(self) -> None:
|
||||||
|
"""Test that the `get_next_txn` function works correctly."""
|
||||||
|
|
||||||
|
# Prefill table with 7 rows written by 'master'
|
||||||
|
self._insert_rows("master", 7)
|
||||||
|
|
||||||
|
id_gen = self._create_id_generator()
|
||||||
|
|
||||||
|
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||||
|
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||||
|
|
||||||
|
# Try allocating a new ID gen and check that we only see position
|
||||||
|
# advanced after we leave the context manager.
|
||||||
|
|
||||||
|
def _get_next_txn(txn: LoggingTransaction) -> None:
|
||||||
|
stream_id = id_gen.get_next_txn(txn)
|
||||||
|
self.assertEqual(stream_id, 8)
|
||||||
|
|
||||||
|
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||||
|
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||||
|
|
||||||
|
self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
|
||||||
|
|
||||||
|
self.assertEqual(id_gen.get_positions(), {"master": 8})
|
||||||
|
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
|
||||||
|
|
||||||
|
def test_restart_during_out_of_order_persistence(self) -> None:
|
||||||
|
"""Test that restarting a process while another process is writing out
|
||||||
|
of order updates are handled correctly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Prefill table with 7 rows written by 'master'
|
||||||
|
self._insert_rows("master", 7)
|
||||||
|
|
||||||
|
id_gen = self._create_id_generator()
|
||||||
|
|
||||||
|
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||||
|
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||||
|
|
||||||
|
# Persist two rows at once
|
||||||
|
ctx1 = id_gen.get_next()
|
||||||
|
ctx2 = id_gen.get_next()
|
||||||
|
|
||||||
|
s1 = self.get_success(ctx1.__aenter__())
|
||||||
|
s2 = self.get_success(ctx2.__aenter__())
|
||||||
|
|
||||||
|
self.assertEqual(s1, 8)
|
||||||
|
self.assertEqual(s2, 9)
|
||||||
|
|
||||||
|
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||||
|
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||||
|
|
||||||
|
# We finish persisting the second row before restart
|
||||||
|
self.get_success(ctx2.__aexit__(None, None, None))
|
||||||
|
|
||||||
|
# We simulate a restart of another worker by just creating a new ID gen.
|
||||||
|
id_gen_worker = self._create_id_generator("worker")
|
||||||
|
|
||||||
|
# Restarted worker should not see the second persisted row
|
||||||
|
self.assertEqual(id_gen_worker.get_positions(), {"master": 7})
|
||||||
|
self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7)
|
||||||
|
|
||||||
|
# Now if we persist the first row then both instances should jump ahead
|
||||||
|
# correctly.
|
||||||
|
self.get_success(ctx1.__aexit__(None, None, None))
|
||||||
|
|
||||||
|
self.assertEqual(id_gen.get_positions(), {"master": 9})
|
||||||
|
id_gen_worker.advance("master", 9)
|
||||||
|
self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
|
||||||
|
if not USE_POSTGRES_FOR_TESTS:
|
||||||
|
skip = "Requires Postgres"
|
||||||
|
|
||||||
|
def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None:
|
||||||
|
"""Insert one row as the given instance with given stream_id, updating
|
||||||
|
the postgres sequence position to match.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _insert(txn: LoggingTransaction) -> None:
|
||||||
|
txn.execute(
|
||||||
|
"INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)",
|
||||||
|
(
|
||||||
|
stream_id,
|
||||||
|
instance_name,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
|
||||||
|
|
||||||
|
txn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
|
||||||
|
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
|
||||||
|
""",
|
||||||
|
(instance_name, stream_id, stream_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
|
||||||
|
|
||||||
|
def test_get_persisted_upto_position(self) -> None:
|
||||||
|
"""Test that `get_persisted_upto_position` correctly tracks updates to
|
||||||
|
positions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The following tests are a bit cheeky in that we notify about new
|
||||||
|
# positions via `advance` without *actually* advancing the postgres
|
||||||
|
# sequence.
|
||||||
|
|
||||||
|
self._insert_row_with_id("first", 3)
|
||||||
|
self._insert_row_with_id("second", 5)
|
||||||
|
|
||||||
|
id_gen = self._create_id_generator("worker", writers=["first", "second"])
|
||||||
|
|
||||||
|
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
|
||||||
|
|
||||||
|
# Min is 3 and there is a gap between 5, so we expect it to be 3.
|
||||||
|
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
|
||||||
|
|
||||||
|
# We advance "first" straight to 6. Min is now 5 but there is no gap so
|
||||||
|
# we expect it to be 6
|
||||||
|
id_gen.advance("first", 6)
|
||||||
|
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
|
||||||
|
|
||||||
|
# No gap, so we expect 7.
|
||||||
|
id_gen.advance("second", 7)
|
||||||
|
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
|
||||||
|
|
||||||
|
# We haven't seen 8 yet, so we expect 7 still.
|
||||||
|
id_gen.advance("second", 9)
|
||||||
|
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
|
||||||
|
|
||||||
|
# Now that we've seen 7, 8 and 9 we can got straight to 9.
|
||||||
|
id_gen.advance("first", 8)
|
||||||
|
self.assertEqual(id_gen.get_persisted_upto_position(), 9)
|
||||||
|
|
||||||
|
# Jump forward with gaps. The minimum is 11, even though we haven't seen
|
||||||
|
# 10 we know that everything before 11 must be persisted.
|
||||||
|
id_gen.advance("first", 11)
|
||||||
|
id_gen.advance("second", 15)
|
||||||
|
self.assertEqual(id_gen.get_persisted_upto_position(), 11)
|
||||||
|
|
||||||
|
def test_get_persisted_upto_position_get_next(self) -> None:
|
||||||
|
"""Test that `get_persisted_upto_position` correctly tracks updates to
|
||||||
|
positions when `get_next` is called.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._insert_row_with_id("first", 3)
|
||||||
|
self._insert_row_with_id("second", 5)
|
||||||
|
|
||||||
|
id_gen = self._create_id_generator("first", writers=["first", "second"])
|
||||||
|
|
||||||
|
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
|
||||||
|
|
||||||
|
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
|
||||||
|
|
||||||
|
async def _get_next_async() -> None:
|
||||||
|
async with id_gen.get_next() as stream_id:
|
||||||
|
self.assertEqual(stream_id, 6)
|
||||||
|
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
|
||||||
|
|
||||||
|
self.get_success(_get_next_async())
|
||||||
|
|
||||||
|
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
|
||||||
|
|
||||||
|
# We assume that so long as `get_next` does correctly advance the
|
||||||
|
# `persisted_upto_position` in this case, then it will be correct in the
|
||||||
|
# other cases that are tested above (since they'll hit the same code).
|
||||||
|
|
||||||
def test_multi_instance(self) -> None:
|
def test_multi_instance(self) -> None:
|
||||||
"""Test that reads and writes from multiple processes are handled
|
"""Test that reads and writes from multiple processes are handled
|
||||||
correctly.
|
correctly.
|
||||||
|
@ -453,145 +477,6 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8}
|
third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_next_txn(self) -> None:
|
|
||||||
"""Test that the `get_next_txn` function works correctly."""
|
|
||||||
|
|
||||||
# Prefill table with 7 rows written by 'master'
|
|
||||||
self._insert_rows("master", 7)
|
|
||||||
|
|
||||||
id_gen = self._create_id_generator()
|
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
|
||||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
|
||||||
|
|
||||||
# Try allocating a new ID gen and check that we only see position
|
|
||||||
# advanced after we leave the context manager.
|
|
||||||
|
|
||||||
def _get_next_txn(txn: LoggingTransaction) -> None:
|
|
||||||
stream_id = id_gen.get_next_txn(txn)
|
|
||||||
self.assertEqual(stream_id, 8)
|
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
|
||||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
|
||||||
|
|
||||||
self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
|
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_positions(), {"master": 8})
|
|
||||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
|
|
||||||
|
|
||||||
def test_get_persisted_upto_position(self) -> None:
|
|
||||||
"""Test that `get_persisted_upto_position` correctly tracks updates to
|
|
||||||
positions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# The following tests are a bit cheeky in that we notify about new
|
|
||||||
# positions via `advance` without *actually* advancing the postgres
|
|
||||||
# sequence.
|
|
||||||
|
|
||||||
self._insert_row_with_id("first", 3)
|
|
||||||
self._insert_row_with_id("second", 5)
|
|
||||||
|
|
||||||
id_gen = self._create_id_generator("worker", writers=["first", "second"])
|
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
|
|
||||||
|
|
||||||
# Min is 3 and there is a gap between 5, so we expect it to be 3.
|
|
||||||
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
|
|
||||||
|
|
||||||
# We advance "first" straight to 6. Min is now 5 but there is no gap so
|
|
||||||
# we expect it to be 6
|
|
||||||
id_gen.advance("first", 6)
|
|
||||||
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
|
|
||||||
|
|
||||||
# No gap, so we expect 7.
|
|
||||||
id_gen.advance("second", 7)
|
|
||||||
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
|
|
||||||
|
|
||||||
# We haven't seen 8 yet, so we expect 7 still.
|
|
||||||
id_gen.advance("second", 9)
|
|
||||||
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
|
|
||||||
|
|
||||||
# Now that we've seen 7, 8 and 9 we can got straight to 9.
|
|
||||||
id_gen.advance("first", 8)
|
|
||||||
self.assertEqual(id_gen.get_persisted_upto_position(), 9)
|
|
||||||
|
|
||||||
# Jump forward with gaps. The minimum is 11, even though we haven't seen
|
|
||||||
# 10 we know that everything before 11 must be persisted.
|
|
||||||
id_gen.advance("first", 11)
|
|
||||||
id_gen.advance("second", 15)
|
|
||||||
self.assertEqual(id_gen.get_persisted_upto_position(), 11)
|
|
||||||
|
|
||||||
def test_get_persisted_upto_position_get_next(self) -> None:
|
|
||||||
"""Test that `get_persisted_upto_position` correctly tracks updates to
|
|
||||||
positions when `get_next` is called.
|
|
||||||
"""
|
|
||||||
|
|
||||||
self._insert_row_with_id("first", 3)
|
|
||||||
self._insert_row_with_id("second", 5)
|
|
||||||
|
|
||||||
id_gen = self._create_id_generator("first", writers=["first", "second"])
|
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
|
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
|
|
||||||
|
|
||||||
async def _get_next_async() -> None:
|
|
||||||
async with id_gen.get_next() as stream_id:
|
|
||||||
self.assertEqual(stream_id, 6)
|
|
||||||
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
|
|
||||||
|
|
||||||
self.get_success(_get_next_async())
|
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
|
|
||||||
|
|
||||||
# We assume that so long as `get_next` does correctly advance the
|
|
||||||
# `persisted_upto_position` in this case, then it will be correct in the
|
|
||||||
# other cases that are tested above (since they'll hit the same code).
|
|
||||||
|
|
||||||
def test_restart_during_out_of_order_persistence(self) -> None:
|
|
||||||
"""Test that restarting a process while another process is writing out
|
|
||||||
of order updates are handled correctly.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Prefill table with 7 rows written by 'master'
|
|
||||||
self._insert_rows("master", 7)
|
|
||||||
|
|
||||||
id_gen = self._create_id_generator()
|
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
|
||||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
|
||||||
|
|
||||||
# Persist two rows at once
|
|
||||||
ctx1 = id_gen.get_next()
|
|
||||||
ctx2 = id_gen.get_next()
|
|
||||||
|
|
||||||
s1 = self.get_success(ctx1.__aenter__())
|
|
||||||
s2 = self.get_success(ctx2.__aenter__())
|
|
||||||
|
|
||||||
self.assertEqual(s1, 8)
|
|
||||||
self.assertEqual(s2, 9)
|
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
|
||||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
|
||||||
|
|
||||||
# We finish persisting the second row before restart
|
|
||||||
self.get_success(ctx2.__aexit__(None, None, None))
|
|
||||||
|
|
||||||
# We simulate a restart of another worker by just creating a new ID gen.
|
|
||||||
id_gen_worker = self._create_id_generator("worker")
|
|
||||||
|
|
||||||
# Restarted worker should not see the second persisted row
|
|
||||||
self.assertEqual(id_gen_worker.get_positions(), {"master": 7})
|
|
||||||
self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7)
|
|
||||||
|
|
||||||
# Now if we persist the first row then both instances should jump ahead
|
|
||||||
# correctly.
|
|
||||||
self.get_success(ctx1.__aexit__(None, None, None))
|
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_positions(), {"master": 9})
|
|
||||||
id_gen_worker.advance("master", 9)
|
|
||||||
self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
|
|
||||||
|
|
||||||
def test_writer_config_change(self) -> None:
|
def test_writer_config_change(self) -> None:
|
||||||
"""Test that changing the writer config correctly works."""
|
"""Test that changing the writer config correctly works."""
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue