mirror of
https://github.com/element-hq/synapse
synced 2024-10-01 12:12:40 +00:00
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
This commit is contained in:
commit
f5ab7d8306
43 changed files with 1337 additions and 753 deletions
1
changelog.d/9086.feature
Normal file
1
changelog.d/9086.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add an admin API for protecting local media from quarantine.
|
1
changelog.d/9093.misc
Normal file
1
changelog.d/9093.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add type hints to media repository.
|
1
changelog.d/9108.bugfix
Normal file
1
changelog.d/9108.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix "Unhandled error in Deferred: BodyExceededMaxSize" errors when .well-known files that are too large.
|
1
changelog.d/9110.feature
Normal file
1
changelog.d/9110.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add support for multiple SSO Identity Providers.
|
1
changelog.d/9117.bugfix
Normal file
1
changelog.d/9117.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix corruption of `pushers` data when a postgres bouncer is used.
|
1
changelog.d/9124.misc
Normal file
1
changelog.d/9124.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Improve efficiency of large state resolutions.
|
1
changelog.d/9125.misc
Normal file
1
changelog.d/9125.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Remove dependency on `distutils`.
|
1
changelog.d/9130.feature
Normal file
1
changelog.d/9130.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add experimental support for handling and persistence of to-device messages to happen on worker processes.
|
6
debian/changelog
vendored
6
debian/changelog
vendored
|
@ -1,3 +1,9 @@
|
|||
matrix-synapse-py3 (1.25.0ubuntu1) UNRELEASED; urgency=medium
|
||||
|
||||
* Remove dependency on `python3-distutils`.
|
||||
|
||||
-- Richard van der Hoff <richard@matrix.org> Fri, 15 Jan 2021 12:44:19 +0000
|
||||
|
||||
matrix-synapse-py3 (1.25.0) stable; urgency=medium
|
||||
|
||||
[ Dan Callahan ]
|
||||
|
|
1
debian/control
vendored
1
debian/control
vendored
|
@ -31,7 +31,6 @@ Pre-Depends: dpkg (>= 1.16.1)
|
|||
Depends:
|
||||
adduser,
|
||||
debconf,
|
||||
python3-distutils|libpython3-stdlib (<< 3.6),
|
||||
${misc:Depends},
|
||||
${shlibs:Depends},
|
||||
${synapse:pydepends},
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
* [Quarantining media by ID](#quarantining-media-by-id)
|
||||
* [Quarantining media in a room](#quarantining-media-in-a-room)
|
||||
* [Quarantining all media of a user](#quarantining-all-media-of-a-user)
|
||||
* [Protecting media from being quarantined](#protecting-media-from-being-quarantined)
|
||||
- [Delete local media](#delete-local-media)
|
||||
* [Delete a specific local media](#delete-a-specific-local-media)
|
||||
* [Delete local media by date or size](#delete-local-media-by-date-or-size)
|
||||
|
@ -123,6 +124,29 @@ The following fields are returned in the JSON response body:
|
|||
|
||||
* `num_quarantined`: integer - The number of media items successfully quarantined
|
||||
|
||||
## Protecting media from being quarantined
|
||||
|
||||
This API protects a single piece of local media from being quarantined using the
|
||||
above APIs. This is useful for sticker packs and other shared media which you do
|
||||
not want to get quarantined, especially when
|
||||
[quarantining media in a room](#quarantining-media-in-a-room).
|
||||
|
||||
Request:
|
||||
|
||||
```
|
||||
POST /_synapse/admin/v1/media/protect/<media_id>
|
||||
|
||||
{}
|
||||
```
|
||||
|
||||
Where `media_id` is in the form of `abcdefg12345...`.
|
||||
|
||||
Response:
|
||||
|
||||
```json
|
||||
{}
|
||||
```
|
||||
|
||||
# Delete local media
|
||||
This API deletes the *local* media from the disk of your own server.
|
||||
This includes any local thumbnails and copies of media downloaded from
|
||||
|
|
199
docs/openid.md
199
docs/openid.md
|
@ -42,11 +42,10 @@ as follows:
|
|||
* For other installation mechanisms, see the documentation provided by the
|
||||
maintainer.
|
||||
|
||||
To enable the OpenID integration, you should then add an `oidc_config` section
|
||||
to your configuration file (or uncomment the `enabled: true` line in the
|
||||
existing section). See [sample_config.yaml](./sample_config.yaml) for some
|
||||
sample settings, as well as the text below for example configurations for
|
||||
specific providers.
|
||||
To enable the OpenID integration, you should then add a section to the `oidc_providers`
|
||||
setting in your configuration file (or uncomment one of the existing examples).
|
||||
See [sample_config.yaml](./sample_config.yaml) for some sample settings, as well as
|
||||
the text below for example configurations for specific providers.
|
||||
|
||||
## Sample configs
|
||||
|
||||
|
@ -62,20 +61,21 @@ Directory (tenant) ID as it will be used in the Azure links.
|
|||
Edit your Synapse config file and change the `oidc_config` section:
|
||||
|
||||
```yaml
|
||||
oidc_config:
|
||||
enabled: true
|
||||
issuer: "https://login.microsoftonline.com/<tenant id>/v2.0"
|
||||
client_id: "<client id>"
|
||||
client_secret: "<client secret>"
|
||||
scopes: ["openid", "profile"]
|
||||
authorization_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/authorize"
|
||||
token_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/token"
|
||||
userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo"
|
||||
oidc_providers:
|
||||
- idp_id: microsoft
|
||||
idp_name: Microsoft
|
||||
issuer: "https://login.microsoftonline.com/<tenant id>/v2.0"
|
||||
client_id: "<client id>"
|
||||
client_secret: "<client secret>"
|
||||
scopes: ["openid", "profile"]
|
||||
authorization_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/authorize"
|
||||
token_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/token"
|
||||
userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo"
|
||||
|
||||
user_mapping_provider:
|
||||
config:
|
||||
localpart_template: "{{ user.preferred_username.split('@')[0] }}"
|
||||
display_name_template: "{{ user.name }}"
|
||||
user_mapping_provider:
|
||||
config:
|
||||
localpart_template: "{{ user.preferred_username.split('@')[0] }}"
|
||||
display_name_template: "{{ user.name }}"
|
||||
```
|
||||
|
||||
### [Dex][dex-idp]
|
||||
|
@ -103,17 +103,18 @@ Run with `dex serve examples/config-dev.yaml`.
|
|||
Synapse config:
|
||||
|
||||
```yaml
|
||||
oidc_config:
|
||||
enabled: true
|
||||
skip_verification: true # This is needed as Dex is served on an insecure endpoint
|
||||
issuer: "http://127.0.0.1:5556/dex"
|
||||
client_id: "synapse"
|
||||
client_secret: "secret"
|
||||
scopes: ["openid", "profile"]
|
||||
user_mapping_provider:
|
||||
config:
|
||||
localpart_template: "{{ user.name }}"
|
||||
display_name_template: "{{ user.name|capitalize }}"
|
||||
oidc_providers:
|
||||
- idp_id: dex
|
||||
idp_name: "My Dex server"
|
||||
skip_verification: true # This is needed as Dex is served on an insecure endpoint
|
||||
issuer: "http://127.0.0.1:5556/dex"
|
||||
client_id: "synapse"
|
||||
client_secret: "secret"
|
||||
scopes: ["openid", "profile"]
|
||||
user_mapping_provider:
|
||||
config:
|
||||
localpart_template: "{{ user.name }}"
|
||||
display_name_template: "{{ user.name|capitalize }}"
|
||||
```
|
||||
### [Keycloak][keycloak-idp]
|
||||
|
||||
|
@ -152,16 +153,17 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to
|
|||
8. Copy Secret
|
||||
|
||||
```yaml
|
||||
oidc_config:
|
||||
enabled: true
|
||||
issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}"
|
||||
client_id: "synapse"
|
||||
client_secret: "copy secret generated from above"
|
||||
scopes: ["openid", "profile"]
|
||||
user_mapping_provider:
|
||||
config:
|
||||
localpart_template: "{{ user.preferred_username }}"
|
||||
display_name_template: "{{ user.name }}"
|
||||
oidc_providers:
|
||||
- idp_id: keycloak
|
||||
idp_name: "My KeyCloak server"
|
||||
issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}"
|
||||
client_id: "synapse"
|
||||
client_secret: "copy secret generated from above"
|
||||
scopes: ["openid", "profile"]
|
||||
user_mapping_provider:
|
||||
config:
|
||||
localpart_template: "{{ user.preferred_username }}"
|
||||
display_name_template: "{{ user.name }}"
|
||||
```
|
||||
### [Auth0][auth0]
|
||||
|
||||
|
@ -191,16 +193,17 @@ oidc_config:
|
|||
Synapse config:
|
||||
|
||||
```yaml
|
||||
oidc_config:
|
||||
enabled: true
|
||||
issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED
|
||||
client_id: "your-client-id" # TO BE FILLED
|
||||
client_secret: "your-client-secret" # TO BE FILLED
|
||||
scopes: ["openid", "profile"]
|
||||
user_mapping_provider:
|
||||
config:
|
||||
localpart_template: "{{ user.preferred_username }}"
|
||||
display_name_template: "{{ user.name }}"
|
||||
oidc_providers:
|
||||
- idp_id: auth0
|
||||
idp_name: Auth0
|
||||
issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED
|
||||
client_id: "your-client-id" # TO BE FILLED
|
||||
client_secret: "your-client-secret" # TO BE FILLED
|
||||
scopes: ["openid", "profile"]
|
||||
user_mapping_provider:
|
||||
config:
|
||||
localpart_template: "{{ user.preferred_username }}"
|
||||
display_name_template: "{{ user.name }}"
|
||||
```
|
||||
|
||||
### GitHub
|
||||
|
@ -219,21 +222,22 @@ does not return a `sub` property, an alternative `subject_claim` has to be set.
|
|||
Synapse config:
|
||||
|
||||
```yaml
|
||||
oidc_config:
|
||||
enabled: true
|
||||
discover: false
|
||||
issuer: "https://github.com/"
|
||||
client_id: "your-client-id" # TO BE FILLED
|
||||
client_secret: "your-client-secret" # TO BE FILLED
|
||||
authorization_endpoint: "https://github.com/login/oauth/authorize"
|
||||
token_endpoint: "https://github.com/login/oauth/access_token"
|
||||
userinfo_endpoint: "https://api.github.com/user"
|
||||
scopes: ["read:user"]
|
||||
user_mapping_provider:
|
||||
config:
|
||||
subject_claim: "id"
|
||||
localpart_template: "{{ user.login }}"
|
||||
display_name_template: "{{ user.name }}"
|
||||
oidc_providers:
|
||||
- idp_id: github
|
||||
idp_name: Github
|
||||
discover: false
|
||||
issuer: "https://github.com/"
|
||||
client_id: "your-client-id" # TO BE FILLED
|
||||
client_secret: "your-client-secret" # TO BE FILLED
|
||||
authorization_endpoint: "https://github.com/login/oauth/authorize"
|
||||
token_endpoint: "https://github.com/login/oauth/access_token"
|
||||
userinfo_endpoint: "https://api.github.com/user"
|
||||
scopes: ["read:user"]
|
||||
user_mapping_provider:
|
||||
config:
|
||||
subject_claim: "id"
|
||||
localpart_template: "{{ user.login }}"
|
||||
display_name_template: "{{ user.name }}"
|
||||
```
|
||||
|
||||
### [Google][google-idp]
|
||||
|
@ -243,16 +247,17 @@ oidc_config:
|
|||
2. add an "OAuth Client ID" for a Web Application under "Credentials".
|
||||
3. Copy the Client ID and Client Secret, and add the following to your synapse config:
|
||||
```yaml
|
||||
oidc_config:
|
||||
enabled: true
|
||||
issuer: "https://accounts.google.com/"
|
||||
client_id: "your-client-id" # TO BE FILLED
|
||||
client_secret: "your-client-secret" # TO BE FILLED
|
||||
scopes: ["openid", "profile"]
|
||||
user_mapping_provider:
|
||||
config:
|
||||
localpart_template: "{{ user.given_name|lower }}"
|
||||
display_name_template: "{{ user.name }}"
|
||||
oidc_providers:
|
||||
- idp_id: google
|
||||
idp_name: Google
|
||||
issuer: "https://accounts.google.com/"
|
||||
client_id: "your-client-id" # TO BE FILLED
|
||||
client_secret: "your-client-secret" # TO BE FILLED
|
||||
scopes: ["openid", "profile"]
|
||||
user_mapping_provider:
|
||||
config:
|
||||
localpart_template: "{{ user.given_name|lower }}"
|
||||
display_name_template: "{{ user.name }}"
|
||||
```
|
||||
4. Back in the Google console, add this Authorized redirect URI: `[synapse
|
||||
public baseurl]/_synapse/oidc/callback`.
|
||||
|
@ -266,16 +271,17 @@ oidc_config:
|
|||
Synapse config:
|
||||
|
||||
```yaml
|
||||
oidc_config:
|
||||
enabled: true
|
||||
issuer: "https://id.twitch.tv/oauth2/"
|
||||
client_id: "your-client-id" # TO BE FILLED
|
||||
client_secret: "your-client-secret" # TO BE FILLED
|
||||
client_auth_method: "client_secret_post"
|
||||
user_mapping_provider:
|
||||
config:
|
||||
localpart_template: "{{ user.preferred_username }}"
|
||||
display_name_template: "{{ user.name }}"
|
||||
oidc_providers:
|
||||
- idp_id: twitch
|
||||
idp_name: Twitch
|
||||
issuer: "https://id.twitch.tv/oauth2/"
|
||||
client_id: "your-client-id" # TO BE FILLED
|
||||
client_secret: "your-client-secret" # TO BE FILLED
|
||||
client_auth_method: "client_secret_post"
|
||||
user_mapping_provider:
|
||||
config:
|
||||
localpart_template: "{{ user.preferred_username }}"
|
||||
display_name_template: "{{ user.name }}"
|
||||
```
|
||||
|
||||
### GitLab
|
||||
|
@ -287,16 +293,17 @@ oidc_config:
|
|||
Synapse config:
|
||||
|
||||
```yaml
|
||||
oidc_config:
|
||||
enabled: true
|
||||
issuer: "https://gitlab.com/"
|
||||
client_id: "your-client-id" # TO BE FILLED
|
||||
client_secret: "your-client-secret" # TO BE FILLED
|
||||
client_auth_method: "client_secret_post"
|
||||
scopes: ["openid", "read_user"]
|
||||
user_profile_method: "userinfo_endpoint"
|
||||
user_mapping_provider:
|
||||
config:
|
||||
localpart_template: '{{ user.nickname }}'
|
||||
display_name_template: '{{ user.name }}'
|
||||
oidc_providers:
|
||||
- idp_id: gitlab
|
||||
idp_name: Gitlab
|
||||
issuer: "https://gitlab.com/"
|
||||
client_id: "your-client-id" # TO BE FILLED
|
||||
client_secret: "your-client-secret" # TO BE FILLED
|
||||
client_auth_method: "client_secret_post"
|
||||
scopes: ["openid", "read_user"]
|
||||
user_profile_method: "userinfo_endpoint"
|
||||
user_mapping_provider:
|
||||
config:
|
||||
localpart_template: '{{ user.nickname }}'
|
||||
display_name_template: '{{ user.name }}'
|
||||
```
|
||||
|
|
|
@ -1709,141 +1709,149 @@ saml2_config:
|
|||
#idp_entityid: 'https://our_idp/entityid'
|
||||
|
||||
|
||||
# Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login.
|
||||
# List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
|
||||
# and login.
|
||||
#
|
||||
# Options for each entry include:
|
||||
#
|
||||
# idp_id: a unique identifier for this identity provider. Used internally
|
||||
# by Synapse; should be a single word such as 'github'.
|
||||
#
|
||||
# Note that, if this is changed, users authenticating via that provider
|
||||
# will no longer be recognised as the same user!
|
||||
#
|
||||
# idp_name: A user-facing name for this identity provider, which is used to
|
||||
# offer the user a choice of login mechanisms.
|
||||
#
|
||||
# discover: set to 'false' to disable the use of the OIDC discovery mechanism
|
||||
# to discover endpoints. Defaults to true.
|
||||
#
|
||||
# issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery
|
||||
# is enabled) to discover the provider's endpoints.
|
||||
#
|
||||
# client_id: Required. oauth2 client id to use.
|
||||
#
|
||||
# client_secret: Required. oauth2 client secret to use.
|
||||
#
|
||||
# client_auth_method: auth method to use when exchanging the token. Valid
|
||||
# values are 'client_secret_basic' (default), 'client_secret_post' and
|
||||
# 'none'.
|
||||
#
|
||||
# scopes: list of scopes to request. This should normally include the "openid"
|
||||
# scope. Defaults to ["openid"].
|
||||
#
|
||||
# authorization_endpoint: the oauth2 authorization endpoint. Required if
|
||||
# provider discovery is disabled.
|
||||
#
|
||||
# token_endpoint: the oauth2 token endpoint. Required if provider discovery is
|
||||
# disabled.
|
||||
#
|
||||
# userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is
|
||||
# disabled and the 'openid' scope is not requested.
|
||||
#
|
||||
# jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and
|
||||
# the 'openid' scope is used.
|
||||
#
|
||||
# skip_verification: set to 'true' to skip metadata verification. Use this if
|
||||
# you are connecting to a provider that is not OpenID Connect compliant.
|
||||
# Defaults to false. Avoid this in production.
|
||||
#
|
||||
# user_profile_method: Whether to fetch the user profile from the userinfo
|
||||
# endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
|
||||
#
|
||||
# Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
|
||||
# included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
|
||||
# userinfo endpoint.
|
||||
#
|
||||
# allow_existing_users: set to 'true' to allow a user logging in via OIDC to
|
||||
# match a pre-existing account instead of failing. This could be used if
|
||||
# switching from password logins to OIDC. Defaults to false.
|
||||
#
|
||||
# user_mapping_provider: Configuration for how attributes returned from a OIDC
|
||||
# provider are mapped onto a matrix user. This setting has the following
|
||||
# sub-properties:
|
||||
#
|
||||
# module: The class name of a custom mapping module. Default is
|
||||
# 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'.
|
||||
# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
|
||||
# for information on implementing a custom mapping provider.
|
||||
#
|
||||
# config: Configuration for the mapping provider module. This section will
|
||||
# be passed as a Python dictionary to the user mapping provider
|
||||
# module's `parse_config` method.
|
||||
#
|
||||
# For the default provider, the following settings are available:
|
||||
#
|
||||
# sub: name of the claim containing a unique identifier for the
|
||||
# user. Defaults to 'sub', which OpenID Connect compliant
|
||||
# providers should provide.
|
||||
#
|
||||
# localpart_template: Jinja2 template for the localpart of the MXID.
|
||||
# If this is not set, the user will be prompted to choose their
|
||||
# own username.
|
||||
#
|
||||
# display_name_template: Jinja2 template for the display name to set
|
||||
# on first login. If unset, no displayname will be set.
|
||||
#
|
||||
# extra_attributes: a map of Jinja2 templates for extra attributes
|
||||
# to send back to the client during login.
|
||||
# Note that these are non-standard and clients will ignore them
|
||||
# without modifications.
|
||||
#
|
||||
# When rendering, the Jinja2 templates are given a 'user' variable,
|
||||
# which is set to the claims returned by the UserInfo Endpoint and/or
|
||||
# in the ID Token.
|
||||
#
|
||||
# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
|
||||
# for some example configurations.
|
||||
# for information on how to configure these options.
|
||||
#
|
||||
oidc_config:
|
||||
# Uncomment the following to enable authorization against an OpenID Connect
|
||||
# server. Defaults to false.
|
||||
# For backwards compatibility, it is also possible to configure a single OIDC
|
||||
# provider via an 'oidc_config' setting. This is now deprecated and admins are
|
||||
# advised to migrate to the 'oidc_providers' format.
|
||||
#
|
||||
oidc_providers:
|
||||
# Generic example
|
||||
#
|
||||
#enabled: true
|
||||
#- idp_id: my_idp
|
||||
# idp_name: "My OpenID provider"
|
||||
# discover: false
|
||||
# issuer: "https://accounts.example.com/"
|
||||
# client_id: "provided-by-your-issuer"
|
||||
# client_secret: "provided-by-your-issuer"
|
||||
# client_auth_method: client_secret_post
|
||||
# scopes: ["openid", "profile"]
|
||||
# authorization_endpoint: "https://accounts.example.com/oauth2/auth"
|
||||
# token_endpoint: "https://accounts.example.com/oauth2/token"
|
||||
# userinfo_endpoint: "https://accounts.example.com/userinfo"
|
||||
# jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
|
||||
# skip_verification: true
|
||||
|
||||
# Uncomment the following to disable use of the OIDC discovery mechanism to
|
||||
# discover endpoints. Defaults to true.
|
||||
# For use with Keycloak
|
||||
#
|
||||
#discover: false
|
||||
#- idp_id: keycloak
|
||||
# idp_name: Keycloak
|
||||
# issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name"
|
||||
# client_id: "synapse"
|
||||
# client_secret: "copy secret generated in Keycloak UI"
|
||||
# scopes: ["openid", "profile"]
|
||||
|
||||
# the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
|
||||
# discover the provider's endpoints.
|
||||
# For use with Github
|
||||
#
|
||||
# Required if 'enabled' is true.
|
||||
#
|
||||
#issuer: "https://accounts.example.com/"
|
||||
|
||||
# oauth2 client id to use.
|
||||
#
|
||||
# Required if 'enabled' is true.
|
||||
#
|
||||
#client_id: "provided-by-your-issuer"
|
||||
|
||||
# oauth2 client secret to use.
|
||||
#
|
||||
# Required if 'enabled' is true.
|
||||
#
|
||||
#client_secret: "provided-by-your-issuer"
|
||||
|
||||
# auth method to use when exchanging the token.
|
||||
# Valid values are 'client_secret_basic' (default), 'client_secret_post' and
|
||||
# 'none'.
|
||||
#
|
||||
#client_auth_method: client_secret_post
|
||||
|
||||
# list of scopes to request. This should normally include the "openid" scope.
|
||||
# Defaults to ["openid"].
|
||||
#
|
||||
#scopes: ["openid", "profile"]
|
||||
|
||||
# the oauth2 authorization endpoint. Required if provider discovery is disabled.
|
||||
#
|
||||
#authorization_endpoint: "https://accounts.example.com/oauth2/auth"
|
||||
|
||||
# the oauth2 token endpoint. Required if provider discovery is disabled.
|
||||
#
|
||||
#token_endpoint: "https://accounts.example.com/oauth2/token"
|
||||
|
||||
# the OIDC userinfo endpoint. Required if discovery is disabled and the
|
||||
# "openid" scope is not requested.
|
||||
#
|
||||
#userinfo_endpoint: "https://accounts.example.com/userinfo"
|
||||
|
||||
# URI where to fetch the JWKS. Required if discovery is disabled and the
|
||||
# "openid" scope is used.
|
||||
#
|
||||
#jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
|
||||
|
||||
# Uncomment to skip metadata verification. Defaults to false.
|
||||
#
|
||||
# Use this if you are connecting to a provider that is not OpenID Connect
|
||||
# compliant.
|
||||
# Avoid this in production.
|
||||
#
|
||||
#skip_verification: true
|
||||
|
||||
# Whether to fetch the user profile from the userinfo endpoint. Valid
|
||||
# values are: "auto" or "userinfo_endpoint".
|
||||
#
|
||||
# Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
|
||||
# in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
|
||||
#
|
||||
#user_profile_method: "userinfo_endpoint"
|
||||
|
||||
# Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
|
||||
# of failing. This could be used if switching from password logins to OIDC. Defaults to false.
|
||||
#
|
||||
#allow_existing_users: true
|
||||
|
||||
# An external module can be provided here as a custom solution to mapping
|
||||
# attributes returned from a OIDC provider onto a matrix user.
|
||||
#
|
||||
user_mapping_provider:
|
||||
# The custom module's class. Uncomment to use a custom module.
|
||||
# Default is 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'.
|
||||
#
|
||||
# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
|
||||
# for information on implementing a custom mapping provider.
|
||||
#
|
||||
#module: mapping_provider.OidcMappingProvider
|
||||
|
||||
# Custom configuration values for the module. This section will be passed as
|
||||
# a Python dictionary to the user mapping provider module's `parse_config`
|
||||
# method.
|
||||
#
|
||||
# The examples below are intended for the default provider: they should be
|
||||
# changed if using a custom provider.
|
||||
#
|
||||
config:
|
||||
# name of the claim containing a unique identifier for the user.
|
||||
# Defaults to `sub`, which OpenID Connect compliant providers should provide.
|
||||
#
|
||||
#subject_claim: "sub"
|
||||
|
||||
# Jinja2 template for the localpart of the MXID.
|
||||
#
|
||||
# When rendering, this template is given the following variables:
|
||||
# * user: The claims returned by the UserInfo Endpoint and/or in the ID
|
||||
# Token
|
||||
#
|
||||
# If this is not set, the user will be prompted to choose their
|
||||
# own username.
|
||||
#
|
||||
#localpart_template: "{{ user.preferred_username }}"
|
||||
|
||||
# Jinja2 template for the display name to set on first login.
|
||||
#
|
||||
# If unset, no displayname will be set.
|
||||
#
|
||||
#display_name_template: "{{ user.given_name }} {{ user.last_name }}"
|
||||
|
||||
# Jinja2 templates for extra attributes to send back to the client during
|
||||
# login.
|
||||
#
|
||||
# Note that these are non-standard and clients will ignore them without modifications.
|
||||
#
|
||||
#extra_attributes:
|
||||
#birthdate: "{{ user.birthdate }}"
|
||||
|
||||
#- idp_id: google
|
||||
# idp_name: Google
|
||||
# discover: false
|
||||
# issuer: "https://github.com/"
|
||||
# client_id: "your-client-id" # TO BE FILLED
|
||||
# client_secret: "your-client-secret" # TO BE FILLED
|
||||
# authorization_endpoint: "https://github.com/login/oauth/authorize"
|
||||
# token_endpoint: "https://github.com/login/oauth/access_token"
|
||||
# userinfo_endpoint: "https://api.github.com/user"
|
||||
# scopes: ["read:user"]
|
||||
# user_mapping_provider:
|
||||
# config:
|
||||
# subject_claim: "id"
|
||||
# localpart_template: "{ user.login }"
|
||||
# display_name_template: "{ user.name }"
|
||||
|
||||
|
||||
# Enable Central Authentication Service (CAS) for registration and login.
|
||||
|
|
|
@ -16,6 +16,9 @@ workers only work with PostgreSQL-based Synapse deployments. SQLite should only
|
|||
be used for demo purposes and any admin considering workers should already be
|
||||
running PostgreSQL.
|
||||
|
||||
See also https://matrix.org/blog/2020/11/03/how-we-fixed-synapses-scalability
|
||||
for a higher level overview.
|
||||
|
||||
## Main process/worker communication
|
||||
|
||||
The processes communicate with each other via a Synapse-specific protocol called
|
||||
|
|
|
@ -40,7 +40,7 @@ class CasConfig(Config):
|
|||
self.cas_required_attributes = {}
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
return """
|
||||
return """\
|
||||
# Enable Central Authentication Service (CAS) for registration and login.
|
||||
#
|
||||
cas_config:
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import string
|
||||
from typing import Optional, Type
|
||||
from typing import Iterable, Optional, Type
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -33,16 +33,8 @@ class OIDCConfig(Config):
|
|||
section = "oidc"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
validate_config(MAIN_CONFIG_SCHEMA, config, ())
|
||||
|
||||
self.oidc_provider = None # type: Optional[OidcProviderConfig]
|
||||
|
||||
oidc_config = config.get("oidc_config")
|
||||
if oidc_config and oidc_config.get("enabled", False):
|
||||
validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
|
||||
self.oidc_provider = _parse_oidc_config_dict(oidc_config)
|
||||
|
||||
if not self.oidc_provider:
|
||||
self.oidc_providers = tuple(_parse_oidc_provider_configs(config))
|
||||
if not self.oidc_providers:
|
||||
return
|
||||
|
||||
try:
|
||||
|
@ -58,144 +50,153 @@ class OIDCConfig(Config):
|
|||
@property
|
||||
def oidc_enabled(self) -> bool:
|
||||
# OIDC is enabled if we have a provider
|
||||
return bool(self.oidc_provider)
|
||||
return bool(self.oidc_providers)
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
return """\
|
||||
# Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login.
|
||||
# List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
|
||||
# and login.
|
||||
#
|
||||
# Options for each entry include:
|
||||
#
|
||||
# idp_id: a unique identifier for this identity provider. Used internally
|
||||
# by Synapse; should be a single word such as 'github'.
|
||||
#
|
||||
# Note that, if this is changed, users authenticating via that provider
|
||||
# will no longer be recognised as the same user!
|
||||
#
|
||||
# idp_name: A user-facing name for this identity provider, which is used to
|
||||
# offer the user a choice of login mechanisms.
|
||||
#
|
||||
# discover: set to 'false' to disable the use of the OIDC discovery mechanism
|
||||
# to discover endpoints. Defaults to true.
|
||||
#
|
||||
# issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery
|
||||
# is enabled) to discover the provider's endpoints.
|
||||
#
|
||||
# client_id: Required. oauth2 client id to use.
|
||||
#
|
||||
# client_secret: Required. oauth2 client secret to use.
|
||||
#
|
||||
# client_auth_method: auth method to use when exchanging the token. Valid
|
||||
# values are 'client_secret_basic' (default), 'client_secret_post' and
|
||||
# 'none'.
|
||||
#
|
||||
# scopes: list of scopes to request. This should normally include the "openid"
|
||||
# scope. Defaults to ["openid"].
|
||||
#
|
||||
# authorization_endpoint: the oauth2 authorization endpoint. Required if
|
||||
# provider discovery is disabled.
|
||||
#
|
||||
# token_endpoint: the oauth2 token endpoint. Required if provider discovery is
|
||||
# disabled.
|
||||
#
|
||||
# userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is
|
||||
# disabled and the 'openid' scope is not requested.
|
||||
#
|
||||
# jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and
|
||||
# the 'openid' scope is used.
|
||||
#
|
||||
# skip_verification: set to 'true' to skip metadata verification. Use this if
|
||||
# you are connecting to a provider that is not OpenID Connect compliant.
|
||||
# Defaults to false. Avoid this in production.
|
||||
#
|
||||
# user_profile_method: Whether to fetch the user profile from the userinfo
|
||||
# endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
|
||||
#
|
||||
# Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
|
||||
# included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
|
||||
# userinfo endpoint.
|
||||
#
|
||||
# allow_existing_users: set to 'true' to allow a user logging in via OIDC to
|
||||
# match a pre-existing account instead of failing. This could be used if
|
||||
# switching from password logins to OIDC. Defaults to false.
|
||||
#
|
||||
# user_mapping_provider: Configuration for how attributes returned from a OIDC
|
||||
# provider are mapped onto a matrix user. This setting has the following
|
||||
# sub-properties:
|
||||
#
|
||||
# module: The class name of a custom mapping module. Default is
|
||||
# {mapping_provider!r}.
|
||||
# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
|
||||
# for information on implementing a custom mapping provider.
|
||||
#
|
||||
# config: Configuration for the mapping provider module. This section will
|
||||
# be passed as a Python dictionary to the user mapping provider
|
||||
# module's `parse_config` method.
|
||||
#
|
||||
# For the default provider, the following settings are available:
|
||||
#
|
||||
# sub: name of the claim containing a unique identifier for the
|
||||
# user. Defaults to 'sub', which OpenID Connect compliant
|
||||
# providers should provide.
|
||||
#
|
||||
# localpart_template: Jinja2 template for the localpart of the MXID.
|
||||
# If this is not set, the user will be prompted to choose their
|
||||
# own username.
|
||||
#
|
||||
# display_name_template: Jinja2 template for the display name to set
|
||||
# on first login. If unset, no displayname will be set.
|
||||
#
|
||||
# extra_attributes: a map of Jinja2 templates for extra attributes
|
||||
# to send back to the client during login.
|
||||
# Note that these are non-standard and clients will ignore them
|
||||
# without modifications.
|
||||
#
|
||||
# When rendering, the Jinja2 templates are given a 'user' variable,
|
||||
# which is set to the claims returned by the UserInfo Endpoint and/or
|
||||
# in the ID Token.
|
||||
#
|
||||
# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
|
||||
# for some example configurations.
|
||||
# for information on how to configure these options.
|
||||
#
|
||||
oidc_config:
|
||||
# Uncomment the following to enable authorization against an OpenID Connect
|
||||
# server. Defaults to false.
|
||||
# For backwards compatibility, it is also possible to configure a single OIDC
|
||||
# provider via an 'oidc_config' setting. This is now deprecated and admins are
|
||||
# advised to migrate to the 'oidc_providers' format.
|
||||
#
|
||||
oidc_providers:
|
||||
# Generic example
|
||||
#
|
||||
#enabled: true
|
||||
#- idp_id: my_idp
|
||||
# idp_name: "My OpenID provider"
|
||||
# discover: false
|
||||
# issuer: "https://accounts.example.com/"
|
||||
# client_id: "provided-by-your-issuer"
|
||||
# client_secret: "provided-by-your-issuer"
|
||||
# client_auth_method: client_secret_post
|
||||
# scopes: ["openid", "profile"]
|
||||
# authorization_endpoint: "https://accounts.example.com/oauth2/auth"
|
||||
# token_endpoint: "https://accounts.example.com/oauth2/token"
|
||||
# userinfo_endpoint: "https://accounts.example.com/userinfo"
|
||||
# jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
|
||||
# skip_verification: true
|
||||
|
||||
# Uncomment the following to disable use of the OIDC discovery mechanism to
|
||||
# discover endpoints. Defaults to true.
|
||||
# For use with Keycloak
|
||||
#
|
||||
#discover: false
|
||||
#- idp_id: keycloak
|
||||
# idp_name: Keycloak
|
||||
# issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name"
|
||||
# client_id: "synapse"
|
||||
# client_secret: "copy secret generated in Keycloak UI"
|
||||
# scopes: ["openid", "profile"]
|
||||
|
||||
# the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
|
||||
# discover the provider's endpoints.
|
||||
# For use with Github
|
||||
#
|
||||
# Required if 'enabled' is true.
|
||||
#
|
||||
#issuer: "https://accounts.example.com/"
|
||||
|
||||
# oauth2 client id to use.
|
||||
#
|
||||
# Required if 'enabled' is true.
|
||||
#
|
||||
#client_id: "provided-by-your-issuer"
|
||||
|
||||
# oauth2 client secret to use.
|
||||
#
|
||||
# Required if 'enabled' is true.
|
||||
#
|
||||
#client_secret: "provided-by-your-issuer"
|
||||
|
||||
# auth method to use when exchanging the token.
|
||||
# Valid values are 'client_secret_basic' (default), 'client_secret_post' and
|
||||
# 'none'.
|
||||
#
|
||||
#client_auth_method: client_secret_post
|
||||
|
||||
# list of scopes to request. This should normally include the "openid" scope.
|
||||
# Defaults to ["openid"].
|
||||
#
|
||||
#scopes: ["openid", "profile"]
|
||||
|
||||
# the oauth2 authorization endpoint. Required if provider discovery is disabled.
|
||||
#
|
||||
#authorization_endpoint: "https://accounts.example.com/oauth2/auth"
|
||||
|
||||
# the oauth2 token endpoint. Required if provider discovery is disabled.
|
||||
#
|
||||
#token_endpoint: "https://accounts.example.com/oauth2/token"
|
||||
|
||||
# the OIDC userinfo endpoint. Required if discovery is disabled and the
|
||||
# "openid" scope is not requested.
|
||||
#
|
||||
#userinfo_endpoint: "https://accounts.example.com/userinfo"
|
||||
|
||||
# URI where to fetch the JWKS. Required if discovery is disabled and the
|
||||
# "openid" scope is used.
|
||||
#
|
||||
#jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
|
||||
|
||||
# Uncomment to skip metadata verification. Defaults to false.
|
||||
#
|
||||
# Use this if you are connecting to a provider that is not OpenID Connect
|
||||
# compliant.
|
||||
# Avoid this in production.
|
||||
#
|
||||
#skip_verification: true
|
||||
|
||||
# Whether to fetch the user profile from the userinfo endpoint. Valid
|
||||
# values are: "auto" or "userinfo_endpoint".
|
||||
#
|
||||
# Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
|
||||
# in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
|
||||
#
|
||||
#user_profile_method: "userinfo_endpoint"
|
||||
|
||||
# Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
|
||||
# of failing. This could be used if switching from password logins to OIDC. Defaults to false.
|
||||
#
|
||||
#allow_existing_users: true
|
||||
|
||||
# An external module can be provided here as a custom solution to mapping
|
||||
# attributes returned from a OIDC provider onto a matrix user.
|
||||
#
|
||||
user_mapping_provider:
|
||||
# The custom module's class. Uncomment to use a custom module.
|
||||
# Default is {mapping_provider!r}.
|
||||
#
|
||||
# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
|
||||
# for information on implementing a custom mapping provider.
|
||||
#
|
||||
#module: mapping_provider.OidcMappingProvider
|
||||
|
||||
# Custom configuration values for the module. This section will be passed as
|
||||
# a Python dictionary to the user mapping provider module's `parse_config`
|
||||
# method.
|
||||
#
|
||||
# The examples below are intended for the default provider: they should be
|
||||
# changed if using a custom provider.
|
||||
#
|
||||
config:
|
||||
# name of the claim containing a unique identifier for the user.
|
||||
# Defaults to `sub`, which OpenID Connect compliant providers should provide.
|
||||
#
|
||||
#subject_claim: "sub"
|
||||
|
||||
# Jinja2 template for the localpart of the MXID.
|
||||
#
|
||||
# When rendering, this template is given the following variables:
|
||||
# * user: The claims returned by the UserInfo Endpoint and/or in the ID
|
||||
# Token
|
||||
#
|
||||
# If this is not set, the user will be prompted to choose their
|
||||
# own username.
|
||||
#
|
||||
#localpart_template: "{{{{ user.preferred_username }}}}"
|
||||
|
||||
# Jinja2 template for the display name to set on first login.
|
||||
#
|
||||
# If unset, no displayname will be set.
|
||||
#
|
||||
#display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}"
|
||||
|
||||
# Jinja2 templates for extra attributes to send back to the client during
|
||||
# login.
|
||||
#
|
||||
# Note that these are non-standard and clients will ignore them without modifications.
|
||||
#
|
||||
#extra_attributes:
|
||||
#birthdate: "{{{{ user.birthdate }}}}"
|
||||
#- idp_id: google
|
||||
# idp_name: Google
|
||||
# discover: false
|
||||
# issuer: "https://github.com/"
|
||||
# client_id: "your-client-id" # TO BE FILLED
|
||||
# client_secret: "your-client-secret" # TO BE FILLED
|
||||
# authorization_endpoint: "https://github.com/login/oauth/authorize"
|
||||
# token_endpoint: "https://github.com/login/oauth/access_token"
|
||||
# userinfo_endpoint: "https://api.github.com/user"
|
||||
# scopes: ["read:user"]
|
||||
# user_mapping_provider:
|
||||
# config:
|
||||
# subject_claim: "id"
|
||||
# localpart_template: "{{ user.login }}"
|
||||
# display_name_template: "{{ user.name }}"
|
||||
""".format(
|
||||
mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
|
||||
)
|
||||
|
@ -234,7 +235,22 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
|
|||
},
|
||||
}
|
||||
|
||||
# the `oidc_config` setting can either be None (as it is in the default
|
||||
# the same as OIDC_PROVIDER_CONFIG_SCHEMA, but with compulsory idp_id and idp_name
|
||||
OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA = {
|
||||
"allOf": [OIDC_PROVIDER_CONFIG_SCHEMA, {"required": ["idp_id", "idp_name"]}]
|
||||
}
|
||||
|
||||
|
||||
# the `oidc_providers` list can either be None (as it is in the default config), or
|
||||
# a list of provider configs, each of which requires an explicit ID and name.
|
||||
OIDC_PROVIDER_LIST_SCHEMA = {
|
||||
"oneOf": [
|
||||
{"type": "null"},
|
||||
{"type": "array", "items": OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA},
|
||||
]
|
||||
}
|
||||
|
||||
# the `oidc_config` setting can either be None (which it used to be in the default
|
||||
# config), or an object. If an object, it is ignored unless it has an "enabled: True"
|
||||
# property.
|
||||
#
|
||||
|
@ -243,12 +259,41 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
|
|||
# additional checks in the code.
|
||||
OIDC_CONFIG_SCHEMA = {"oneOf": [{"type": "null"}, {"type": "object"}]}
|
||||
|
||||
# the top-level schema can contain an "oidc_config" and/or an "oidc_providers".
|
||||
MAIN_CONFIG_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {"oidc_config": OIDC_CONFIG_SCHEMA},
|
||||
"properties": {
|
||||
"oidc_config": OIDC_CONFIG_SCHEMA,
|
||||
"oidc_providers": OIDC_PROVIDER_LIST_SCHEMA,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _parse_oidc_provider_configs(config: JsonDict) -> Iterable["OidcProviderConfig"]:
|
||||
"""extract and parse the OIDC provider configs from the config dict
|
||||
|
||||
The configuration may contain either a single `oidc_config` object with an
|
||||
`enabled: True` property, or a list of provider configurations under
|
||||
`oidc_providers`, *or both*.
|
||||
|
||||
Returns a generator which yields the OidcProviderConfig objects
|
||||
"""
|
||||
validate_config(MAIN_CONFIG_SCHEMA, config, ())
|
||||
|
||||
for p in config.get("oidc_providers") or []:
|
||||
yield _parse_oidc_config_dict(p)
|
||||
|
||||
# for backwards-compatibility, it is also possible to provide a single "oidc_config"
|
||||
# object with an "enabled: True" property.
|
||||
oidc_config = config.get("oidc_config")
|
||||
if oidc_config and oidc_config.get("enabled", False):
|
||||
# MAIN_CONFIG_SCHEMA checks that `oidc_config` is an object, but not that
|
||||
# it matches OIDC_PROVIDER_CONFIG_SCHEMA (see the comments on OIDC_CONFIG_SCHEMA
|
||||
# above), so now we need to validate it.
|
||||
validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
|
||||
yield _parse_oidc_config_dict(oidc_config)
|
||||
|
||||
|
||||
def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig":
|
||||
"""Take the configuration dict and parse it into an OidcProviderConfig
|
||||
|
||||
|
|
|
@ -14,14 +14,13 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from distutils.util import strtobool
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from synapse.api.constants import RoomCreationPreset
|
||||
from synapse.config._base import Config, ConfigError
|
||||
from synapse.types import RoomAlias, UserID
|
||||
from synapse.util.stringutils import random_string_with_symbols
|
||||
from synapse.util.stringutils import random_string_with_symbols, strtobool
|
||||
|
||||
|
||||
class AccountValidityConfig(Config):
|
||||
|
@ -86,12 +85,12 @@ class RegistrationConfig(Config):
|
|||
section = "registration"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
self.enable_registration = bool(
|
||||
strtobool(str(config.get("enable_registration", False)))
|
||||
self.enable_registration = strtobool(
|
||||
str(config.get("enable_registration", False))
|
||||
)
|
||||
if "disable_registration" in config:
|
||||
self.enable_registration = not bool(
|
||||
strtobool(str(config["disable_registration"]))
|
||||
self.enable_registration = not strtobool(
|
||||
str(config["disable_registration"])
|
||||
)
|
||||
|
||||
self.account_validity = AccountValidityConfig(
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
|
||||
import abc
|
||||
import os
|
||||
from distutils.util import strtobool
|
||||
from typing import Dict, Optional, Tuple, Type
|
||||
|
||||
from unpaddedbase64 import encode_base64
|
||||
|
@ -26,6 +25,7 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVers
|
|||
from synapse.types import JsonDict, RoomStreamToken
|
||||
from synapse.util.caches import intern_dict
|
||||
from synapse.util.frozenutils import freeze
|
||||
from synapse.util.stringutils import strtobool
|
||||
|
||||
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
|
||||
# bugs where we accidentally share e.g. signature dicts. However, converting a
|
||||
|
@ -34,6 +34,7 @@ from synapse.util.frozenutils import freeze
|
|||
# NOTE: This is overridden by the configuration by the Synapse worker apps, but
|
||||
# for the sake of tests, it is set here while it cannot be configured on the
|
||||
# homeserver object itself.
|
||||
|
||||
USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
|
||||
|
||||
|
||||
|
|
|
@ -163,7 +163,7 @@ class DeviceMessageHandler:
|
|||
await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
|
||||
|
||||
# Immediately attempt a resync in the background
|
||||
run_in_background(self._user_device_resync, sender_user_id)
|
||||
run_in_background(self._user_device_resync, user_id=sender_user_id)
|
||||
|
||||
async def send_device_message(
|
||||
self,
|
||||
|
|
|
@ -78,21 +78,28 @@ class OidcHandler:
|
|||
def __init__(self, hs: "HomeServer"):
|
||||
self._sso_handler = hs.get_sso_handler()
|
||||
|
||||
provider_conf = hs.config.oidc.oidc_provider
|
||||
provider_confs = hs.config.oidc.oidc_providers
|
||||
# we should not have been instantiated if there is no configured provider.
|
||||
assert provider_conf is not None
|
||||
assert provider_confs
|
||||
|
||||
self._token_generator = OidcSessionTokenGenerator(hs)
|
||||
|
||||
self._provider = OidcProvider(hs, self._token_generator, provider_conf)
|
||||
self._providers = {
|
||||
p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
|
||||
}
|
||||
|
||||
async def load_metadata(self) -> None:
|
||||
"""Validate the config and load the metadata from the remote endpoint.
|
||||
|
||||
Called at startup to ensure we have everything we need.
|
||||
"""
|
||||
await self._provider.load_metadata()
|
||||
await self._provider.load_jwks()
|
||||
for idp_id, p in self._providers.items():
|
||||
try:
|
||||
await p.load_metadata()
|
||||
await p.load_jwks()
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
"Error while initialising OIDC provider %r" % (idp_id,)
|
||||
) from e
|
||||
|
||||
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
|
||||
"""Handle an incoming request to /_synapse/oidc/callback
|
||||
|
@ -184,6 +191,12 @@ class OidcHandler:
|
|||
self._sso_handler.render_error(request, "mismatching_session", str(e))
|
||||
return
|
||||
|
||||
oidc_provider = self._providers.get(session_data.idp_id)
|
||||
if not oidc_provider:
|
||||
logger.error("OIDC session uses unknown IdP %r", oidc_provider)
|
||||
self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP")
|
||||
return
|
||||
|
||||
if b"code" not in request.args:
|
||||
logger.info("Code parameter is missing")
|
||||
self._sso_handler.render_error(
|
||||
|
@ -193,7 +206,7 @@ class OidcHandler:
|
|||
|
||||
code = request.args[b"code"][0].decode()
|
||||
|
||||
await self._provider.handle_oidc_callback(request, session_data, code)
|
||||
await oidc_provider.handle_oidc_callback(request, session_data, code)
|
||||
|
||||
|
||||
class OidcError(Exception):
|
||||
|
|
|
@ -766,14 +766,24 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
|
|||
self.max_size = max_size
|
||||
|
||||
def dataReceived(self, data: bytes) -> None:
|
||||
# If the deferred was called, bail early.
|
||||
if self.deferred.called:
|
||||
return
|
||||
|
||||
self.stream.write(data)
|
||||
self.length += len(data)
|
||||
# The first time the maximum size is exceeded, error and cancel the
|
||||
# connection. dataReceived might be called again if data was received
|
||||
# in the meantime.
|
||||
if self.max_size is not None and self.length >= self.max_size:
|
||||
self.deferred.errback(BodyExceededMaxSize())
|
||||
self.deferred = defer.Deferred()
|
||||
self.transport.loseConnection()
|
||||
|
||||
def connectionLost(self, reason: Failure) -> None:
|
||||
# If the maximum size was already exceeded, there's nothing to do.
|
||||
if self.deferred.called:
|
||||
return
|
||||
|
||||
if reason.check(ResponseDone):
|
||||
self.deferred.callback(self.length)
|
||||
elif reason.check(PotentialDataLoss):
|
||||
|
|
|
@ -15,6 +15,9 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from twisted.web.http import Request
|
||||
|
||||
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
|
||||
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
|
||||
|
@ -23,6 +26,10 @@ from synapse.rest.admin._base import (
|
|||
assert_requester_is_admin,
|
||||
assert_user_is_admin,
|
||||
)
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -39,11 +46,11 @@ class QuarantineMediaInRoom(RestServlet):
|
|||
admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request, room_id: str):
|
||||
async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
|
@ -64,11 +71,11 @@ class QuarantineMediaByUser(RestServlet):
|
|||
|
||||
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request, user_id: str):
|
||||
async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
|
@ -91,11 +98,13 @@ class QuarantineMediaByID(RestServlet):
|
|||
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request, server_name: str, media_id: str):
|
||||
async def on_POST(
|
||||
self, request: Request, server_name: str, media_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
|
@ -109,17 +118,39 @@ class QuarantineMediaByID(RestServlet):
|
|||
return 200, {}
|
||||
|
||||
|
||||
class ProtectMediaByID(RestServlet):
|
||||
"""Protect local media from being quarantined.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
logging.info("Protecting local media by ID: %s", media_id)
|
||||
|
||||
# Quarantine this media id
|
||||
await self.store.mark_local_media_as_safe(media_id)
|
||||
|
||||
return 200, {}
|
||||
|
||||
|
||||
class ListMediaInRoom(RestServlet):
|
||||
"""Lists all of the media in a given room.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_GET(self, request, room_id):
|
||||
async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
is_admin = await self.auth.is_server_admin(requester.user)
|
||||
if not is_admin:
|
||||
|
@ -133,11 +164,11 @@ class ListMediaInRoom(RestServlet):
|
|||
class PurgeMediaCacheRestServlet(RestServlet):
|
||||
PATTERNS = admin_patterns("/purge_media_cache")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.media_repository = hs.get_media_repository()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request):
|
||||
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
before_ts = parse_integer(request, "before_ts", required=True)
|
||||
|
@ -154,13 +185,15 @@ class DeleteMediaByID(RestServlet):
|
|||
|
||||
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.server_name = hs.hostname
|
||||
self.media_repository = hs.get_media_repository()
|
||||
|
||||
async def on_DELETE(self, request, server_name: str, media_id: str):
|
||||
async def on_DELETE(
|
||||
self, request: Request, server_name: str, media_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
if self.server_name != server_name:
|
||||
|
@ -182,13 +215,13 @@ class DeleteMediaByDateSize(RestServlet):
|
|||
|
||||
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.server_name = hs.hostname
|
||||
self.media_repository = hs.get_media_repository()
|
||||
|
||||
async def on_POST(self, request, server_name: str):
|
||||
async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
before_ts = parse_integer(request, "before_ts", required=True)
|
||||
|
@ -222,7 +255,7 @@ class DeleteMediaByDateSize(RestServlet):
|
|||
return 200, {"deleted_media": deleted_media, "total": total}
|
||||
|
||||
|
||||
def register_servlets_for_media_repo(hs, http_server):
|
||||
def register_servlets_for_media_repo(hs: "HomeServer", http_server):
|
||||
"""
|
||||
Media repo specific APIs.
|
||||
"""
|
||||
|
@ -230,6 +263,7 @@ def register_servlets_for_media_repo(hs, http_server):
|
|||
QuarantineMediaInRoom(hs).register(http_server)
|
||||
QuarantineMediaByID(hs).register(http_server)
|
||||
QuarantineMediaByUser(hs).register(http_server)
|
||||
ProtectMediaByID(hs).register(http_server)
|
||||
ListMediaInRoom(hs).register(http_server)
|
||||
DeleteMediaByID(hs).register(http_server)
|
||||
DeleteMediaByDateSize(hs).register(http_server)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2019 New Vector Ltd
|
||||
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -17,10 +17,11 @@
|
|||
import logging
|
||||
import os
|
||||
import urllib
|
||||
from typing import Awaitable
|
||||
from typing import Awaitable, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
from twisted.internet.interfaces import IConsumer
|
||||
from twisted.protocols.basic import FileSender
|
||||
from twisted.web.http import Request
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError, cs_error
|
||||
from synapse.http.server import finish_request, respond_with_json
|
||||
|
@ -46,7 +47,7 @@ TEXT_CONTENT_TYPES = [
|
|||
]
|
||||
|
||||
|
||||
def parse_media_id(request):
|
||||
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
|
||||
try:
|
||||
# This allows users to append e.g. /test.png to the URL. Useful for
|
||||
# clients that parse the URL to see content type.
|
||||
|
@ -69,7 +70,7 @@ def parse_media_id(request):
|
|||
)
|
||||
|
||||
|
||||
def respond_404(request):
|
||||
def respond_404(request: Request) -> None:
|
||||
respond_with_json(
|
||||
request,
|
||||
404,
|
||||
|
@ -79,8 +80,12 @@ def respond_404(request):
|
|||
|
||||
|
||||
async def respond_with_file(
|
||||
request, media_type, file_path, file_size=None, upload_name=None
|
||||
):
|
||||
request: Request,
|
||||
media_type: str,
|
||||
file_path: str,
|
||||
file_size: Optional[int] = None,
|
||||
upload_name: Optional[str] = None,
|
||||
) -> None:
|
||||
logger.debug("Responding with %r", file_path)
|
||||
|
||||
if os.path.isfile(file_path):
|
||||
|
@ -98,15 +103,20 @@ async def respond_with_file(
|
|||
respond_404(request)
|
||||
|
||||
|
||||
def add_file_headers(request, media_type, file_size, upload_name):
|
||||
def add_file_headers(
|
||||
request: Request,
|
||||
media_type: str,
|
||||
file_size: Optional[int],
|
||||
upload_name: Optional[str],
|
||||
) -> None:
|
||||
"""Adds the correct response headers in preparation for responding with the
|
||||
media.
|
||||
|
||||
Args:
|
||||
request (twisted.web.http.Request)
|
||||
media_type (str): The media/content type.
|
||||
file_size (int): Size in bytes of the media, if known.
|
||||
upload_name (str): The name of the requested file, if any.
|
||||
request
|
||||
media_type: The media/content type.
|
||||
file_size: Size in bytes of the media, if known.
|
||||
upload_name: The name of the requested file, if any.
|
||||
"""
|
||||
|
||||
def _quote(x):
|
||||
|
@ -153,7 +163,8 @@ def add_file_headers(request, media_type, file_size, upload_name):
|
|||
# select private. don't bother setting Expires as all our
|
||||
# clients are smart enough to be happy with Cache-Control
|
||||
request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
|
||||
request.setHeader(b"Content-Length", b"%d" % (file_size,))
|
||||
if file_size is not None:
|
||||
request.setHeader(b"Content-Length", b"%d" % (file_size,))
|
||||
|
||||
# Tell web crawlers to not index, archive, or follow links in media. This
|
||||
# should help to prevent things in the media repo from showing up in web
|
||||
|
@ -184,7 +195,7 @@ _FILENAME_SEPARATOR_CHARS = {
|
|||
}
|
||||
|
||||
|
||||
def _can_encode_filename_as_token(x):
|
||||
def _can_encode_filename_as_token(x: str) -> bool:
|
||||
for c in x:
|
||||
# from RFC2616:
|
||||
#
|
||||
|
@ -206,17 +217,21 @@ def _can_encode_filename_as_token(x):
|
|||
|
||||
|
||||
async def respond_with_responder(
|
||||
request, responder, media_type, file_size, upload_name=None
|
||||
):
|
||||
request: Request,
|
||||
responder: "Optional[Responder]",
|
||||
media_type: str,
|
||||
file_size: Optional[int],
|
||||
upload_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Responds to the request with given responder. If responder is None then
|
||||
returns 404.
|
||||
|
||||
Args:
|
||||
request (twisted.web.http.Request)
|
||||
responder (Responder|None)
|
||||
media_type (str): The media/content type.
|
||||
file_size (int|None): Size in bytes of the media. If not known it should be None
|
||||
upload_name (str|None): The name of the requested file, if any.
|
||||
request
|
||||
responder
|
||||
media_type: The media/content type.
|
||||
file_size: Size in bytes of the media. If not known it should be None
|
||||
upload_name: The name of the requested file, if any.
|
||||
"""
|
||||
if request._disconnected:
|
||||
logger.warning(
|
||||
|
@ -308,22 +323,22 @@ class FileInfo:
|
|||
self.thumbnail_type = thumbnail_type
|
||||
|
||||
|
||||
def get_filename_from_headers(headers):
|
||||
def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
|
||||
"""
|
||||
Get the filename of the downloaded file by inspecting the
|
||||
Content-Disposition HTTP header.
|
||||
|
||||
Args:
|
||||
headers (dict[bytes, list[bytes]]): The HTTP request headers.
|
||||
headers: The HTTP request headers.
|
||||
|
||||
Returns:
|
||||
A Unicode string of the filename, or None.
|
||||
The filename, or None.
|
||||
"""
|
||||
content_disposition = headers.get(b"Content-Disposition", [b""])
|
||||
|
||||
# No header, bail out.
|
||||
if not content_disposition[0]:
|
||||
return
|
||||
return None
|
||||
|
||||
_, params = _parse_header(content_disposition[0])
|
||||
|
||||
|
@ -356,17 +371,16 @@ def get_filename_from_headers(headers):
|
|||
return upload_name
|
||||
|
||||
|
||||
def _parse_header(line):
|
||||
def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]:
|
||||
"""Parse a Content-type like header.
|
||||
|
||||
Cargo-culted from `cgi`, but works on bytes rather than strings.
|
||||
|
||||
Args:
|
||||
line (bytes): header to be parsed
|
||||
line: header to be parsed
|
||||
|
||||
Returns:
|
||||
Tuple[bytes, dict[bytes, bytes]]:
|
||||
the main content-type, followed by the parameter dictionary
|
||||
The main content-type, followed by the parameter dictionary
|
||||
"""
|
||||
parts = _parseparam(b";" + line)
|
||||
key = next(parts)
|
||||
|
@ -386,16 +400,16 @@ def _parse_header(line):
|
|||
return key, pdict
|
||||
|
||||
|
||||
def _parseparam(s):
|
||||
def _parseparam(s: bytes) -> Generator[bytes, None, None]:
|
||||
"""Generator which splits the input on ;, respecting double-quoted sequences
|
||||
|
||||
Cargo-culted from `cgi`, but works on bytes rather than strings.
|
||||
|
||||
Args:
|
||||
s (bytes): header to be parsed
|
||||
s: header to be parsed
|
||||
|
||||
Returns:
|
||||
Iterable[bytes]: the split input
|
||||
The split input
|
||||
"""
|
||||
while s[:1] == b";":
|
||||
s = s[1:]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 Will Hunt <will@half-shot.uk>
|
||||
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -14,22 +15,29 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.web.http import Request
|
||||
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
|
||||
class MediaConfigResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
config = hs.get_config()
|
||||
self.clock = hs.get_clock()
|
||||
self.auth = hs.get_auth()
|
||||
self.limits_dict = {"m.upload.size": config.max_upload_size}
|
||||
|
||||
async def _async_render_GET(self, request):
|
||||
async def _async_render_GET(self, request: Request) -> None:
|
||||
await self.auth.get_user_by_req(request)
|
||||
respond_with_json(request, 200, self.limits_dict, send_cors=True)
|
||||
|
||||
async def _async_render_OPTIONS(self, request):
|
||||
async def _async_render_OPTIONS(self, request: Request) -> None:
|
||||
respond_with_json(request, 200, {}, send_cors=True)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,24 +14,31 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.web.http import Request
|
||||
|
||||
import synapse.http.servlet
|
||||
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
||||
from synapse.http.servlet import parse_boolean
|
||||
|
||||
from ._base import parse_media_id, respond_404
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
from synapse.rest.media.v1.media_repository import MediaRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DownloadResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, media_repo):
|
||||
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
|
||||
super().__init__()
|
||||
self.media_repo = media_repo
|
||||
self.server_name = hs.hostname
|
||||
|
||||
async def _async_render_GET(self, request):
|
||||
async def _async_render_GET(self, request: Request) -> None:
|
||||
set_cors_headers(request)
|
||||
request.setHeader(
|
||||
b"Content-Security-Policy",
|
||||
|
@ -49,9 +57,7 @@ class DownloadResource(DirectServeJsonResource):
|
|||
if server_name == self.server_name:
|
||||
await self.media_repo.get_local_media(request, media_id, name)
|
||||
else:
|
||||
allow_remote = synapse.http.servlet.parse_boolean(
|
||||
request, "allow_remote", default=True
|
||||
)
|
||||
allow_remote = parse_boolean(request, "allow_remote", default=True)
|
||||
if not allow_remote:
|
||||
logger.info(
|
||||
"Rejecting request for remote media %s/%s due to allow_remote",
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -16,11 +17,12 @@
|
|||
import functools
|
||||
import os
|
||||
import re
|
||||
from typing import Callable, List
|
||||
|
||||
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
|
||||
|
||||
|
||||
def _wrap_in_base_path(func):
|
||||
def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]":
|
||||
"""Takes a function that returns a relative path and turns it into an
|
||||
absolute path based on the location of the primary media store
|
||||
"""
|
||||
|
@ -41,12 +43,18 @@ class MediaFilePaths:
|
|||
to write to the backup media store (when one is configured)
|
||||
"""
|
||||
|
||||
def __init__(self, primary_base_path):
|
||||
def __init__(self, primary_base_path: str):
|
||||
self.base_path = primary_base_path
|
||||
|
||||
def default_thumbnail_rel(
|
||||
self, default_top_level, default_sub_type, width, height, content_type, method
|
||||
):
|
||||
self,
|
||||
default_top_level: str,
|
||||
default_sub_type: str,
|
||||
width: int,
|
||||
height: int,
|
||||
content_type: str,
|
||||
method: str,
|
||||
) -> str:
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
|
||||
return os.path.join(
|
||||
|
@ -55,12 +63,14 @@ class MediaFilePaths:
|
|||
|
||||
default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
|
||||
|
||||
def local_media_filepath_rel(self, media_id):
|
||||
def local_media_filepath_rel(self, media_id: str) -> str:
|
||||
return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
|
||||
|
||||
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
|
||||
|
||||
def local_media_thumbnail_rel(self, media_id, width, height, content_type, method):
|
||||
def local_media_thumbnail_rel(
|
||||
self, media_id: str, width: int, height: int, content_type: str, method: str
|
||||
) -> str:
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
|
||||
return os.path.join(
|
||||
|
@ -86,7 +96,7 @@ class MediaFilePaths:
|
|||
media_id[4:],
|
||||
)
|
||||
|
||||
def remote_media_filepath_rel(self, server_name, file_id):
|
||||
def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
|
||||
return os.path.join(
|
||||
"remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
|
||||
)
|
||||
|
@ -94,8 +104,14 @@ class MediaFilePaths:
|
|||
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
|
||||
|
||||
def remote_media_thumbnail_rel(
|
||||
self, server_name, file_id, width, height, content_type, method
|
||||
):
|
||||
self,
|
||||
server_name: str,
|
||||
file_id: str,
|
||||
width: int,
|
||||
height: int,
|
||||
content_type: str,
|
||||
method: str,
|
||||
) -> str:
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
|
||||
return os.path.join(
|
||||
|
@ -113,7 +129,7 @@ class MediaFilePaths:
|
|||
# Should be removed after some time, when most of the thumbnails are stored
|
||||
# using the new path.
|
||||
def remote_media_thumbnail_rel_legacy(
|
||||
self, server_name, file_id, width, height, content_type
|
||||
self, server_name: str, file_id: str, width: int, height: int, content_type: str
|
||||
):
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
|
||||
|
@ -126,7 +142,7 @@ class MediaFilePaths:
|
|||
file_name,
|
||||
)
|
||||
|
||||
def remote_media_thumbnail_dir(self, server_name, file_id):
|
||||
def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
|
||||
return os.path.join(
|
||||
self.base_path,
|
||||
"remote_thumbnail",
|
||||
|
@ -136,7 +152,7 @@ class MediaFilePaths:
|
|||
file_id[4:],
|
||||
)
|
||||
|
||||
def url_cache_filepath_rel(self, media_id):
|
||||
def url_cache_filepath_rel(self, media_id: str) -> str:
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
# Media id is of the form <DATE><RANDOM_STRING>
|
||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||
|
@ -146,7 +162,7 @@ class MediaFilePaths:
|
|||
|
||||
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
|
||||
|
||||
def url_cache_filepath_dirs_to_delete(self, media_id):
|
||||
def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
|
||||
"The dirs to try and remove if we delete the media_id file"
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
return [os.path.join(self.base_path, "url_cache", media_id[:10])]
|
||||
|
@ -156,7 +172,9 @@ class MediaFilePaths:
|
|||
os.path.join(self.base_path, "url_cache", media_id[0:2]),
|
||||
]
|
||||
|
||||
def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method):
|
||||
def url_cache_thumbnail_rel(
|
||||
self, media_id: str, width: int, height: int, content_type: str, method: str
|
||||
) -> str:
|
||||
# Media id is of the form <DATE><RANDOM_STRING>
|
||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||
|
||||
|
@ -178,7 +196,7 @@ class MediaFilePaths:
|
|||
|
||||
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
|
||||
|
||||
def url_cache_thumbnail_directory(self, media_id):
|
||||
def url_cache_thumbnail_directory(self, media_id: str) -> str:
|
||||
# Media id is of the form <DATE><RANDOM_STRING>
|
||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||
|
||||
|
@ -195,7 +213,7 @@ class MediaFilePaths:
|
|||
media_id[4:],
|
||||
)
|
||||
|
||||
def url_cache_thumbnail_dirs_to_delete(self, media_id):
|
||||
def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
|
||||
"The dirs to try and remove if we delete the media_id thumbnails"
|
||||
# Media id is of the form <DATE><RANDOM_STRING>
|
||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,12 +13,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import IO, Dict, List, Optional, Tuple
|
||||
from io import BytesIO
|
||||
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import twisted.internet.error
|
||||
import twisted.web.http
|
||||
|
@ -56,6 +56,9 @@ from .thumbnail_resource import ThumbnailResource
|
|||
from .thumbnailer import Thumbnailer, ThumbnailError
|
||||
from .upload_resource import UploadResource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -63,7 +66,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
|
|||
|
||||
|
||||
class MediaRepository:
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.client = hs.get_federation_http_client()
|
||||
|
@ -73,16 +76,16 @@ class MediaRepository:
|
|||
self.max_upload_size = hs.config.max_upload_size
|
||||
self.max_image_pixels = hs.config.max_image_pixels
|
||||
|
||||
self.primary_base_path = hs.config.media_store_path
|
||||
self.filepaths = MediaFilePaths(self.primary_base_path)
|
||||
self.primary_base_path = hs.config.media_store_path # type: str
|
||||
self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths
|
||||
|
||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||
self.thumbnail_requirements = hs.config.thumbnail_requirements
|
||||
|
||||
self.remote_media_linearizer = Linearizer(name="media_remote")
|
||||
|
||||
self.recently_accessed_remotes = set()
|
||||
self.recently_accessed_locals = set()
|
||||
self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]]
|
||||
self.recently_accessed_locals = set() # type: Set[str]
|
||||
|
||||
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||
|
||||
|
@ -113,7 +116,7 @@ class MediaRepository:
|
|||
"update_recently_accessed_media", self._update_recently_accessed
|
||||
)
|
||||
|
||||
async def _update_recently_accessed(self):
|
||||
async def _update_recently_accessed(self) -> None:
|
||||
remote_media = self.recently_accessed_remotes
|
||||
self.recently_accessed_remotes = set()
|
||||
|
||||
|
@ -124,12 +127,12 @@ class MediaRepository:
|
|||
local_media, remote_media, self.clock.time_msec()
|
||||
)
|
||||
|
||||
def mark_recently_accessed(self, server_name, media_id):
|
||||
def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None:
|
||||
"""Mark the given media as recently accessed.
|
||||
|
||||
Args:
|
||||
server_name (str|None): Origin server of media, or None if local
|
||||
media_id (str): The media ID of the content
|
||||
server_name: Origin server of media, or None if local
|
||||
media_id: The media ID of the content
|
||||
"""
|
||||
if server_name:
|
||||
self.recently_accessed_remotes.add((server_name, media_id))
|
||||
|
@ -459,7 +462,14 @@ class MediaRepository:
|
|||
def _get_thumbnail_requirements(self, media_type):
|
||||
return self.thumbnail_requirements.get(media_type, ())
|
||||
|
||||
def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type):
|
||||
def _generate_thumbnail(
|
||||
self,
|
||||
thumbnailer: Thumbnailer,
|
||||
t_width: int,
|
||||
t_height: int,
|
||||
t_method: str,
|
||||
t_type: str,
|
||||
) -> Optional[BytesIO]:
|
||||
m_width = thumbnailer.width
|
||||
m_height = thumbnailer.height
|
||||
|
||||
|
@ -470,22 +480,20 @@ class MediaRepository:
|
|||
m_height,
|
||||
self.max_image_pixels,
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
if thumbnailer.transpose_method is not None:
|
||||
m_width, m_height = thumbnailer.transpose()
|
||||
|
||||
if t_method == "crop":
|
||||
t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
|
||||
return thumbnailer.crop(t_width, t_height, t_type)
|
||||
elif t_method == "scale":
|
||||
t_width, t_height = thumbnailer.aspect(t_width, t_height)
|
||||
t_width = min(m_width, t_width)
|
||||
t_height = min(m_height, t_height)
|
||||
t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
|
||||
else:
|
||||
t_byte_source = None
|
||||
return thumbnailer.scale(t_width, t_height, t_type)
|
||||
|
||||
return t_byte_source
|
||||
return None
|
||||
|
||||
async def generate_local_exact_thumbnail(
|
||||
self,
|
||||
|
@ -776,7 +784,7 @@ class MediaRepository:
|
|||
|
||||
return {"width": m_width, "height": m_height}
|
||||
|
||||
async def delete_old_remote_media(self, before_ts):
|
||||
async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:
|
||||
old_media = await self.store.get_remote_media_before(before_ts)
|
||||
|
||||
deleted = 0
|
||||
|
@ -928,7 +936,7 @@ class MediaRepositoryResource(Resource):
|
|||
within a given rectangle.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
# If we're not configured to use it, raise if we somehow got here.
|
||||
if not hs.config.can_load_media_repo:
|
||||
raise ConfigError("Synapse is not configured to use a media repo.")
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vecotr Ltd
|
||||
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -18,6 +18,8 @@ import os
|
|||
import shutil
|
||||
from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.interfaces import IConsumer
|
||||
from twisted.protocols.basic import FileSender
|
||||
|
||||
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
|
||||
|
@ -270,7 +272,7 @@ class MediaStorage:
|
|||
return self.filepaths.local_media_filepath_rel(file_info.file_id)
|
||||
|
||||
|
||||
def _write_file_synchronously(source, dest):
|
||||
def _write_file_synchronously(source: IO, dest: IO) -> None:
|
||||
"""Write `source` to the file like `dest` synchronously. Should be called
|
||||
from a thread.
|
||||
|
||||
|
@ -286,14 +288,14 @@ class FileResponder(Responder):
|
|||
"""Wraps an open file that can be sent to a request.
|
||||
|
||||
Args:
|
||||
open_file (file): A file like object to be streamed ot the client,
|
||||
open_file: A file like object to be streamed ot the client,
|
||||
is closed when finished streaming.
|
||||
"""
|
||||
|
||||
def __init__(self, open_file):
|
||||
def __init__(self, open_file: IO):
|
||||
self.open_file = open_file
|
||||
|
||||
def write_to_consumer(self, consumer):
|
||||
def write_to_consumer(self, consumer: IConsumer) -> Deferred:
|
||||
return make_deferred_yieldable(
|
||||
FileSender().beginFileTransfer(self.open_file, consumer)
|
||||
)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -12,7 +13,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
import errno
|
||||
import fnmatch
|
||||
|
@ -23,12 +23,13 @@ import re
|
|||
import shutil
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union
|
||||
from urllib import parse as urlparse
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.web.http import Request
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
|
@ -41,6 +42,7 @@ from synapse.http.servlet import parse_integer, parse_string
|
|||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.rest.media.v1._base import get_filename_from_headers
|
||||
from synapse.rest.media.v1.media_storage import MediaStorage
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.async_helpers import ObservableDeferred
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
|
@ -48,6 +50,12 @@ from synapse.util.stringutils import random_string
|
|||
|
||||
from ._base import FileInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lxml import etree
|
||||
|
||||
from synapse.app.homeserver import HomeServer
|
||||
from synapse.rest.media.v1.media_repository import MediaRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
|
||||
|
@ -119,7 +127,12 @@ class OEmbedError(Exception):
|
|||
class PreviewUrlResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, media_repo, media_storage):
|
||||
def __init__(
|
||||
self,
|
||||
hs: "HomeServer",
|
||||
media_repo: "MediaRepository",
|
||||
media_storage: MediaStorage,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
|
@ -166,11 +179,11 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
self._start_expire_url_cache_data, 10 * 1000
|
||||
)
|
||||
|
||||
async def _async_render_OPTIONS(self, request):
|
||||
async def _async_render_OPTIONS(self, request: Request) -> None:
|
||||
request.setHeader(b"Allow", b"OPTIONS, GET")
|
||||
respond_with_json(request, 200, {}, send_cors=True)
|
||||
|
||||
async def _async_render_GET(self, request):
|
||||
async def _async_render_GET(self, request: Request) -> None:
|
||||
|
||||
# XXX: if get_user_by_req fails, what should we do in an async render?
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
@ -450,7 +463,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
|
||||
raise OEmbedError() from e
|
||||
|
||||
async def _download_url(self, url: str, user):
|
||||
async def _download_url(self, url: str, user: str) -> Dict[str, Any]:
|
||||
# TODO: we should probably honour robots.txt... except in practice
|
||||
# we're most likely being explicitly triggered by a human rather than a
|
||||
# bot, so are we really a robot?
|
||||
|
@ -580,7 +593,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
"expire_url_cache_data", self._expire_url_cache_data
|
||||
)
|
||||
|
||||
async def _expire_url_cache_data(self):
|
||||
async def _expire_url_cache_data(self) -> None:
|
||||
"""Clean up expired url cache content, media and thumbnails.
|
||||
"""
|
||||
# TODO: Delete from backup media store
|
||||
|
@ -676,7 +689,9 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
logger.debug("No media removed from url cache")
|
||||
|
||||
|
||||
def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]:
|
||||
def decode_and_calc_og(
|
||||
body: bytes, media_uri: str, request_encoding: Optional[str] = None
|
||||
) -> Dict[str, Optional[str]]:
|
||||
# If there's no body, nothing useful is going to be found.
|
||||
if not body:
|
||||
return {}
|
||||
|
@ -697,7 +712,7 @@ def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]
|
|||
return og
|
||||
|
||||
|
||||
def _calc_og(tree, media_uri):
|
||||
def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]:
|
||||
# suck our tree into lxml and define our OG response.
|
||||
|
||||
# if we see any image URLs in the OG response, then spider them
|
||||
|
@ -801,7 +816,9 @@ def _calc_og(tree, media_uri):
|
|||
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
|
||||
)
|
||||
og["og:description"] = summarize_paragraphs(text_nodes)
|
||||
else:
|
||||
elif og["og:description"]:
|
||||
# This must be a non-empty string at this point.
|
||||
assert isinstance(og["og:description"], str)
|
||||
og["og:description"] = summarize_paragraphs([og["og:description"]])
|
||||
|
||||
# TODO: delete the url downloads to stop diskfilling,
|
||||
|
@ -809,7 +826,9 @@ def _calc_og(tree, media_uri):
|
|||
return og
|
||||
|
||||
|
||||
def _iterate_over_text(tree, *tags_to_ignore):
|
||||
def _iterate_over_text(
|
||||
tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
|
||||
) -> Generator[str, None, None]:
|
||||
"""Iterate over the tree returning text nodes in a depth first fashion,
|
||||
skipping text nodes inside certain tags.
|
||||
"""
|
||||
|
@ -843,32 +862,32 @@ def _iterate_over_text(tree, *tags_to_ignore):
|
|||
)
|
||||
|
||||
|
||||
def _rebase_url(url, base):
|
||||
base = list(urlparse.urlparse(base))
|
||||
url = list(urlparse.urlparse(url))
|
||||
if not url[0]: # fix up schema
|
||||
url[0] = base[0] or "http"
|
||||
if not url[1]: # fix up hostname
|
||||
url[1] = base[1]
|
||||
if not url[2].startswith("/"):
|
||||
url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2]
|
||||
return urlparse.urlunparse(url)
|
||||
def _rebase_url(url: str, base: str) -> str:
|
||||
base_parts = list(urlparse.urlparse(base))
|
||||
url_parts = list(urlparse.urlparse(url))
|
||||
if not url_parts[0]: # fix up schema
|
||||
url_parts[0] = base_parts[0] or "http"
|
||||
if not url_parts[1]: # fix up hostname
|
||||
url_parts[1] = base_parts[1]
|
||||
if not url_parts[2].startswith("/"):
|
||||
url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
|
||||
return urlparse.urlunparse(url_parts)
|
||||
|
||||
|
||||
def _is_media(content_type):
|
||||
if content_type.lower().startswith("image/"):
|
||||
return True
|
||||
def _is_media(content_type: str) -> bool:
|
||||
return content_type.lower().startswith("image/")
|
||||
|
||||
|
||||
def _is_html(content_type):
|
||||
def _is_html(content_type: str) -> bool:
|
||||
content_type = content_type.lower()
|
||||
if content_type.startswith("text/html") or content_type.startswith(
|
||||
return content_type.startswith("text/html") or content_type.startswith(
|
||||
"application/xhtml"
|
||||
):
|
||||
return True
|
||||
)
|
||||
|
||||
|
||||
def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
|
||||
def summarize_paragraphs(
|
||||
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
|
||||
) -> Optional[str]:
|
||||
# Try to get a summary of between 200 and 500 words, respecting
|
||||
# first paragraph and then word boundaries.
|
||||
# TODO: Respect sentences?
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,10 +13,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from synapse.config._base import Config
|
||||
from synapse.logging.context import defer_to_thread, run_in_background
|
||||
|
@ -27,13 +28,17 @@ from .media_storage import FileResponder
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
class StorageProvider:
|
||||
|
||||
class StorageProvider(metaclass=abc.ABCMeta):
|
||||
"""A storage provider is a service that can store uploaded media and
|
||||
retrieve them.
|
||||
"""
|
||||
|
||||
async def store_file(self, path: str, file_info: FileInfo):
|
||||
@abc.abstractmethod
|
||||
async def store_file(self, path: str, file_info: FileInfo) -> None:
|
||||
"""Store the file described by file_info. The actual contents can be
|
||||
retrieved by reading the file in file_info.upload_path.
|
||||
|
||||
|
@ -42,6 +47,7 @@ class StorageProvider:
|
|||
file_info: The metadata of the file.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
|
||||
"""Attempt to fetch the file described by file_info and stream it
|
||||
into writer.
|
||||
|
@ -78,10 +84,10 @@ class StorageProviderWrapper(StorageProvider):
|
|||
self.store_synchronous = store_synchronous
|
||||
self.store_remote = store_remote
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return "StorageProviderWrapper[%s]" % (self.backend,)
|
||||
|
||||
async def store_file(self, path, file_info):
|
||||
async def store_file(self, path: str, file_info: FileInfo) -> None:
|
||||
if not file_info.server_name and not self.store_local:
|
||||
return None
|
||||
|
||||
|
@ -91,7 +97,7 @@ class StorageProviderWrapper(StorageProvider):
|
|||
if self.store_synchronous:
|
||||
# store_file is supposed to return an Awaitable, but guard
|
||||
# against improper implementations.
|
||||
return await maybe_awaitable(self.backend.store_file(path, file_info))
|
||||
await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
|
||||
else:
|
||||
# TODO: Handle errors.
|
||||
async def store():
|
||||
|
@ -103,9 +109,8 @@ class StorageProviderWrapper(StorageProvider):
|
|||
logger.exception("Error storing file")
|
||||
|
||||
run_in_background(store)
|
||||
return None
|
||||
|
||||
async def fetch(self, path, file_info):
|
||||
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
|
||||
# store_file is supposed to return an Awaitable, but guard
|
||||
# against improper implementations.
|
||||
return await maybe_awaitable(self.backend.fetch(path, file_info))
|
||||
|
@ -115,11 +120,11 @@ class FileStorageProviderBackend(StorageProvider):
|
|||
"""A storage provider that stores files in a directory on a filesystem.
|
||||
|
||||
Args:
|
||||
hs (HomeServer)
|
||||
hs
|
||||
config: The config returned by `parse_config`.
|
||||
"""
|
||||
|
||||
def __init__(self, hs, config):
|
||||
def __init__(self, hs: "HomeServer", config: str):
|
||||
self.hs = hs
|
||||
self.cache_directory = hs.config.media_store_path
|
||||
self.base_directory = config
|
||||
|
@ -127,7 +132,7 @@ class FileStorageProviderBackend(StorageProvider):
|
|||
def __str__(self):
|
||||
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
|
||||
|
||||
async def store_file(self, path, file_info):
|
||||
async def store_file(self, path: str, file_info: FileInfo) -> None:
|
||||
"""See StorageProvider.store_file"""
|
||||
|
||||
primary_fname = os.path.join(self.cache_directory, path)
|
||||
|
@ -137,19 +142,21 @@ class FileStorageProviderBackend(StorageProvider):
|
|||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
return await defer_to_thread(
|
||||
await defer_to_thread(
|
||||
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
|
||||
)
|
||||
|
||||
async def fetch(self, path, file_info):
|
||||
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
|
||||
"""See StorageProvider.fetch"""
|
||||
|
||||
backup_fname = os.path.join(self.base_directory, path)
|
||||
if os.path.isfile(backup_fname):
|
||||
return FileResponder(open(backup_fname, "rb"))
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_config(config):
|
||||
def parse_config(config: dict) -> str:
|
||||
"""Called on startup to parse config supplied. This should parse
|
||||
the config and raise if there is a problem.
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 - 2016 OpenMarket Ltd
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -15,10 +16,14 @@
|
|||
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.web.http import Request
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
||||
from synapse.http.servlet import parse_integer, parse_string
|
||||
from synapse.rest.media.v1.media_storage import MediaStorage
|
||||
|
||||
from ._base import (
|
||||
FileInfo,
|
||||
|
@ -28,13 +33,22 @@ from ._base import (
|
|||
respond_with_responder,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
from synapse.rest.media.v1.media_repository import MediaRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThumbnailResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, media_repo, media_storage):
|
||||
def __init__(
|
||||
self,
|
||||
hs: "HomeServer",
|
||||
media_repo: "MediaRepository",
|
||||
media_storage: MediaStorage,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -43,7 +57,7 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||
self.server_name = hs.hostname
|
||||
|
||||
async def _async_render_GET(self, request):
|
||||
async def _async_render_GET(self, request: Request) -> None:
|
||||
set_cors_headers(request)
|
||||
server_name, media_id, _ = parse_media_id(request)
|
||||
width = parse_integer(request, "width", required=True)
|
||||
|
@ -73,8 +87,14 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||
self.media_repo.mark_recently_accessed(server_name, media_id)
|
||||
|
||||
async def _respond_local_thumbnail(
|
||||
self, request, media_id, width, height, method, m_type
|
||||
):
|
||||
self,
|
||||
request: Request,
|
||||
media_id: str,
|
||||
width: int,
|
||||
height: int,
|
||||
method: str,
|
||||
m_type: str,
|
||||
) -> None:
|
||||
media_info = await self.store.get_local_media(media_id)
|
||||
|
||||
if not media_info:
|
||||
|
@ -114,13 +134,13 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||
|
||||
async def _select_or_generate_local_thumbnail(
|
||||
self,
|
||||
request,
|
||||
media_id,
|
||||
desired_width,
|
||||
desired_height,
|
||||
desired_method,
|
||||
desired_type,
|
||||
):
|
||||
request: Request,
|
||||
media_id: str,
|
||||
desired_width: int,
|
||||
desired_height: int,
|
||||
desired_method: str,
|
||||
desired_type: str,
|
||||
) -> None:
|
||||
media_info = await self.store.get_local_media(media_id)
|
||||
|
||||
if not media_info:
|
||||
|
@ -178,14 +198,14 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||
|
||||
async def _select_or_generate_remote_thumbnail(
|
||||
self,
|
||||
request,
|
||||
server_name,
|
||||
media_id,
|
||||
desired_width,
|
||||
desired_height,
|
||||
desired_method,
|
||||
desired_type,
|
||||
):
|
||||
request: Request,
|
||||
server_name: str,
|
||||
media_id: str,
|
||||
desired_width: int,
|
||||
desired_height: int,
|
||||
desired_method: str,
|
||||
desired_type: str,
|
||||
) -> None:
|
||||
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
|
||||
|
||||
thumbnail_infos = await self.store.get_remote_media_thumbnails(
|
||||
|
@ -239,8 +259,15 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||
raise SynapseError(400, "Failed to generate thumbnail.")
|
||||
|
||||
async def _respond_remote_thumbnail(
|
||||
self, request, server_name, media_id, width, height, method, m_type
|
||||
):
|
||||
self,
|
||||
request: Request,
|
||||
server_name: str,
|
||||
media_id: str,
|
||||
width: int,
|
||||
height: int,
|
||||
method: str,
|
||||
m_type: str,
|
||||
) -> None:
|
||||
# TODO: Don't download the whole remote file
|
||||
# We should proxy the thumbnail from the remote server instead of
|
||||
# downloading the remote file and generating our own thumbnails.
|
||||
|
@ -275,12 +302,12 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||
|
||||
def _select_thumbnail(
|
||||
self,
|
||||
desired_width,
|
||||
desired_height,
|
||||
desired_method,
|
||||
desired_type,
|
||||
desired_width: int,
|
||||
desired_height: int,
|
||||
desired_method: str,
|
||||
desired_type: str,
|
||||
thumbnail_infos,
|
||||
):
|
||||
) -> dict:
|
||||
d_w = desired_width
|
||||
d_h = desired_height
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -14,6 +15,7 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from typing import Tuple
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -39,7 +41,7 @@ class Thumbnailer:
|
|||
|
||||
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
|
||||
|
||||
def __init__(self, input_path):
|
||||
def __init__(self, input_path: str):
|
||||
try:
|
||||
self.image = Image.open(input_path)
|
||||
except OSError as e:
|
||||
|
@ -59,11 +61,11 @@ class Thumbnailer:
|
|||
# A lot of parsing errors can happen when parsing EXIF
|
||||
logger.info("Error parsing image EXIF information: %s", e)
|
||||
|
||||
def transpose(self):
|
||||
def transpose(self) -> Tuple[int, int]:
|
||||
"""Transpose the image using its EXIF Orientation tag
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: (width, height) containing the new image size in pixels.
|
||||
A tuple containing the new image size in pixels as (width, height).
|
||||
"""
|
||||
if self.transpose_method is not None:
|
||||
self.image = self.image.transpose(self.transpose_method)
|
||||
|
@ -73,7 +75,7 @@ class Thumbnailer:
|
|||
self.image.info["exif"] = None
|
||||
return self.image.size
|
||||
|
||||
def aspect(self, max_width, max_height):
|
||||
def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]:
|
||||
"""Calculate the largest size that preserves aspect ratio which
|
||||
fits within the given rectangle::
|
||||
|
||||
|
@ -91,7 +93,7 @@ class Thumbnailer:
|
|||
else:
|
||||
return (max_height * self.width) // self.height, max_height
|
||||
|
||||
def _resize(self, width, height):
|
||||
def _resize(self, width: int, height: int) -> Image:
|
||||
# 1-bit or 8-bit color palette images need converting to RGB
|
||||
# otherwise they will be scaled using nearest neighbour which
|
||||
# looks awful
|
||||
|
@ -99,7 +101,7 @@ class Thumbnailer:
|
|||
self.image = self.image.convert("RGB")
|
||||
return self.image.resize((width, height), Image.ANTIALIAS)
|
||||
|
||||
def scale(self, width, height, output_type):
|
||||
def scale(self, width: int, height: int, output_type: str) -> BytesIO:
|
||||
"""Rescales the image to the given dimensions.
|
||||
|
||||
Returns:
|
||||
|
@ -108,7 +110,7 @@ class Thumbnailer:
|
|||
scaled = self._resize(width, height)
|
||||
return self._encode_image(scaled, output_type)
|
||||
|
||||
def crop(self, width, height, output_type):
|
||||
def crop(self, width: int, height: int, output_type: str) -> BytesIO:
|
||||
"""Rescales and crops the image to the given dimensions preserving
|
||||
aspect::
|
||||
(w_in / h_in) = (w_scaled / h_scaled)
|
||||
|
@ -136,7 +138,7 @@ class Thumbnailer:
|
|||
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
|
||||
return self._encode_image(cropped, output_type)
|
||||
|
||||
def _encode_image(self, output_image, output_type):
|
||||
def _encode_image(self, output_image: Image, output_type: str) -> BytesIO:
|
||||
output_bytes_io = BytesIO()
|
||||
fmt = self.FORMATS[output_type]
|
||||
if fmt == "JPEG":
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -14,18 +15,25 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.web.http import Request
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
from synapse.http.servlet import parse_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
from synapse.rest.media.v1.media_repository import MediaRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UploadResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, media_repo):
|
||||
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
|
||||
super().__init__()
|
||||
|
||||
self.media_repo = media_repo
|
||||
|
@ -37,10 +45,10 @@ class UploadResource(DirectServeJsonResource):
|
|||
self.max_upload_size = hs.config.max_upload_size
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
async def _async_render_OPTIONS(self, request):
|
||||
async def _async_render_OPTIONS(self, request: Request) -> None:
|
||||
respond_with_json(request, 200, {}, send_cors=True)
|
||||
|
||||
async def _async_render_POST(self, request):
|
||||
async def _async_render_POST(self, request: Request) -> None:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
# TODO: The checks here are a bit late. The content will have
|
||||
# already been uploaded to a tmp file at this point
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import EventContentFields
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events import make_event_from_dict
|
||||
|
@ -28,6 +30,25 @@ from synapse.types import JsonDict
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class _CalculateChainCover:
|
||||
"""Return value for _calculate_chain_cover_txn.
|
||||
"""
|
||||
|
||||
# The last room_id/depth/stream processed.
|
||||
room_id = attr.ib(type=str)
|
||||
depth = attr.ib(type=int)
|
||||
stream = attr.ib(type=int)
|
||||
|
||||
# Number of rows processed
|
||||
processed_count = attr.ib(type=int)
|
||||
|
||||
# Map from room_id to last depth/stream processed for each room that we have
|
||||
# processed all events for (i.e. the rooms we can flip the
|
||||
# `has_auth_chain_index` for)
|
||||
finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]])
|
||||
|
||||
|
||||
class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||
|
||||
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
|
||||
|
@ -719,138 +740,29 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
|
||||
current_room_id = progress.get("current_room_id", "")
|
||||
|
||||
# Have we finished processing the current room.
|
||||
finished = progress.get("finished", True)
|
||||
|
||||
# Where we've processed up to in the room, defaults to the start of the
|
||||
# room.
|
||||
last_depth = progress.get("last_depth", -1)
|
||||
last_stream = progress.get("last_stream", -1)
|
||||
|
||||
# Have we set the `has_auth_chain_index` for the room yet.
|
||||
has_set_room_has_chain_index = progress.get(
|
||||
"has_set_room_has_chain_index", False
|
||||
result = await self.db_pool.runInteraction(
|
||||
"_chain_cover_index",
|
||||
self._calculate_chain_cover_txn,
|
||||
current_room_id,
|
||||
last_depth,
|
||||
last_stream,
|
||||
batch_size,
|
||||
single_room=False,
|
||||
)
|
||||
|
||||
if finished:
|
||||
# If we've finished with the previous room (or its our first
|
||||
# iteration) we move on to the next room.
|
||||
finished = result.processed_count == 0
|
||||
|
||||
def _get_next_room(txn: Cursor) -> Optional[str]:
|
||||
sql = """
|
||||
SELECT room_id FROM rooms
|
||||
WHERE room_id > ?
|
||||
AND (
|
||||
NOT has_auth_chain_index
|
||||
OR has_auth_chain_index IS NULL
|
||||
)
|
||||
ORDER BY room_id
|
||||
LIMIT 1
|
||||
"""
|
||||
txn.execute(sql, (current_room_id,))
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
return row[0]
|
||||
total_rows_processed = result.processed_count
|
||||
current_room_id = result.room_id
|
||||
last_depth = result.depth
|
||||
last_stream = result.stream
|
||||
|
||||
return None
|
||||
|
||||
current_room_id = await self.db_pool.runInteraction(
|
||||
"_chain_cover_index", _get_next_room
|
||||
)
|
||||
if not current_room_id:
|
||||
await self.db_pool.updates._end_background_update("chain_cover")
|
||||
return 0
|
||||
|
||||
logger.debug("Adding chain cover to %s", current_room_id)
|
||||
|
||||
def _calculate_auth_chain(
|
||||
txn: Cursor, last_depth: int, last_stream: int
|
||||
) -> Tuple[int, int, int]:
|
||||
# Get the next set of events in the room (that we haven't already
|
||||
# computed chain cover for). We do this in topological order.
|
||||
|
||||
# We want to do a `(topological_ordering, stream_ordering) > (?,?)`
|
||||
# comparison, but that is not supported on older SQLite versions
|
||||
tuple_clause, tuple_args = make_tuple_comparison_clause(
|
||||
self.database_engine,
|
||||
[
|
||||
("topological_ordering", last_depth),
|
||||
("stream_ordering", last_stream),
|
||||
],
|
||||
)
|
||||
|
||||
sql = """
|
||||
SELECT
|
||||
event_id, state_events.type, state_events.state_key,
|
||||
topological_ordering, stream_ordering
|
||||
FROM events
|
||||
INNER JOIN state_events USING (event_id)
|
||||
LEFT JOIN event_auth_chains USING (event_id)
|
||||
LEFT JOIN event_auth_chain_to_calculate USING (event_id)
|
||||
WHERE events.room_id = ?
|
||||
AND event_auth_chains.event_id IS NULL
|
||||
AND event_auth_chain_to_calculate.event_id IS NULL
|
||||
AND %(tuple_cmp)s
|
||||
ORDER BY topological_ordering, stream_ordering
|
||||
LIMIT ?
|
||||
""" % {
|
||||
"tuple_cmp": tuple_clause,
|
||||
}
|
||||
|
||||
args = [current_room_id]
|
||||
args.extend(tuple_args)
|
||||
args.append(batch_size)
|
||||
|
||||
txn.execute(sql, args)
|
||||
rows = txn.fetchall()
|
||||
|
||||
# Put the results in the necessary format for
|
||||
# `_add_chain_cover_index`
|
||||
event_to_room_id = {row[0]: current_room_id for row in rows}
|
||||
event_to_types = {row[0]: (row[1], row[2]) for row in rows}
|
||||
|
||||
new_last_depth = rows[-1][3] if rows else last_depth # type: int
|
||||
new_last_stream = rows[-1][4] if rows else last_stream # type: int
|
||||
|
||||
count = len(rows)
|
||||
|
||||
# We also need to fetch the auth events for them.
|
||||
auth_events = self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="event_auth",
|
||||
column="event_id",
|
||||
iterable=event_to_room_id,
|
||||
keyvalues={},
|
||||
retcols=("event_id", "auth_id"),
|
||||
)
|
||||
|
||||
event_to_auth_chain = {} # type: Dict[str, List[str]]
|
||||
for row in auth_events:
|
||||
event_to_auth_chain.setdefault(row["event_id"], []).append(
|
||||
row["auth_id"]
|
||||
)
|
||||
|
||||
# Calculate and persist the chain cover index for this set of events.
|
||||
#
|
||||
# Annoyingly we need to gut wrench into the persit event store so that
|
||||
# we can reuse the function to calculate the chain cover for rooms.
|
||||
PersistEventsStore._add_chain_cover_index(
|
||||
txn,
|
||||
self.db_pool,
|
||||
event_to_room_id,
|
||||
event_to_types,
|
||||
event_to_auth_chain,
|
||||
)
|
||||
|
||||
return new_last_depth, new_last_stream, count
|
||||
|
||||
last_depth, last_stream, count = await self.db_pool.runInteraction(
|
||||
"_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
|
||||
)
|
||||
|
||||
total_rows_processed = count
|
||||
|
||||
if count < batch_size and not has_set_room_has_chain_index:
|
||||
for room_id, (depth, stream) in result.finished_room_map.items():
|
||||
# If we've done all the events in the room we flip the
|
||||
# `has_auth_chain_index` in the DB. Note that its possible for
|
||||
# further events to be persisted between the above and setting the
|
||||
|
@ -860,42 +772,159 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
|
||||
await self.db_pool.simple_update(
|
||||
table="rooms",
|
||||
keyvalues={"room_id": current_room_id},
|
||||
keyvalues={"room_id": room_id},
|
||||
updatevalues={"has_auth_chain_index": True},
|
||||
desc="_chain_cover_index",
|
||||
)
|
||||
has_set_room_has_chain_index = True
|
||||
|
||||
# Handle any events that might have raced with us flipping the
|
||||
# bit above.
|
||||
last_depth, last_stream, count = await self.db_pool.runInteraction(
|
||||
"_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
|
||||
result = await self.db_pool.runInteraction(
|
||||
"_chain_cover_index",
|
||||
self._calculate_chain_cover_txn,
|
||||
room_id,
|
||||
depth,
|
||||
stream,
|
||||
batch_size=None,
|
||||
single_room=True,
|
||||
)
|
||||
|
||||
total_rows_processed += count
|
||||
total_rows_processed += result.processed_count
|
||||
|
||||
# Note that at this point its technically possible that more events
|
||||
# than our `batch_size` have been persisted without their chain
|
||||
# cover, so we need to continue processing this room if the last
|
||||
# count returned was equal to the `batch_size`.
|
||||
if finished:
|
||||
await self.db_pool.updates._end_background_update("chain_cover")
|
||||
return total_rows_processed
|
||||
|
||||
if count < batch_size:
|
||||
# We've finished calculating the index for this room, move on to the
|
||||
# next room.
|
||||
await self.db_pool.updates._background_update_progress(
|
||||
"chain_cover", {"current_room_id": current_room_id, "finished": True},
|
||||
)
|
||||
else:
|
||||
# We still have outstanding events to calculate the index for.
|
||||
await self.db_pool.updates._background_update_progress(
|
||||
"chain_cover",
|
||||
{
|
||||
"current_room_id": current_room_id,
|
||||
"last_depth": last_depth,
|
||||
"last_stream": last_stream,
|
||||
"has_auth_chain_index": has_set_room_has_chain_index,
|
||||
"finished": False,
|
||||
},
|
||||
)
|
||||
await self.db_pool.updates._background_update_progress(
|
||||
"chain_cover",
|
||||
{
|
||||
"current_room_id": current_room_id,
|
||||
"last_depth": last_depth,
|
||||
"last_stream": last_stream,
|
||||
},
|
||||
)
|
||||
|
||||
return total_rows_processed
|
||||
|
||||
def _calculate_chain_cover_txn(
|
||||
self,
|
||||
txn: Cursor,
|
||||
last_room_id: str,
|
||||
last_depth: int,
|
||||
last_stream: int,
|
||||
batch_size: Optional[int],
|
||||
single_room: bool,
|
||||
) -> _CalculateChainCover:
|
||||
"""Calculate the chain cover for `batch_size` events, ordered by
|
||||
`(room_id, depth, stream)`.
|
||||
|
||||
Args:
|
||||
txn,
|
||||
last_room_id, last_depth, last_stream: The `(room_id, depth, stream)`
|
||||
tuple to fetch results after.
|
||||
batch_size: The maximum number of events to process. If None then
|
||||
no limit.
|
||||
single_room: Whether to calculate the index for just the given
|
||||
room.
|
||||
"""
|
||||
|
||||
# Get the next set of events in the room (that we haven't already
|
||||
# computed chain cover for). We do this in topological order.
|
||||
|
||||
# We want to do a `(topological_ordering, stream_ordering) > (?,?)`
|
||||
# comparison, but that is not supported on older SQLite versions
|
||||
tuple_clause, tuple_args = make_tuple_comparison_clause(
|
||||
self.database_engine,
|
||||
[
|
||||
("events.room_id", last_room_id),
|
||||
("topological_ordering", last_depth),
|
||||
("stream_ordering", last_stream),
|
||||
],
|
||||
)
|
||||
|
||||
extra_clause = ""
|
||||
if single_room:
|
||||
extra_clause = "AND events.room_id = ?"
|
||||
tuple_args.append(last_room_id)
|
||||
|
||||
sql = """
|
||||
SELECT
|
||||
event_id, state_events.type, state_events.state_key,
|
||||
topological_ordering, stream_ordering,
|
||||
events.room_id
|
||||
FROM events
|
||||
INNER JOIN state_events USING (event_id)
|
||||
LEFT JOIN event_auth_chains USING (event_id)
|
||||
LEFT JOIN event_auth_chain_to_calculate USING (event_id)
|
||||
WHERE event_auth_chains.event_id IS NULL
|
||||
AND event_auth_chain_to_calculate.event_id IS NULL
|
||||
AND %(tuple_cmp)s
|
||||
%(extra)s
|
||||
ORDER BY events.room_id, topological_ordering, stream_ordering
|
||||
%(limit)s
|
||||
""" % {
|
||||
"tuple_cmp": tuple_clause,
|
||||
"limit": "LIMIT ?" if batch_size is not None else "",
|
||||
"extra": extra_clause,
|
||||
}
|
||||
|
||||
if batch_size is not None:
|
||||
tuple_args.append(batch_size)
|
||||
|
||||
txn.execute(sql, tuple_args)
|
||||
rows = txn.fetchall()
|
||||
|
||||
# Put the results in the necessary format for
|
||||
# `_add_chain_cover_index`
|
||||
event_to_room_id = {row[0]: row[5] for row in rows}
|
||||
event_to_types = {row[0]: (row[1], row[2]) for row in rows}
|
||||
|
||||
# Calculate the new last position we've processed up to.
|
||||
new_last_depth = rows[-1][3] if rows else last_depth # type: int
|
||||
new_last_stream = rows[-1][4] if rows else last_stream # type: int
|
||||
new_last_room_id = rows[-1][5] if rows else "" # type: str
|
||||
|
||||
# Map from room_id to last depth/stream_ordering processed for the room,
|
||||
# excluding the last room (which we're likely still processing). We also
|
||||
# need to include the room passed in if it's not included in the result
|
||||
# set (as we then know we've processed all events in said room).
|
||||
#
|
||||
# This is the set of rooms that we can now safely flip the
|
||||
# `has_auth_chain_index` bit for.
|
||||
finished_rooms = {
|
||||
row[5]: (row[3], row[4]) for row in rows if row[5] != new_last_room_id
|
||||
}
|
||||
if last_room_id not in finished_rooms and last_room_id != new_last_room_id:
|
||||
finished_rooms[last_room_id] = (last_depth, last_stream)
|
||||
|
||||
count = len(rows)
|
||||
|
||||
# We also need to fetch the auth events for them.
|
||||
auth_events = self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="event_auth",
|
||||
column="event_id",
|
||||
iterable=event_to_room_id,
|
||||
keyvalues={},
|
||||
retcols=("event_id", "auth_id"),
|
||||
)
|
||||
|
||||
event_to_auth_chain = {} # type: Dict[str, List[str]]
|
||||
for row in auth_events:
|
||||
event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
|
||||
|
||||
# Calculate and persist the chain cover index for this set of events.
|
||||
#
|
||||
# Annoyingly we need to gut wrench into the persit event store so that
|
||||
# we can reuse the function to calculate the chain cover for rooms.
|
||||
PersistEventsStore._add_chain_cover_index(
|
||||
txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
|
||||
)
|
||||
|
||||
return _CalculateChainCover(
|
||||
room_id=new_last_room_id,
|
||||
depth=new_last_depth,
|
||||
stream=new_last_stream,
|
||||
processed_count=count,
|
||||
finished_room_map=finished_rooms,
|
||||
)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -169,7 +170,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
|
||||
async def get_local_media_before(
|
||||
self, before_ts: int, size_gt: int, keep_profiles: bool,
|
||||
) -> Optional[List[str]]:
|
||||
) -> List[str]:
|
||||
|
||||
# to find files that have never been accessed (last_access_ts IS NULL)
|
||||
# compare with `created_ts`
|
||||
|
|
|
@ -17,14 +17,13 @@
|
|||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from synapse.push import PusherConfig, ThrottleParams
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.types import Connection
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -315,7 +314,7 @@ class PusherStore(PusherWorkerStore):
|
|||
"device_display_name": device_display_name,
|
||||
"ts": pushkey_ts,
|
||||
"lang": lang,
|
||||
"data": bytearray(encode_canonical_json(data)),
|
||||
"data": json_encoder.encode(data),
|
||||
"last_stream_ordering": last_stream_ordering,
|
||||
"profile_tag": profile_tag,
|
||||
"id": stream_id,
|
||||
|
|
|
@ -75,3 +75,22 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
|
|||
if len(items) <= maxitems:
|
||||
return str(items)
|
||||
return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
|
||||
|
||||
|
||||
def strtobool(val: str) -> bool:
|
||||
"""Convert a string representation of truth to True or False
|
||||
|
||||
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
|
||||
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
|
||||
'val' is anything else.
|
||||
|
||||
This is lifted from distutils.util.strtobool, with the exception that it actually
|
||||
returns a bool, rather than an int.
|
||||
"""
|
||||
val = val.lower()
|
||||
if val in ("y", "yes", "t", "true", "on", "1"):
|
||||
return True
|
||||
elif val in ("n", "no", "f", "false", "off", "0"):
|
||||
return False
|
||||
else:
|
||||
raise ValueError("invalid truth value %r" % (val,))
|
||||
|
|
|
@ -145,7 +145,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
|
||||
|
||||
self.handler = hs.get_oidc_handler()
|
||||
self.provider = self.handler._provider
|
||||
self.provider = self.handler._providers["oidc"]
|
||||
sso_handler = hs.get_sso_handler()
|
||||
# Mock the render error method.
|
||||
self.render_error = Mock(return_value=None)
|
||||
|
@ -866,7 +866,7 @@ async def _make_callback_with_userinfo(
|
|||
from synapse.handlers.oidc_handler import OidcSessionData
|
||||
|
||||
handler = hs.get_oidc_handler()
|
||||
provider = handler._provider
|
||||
provider = handler._providers["oidc"]
|
||||
provider._exchange_code = simple_async_mock(return_value={})
|
||||
provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||
provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
||||
|
|
|
@ -1095,7 +1095,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||
# Expire both caches and repeat the request
|
||||
self.reactor.pump((10000.0,))
|
||||
|
||||
# Repated the request, this time it should fail if the lookup fails.
|
||||
# Repeat the request, this time it should fail if the lookup fails.
|
||||
fetch_d = defer.ensureDeferred(
|
||||
self.well_known_resolver.get_well_known(b"testserv")
|
||||
)
|
||||
|
@ -1130,7 +1130,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||
content=b'{ "m.server": "' + (b"a" * WELL_KNOWN_MAX_SIZE) + b'" }',
|
||||
)
|
||||
|
||||
# The result is sucessful, but disabled delegation.
|
||||
# The result is successful, but disabled delegation.
|
||||
r = self.successResultOf(fetch_d)
|
||||
self.assertIsNone(r.delegated_server)
|
||||
|
||||
|
|
101
tests/http/test_client.py
Normal file
101
tests/http/test_client.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web.client import ResponseDone
|
||||
|
||||
from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
|
||||
|
||||
from tests.unittest import TestCase
|
||||
|
||||
|
||||
class ReadBodyWithMaxSizeTests(TestCase):
|
||||
def setUp(self):
|
||||
"""Start reading the body, returns the response, result and proto"""
|
||||
self.response = Mock()
|
||||
self.result = BytesIO()
|
||||
self.deferred = read_body_with_max_size(self.response, self.result, 6)
|
||||
|
||||
# Fish the protocol out of the response.
|
||||
self.protocol = self.response.deliverBody.call_args[0][0]
|
||||
self.protocol.transport = Mock()
|
||||
|
||||
def _cleanup_error(self):
|
||||
"""Ensure that the error in the Deferred is handled gracefully."""
|
||||
called = [False]
|
||||
|
||||
def errback(f):
|
||||
called[0] = True
|
||||
|
||||
self.deferred.addErrback(errback)
|
||||
self.assertTrue(called[0])
|
||||
|
||||
def test_no_error(self):
|
||||
"""A response that is NOT too large."""
|
||||
|
||||
# Start sending data.
|
||||
self.protocol.dataReceived(b"12345")
|
||||
# Close the connection.
|
||||
self.protocol.connectionLost(Failure(ResponseDone()))
|
||||
|
||||
self.assertEqual(self.result.getvalue(), b"12345")
|
||||
self.assertEqual(self.deferred.result, 5)
|
||||
|
||||
def test_too_large(self):
|
||||
"""A response which is too large raises an exception."""
|
||||
|
||||
# Start sending data.
|
||||
self.protocol.dataReceived(b"1234567890")
|
||||
# Close the connection.
|
||||
self.protocol.connectionLost(Failure(ResponseDone()))
|
||||
|
||||
self.assertEqual(self.result.getvalue(), b"1234567890")
|
||||
self.assertIsInstance(self.deferred.result, Failure)
|
||||
self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
|
||||
self._cleanup_error()
|
||||
|
||||
def test_multiple_packets(self):
|
||||
"""Data should be accummulated through mutliple packets."""
|
||||
|
||||
# Start sending data.
|
||||
self.protocol.dataReceived(b"12")
|
||||
self.protocol.dataReceived(b"34")
|
||||
# Close the connection.
|
||||
self.protocol.connectionLost(Failure(ResponseDone()))
|
||||
|
||||
self.assertEqual(self.result.getvalue(), b"1234")
|
||||
self.assertEqual(self.deferred.result, 4)
|
||||
|
||||
def test_additional_data(self):
|
||||
"""A connection can receive data after being closed."""
|
||||
|
||||
# Start sending data.
|
||||
self.protocol.dataReceived(b"1234567890")
|
||||
self.assertIsInstance(self.deferred.result, Failure)
|
||||
self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
|
||||
self.protocol.transport.loseConnection.assert_called_once()
|
||||
|
||||
# More data might have come in.
|
||||
self.protocol.dataReceived(b"1234567890")
|
||||
# Close the connection.
|
||||
self.protocol.connectionLost(Failure(ResponseDone()))
|
||||
|
||||
self.assertEqual(self.result.getvalue(), b"1234567890")
|
||||
self.assertIsInstance(self.deferred.result, Failure)
|
||||
self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
|
||||
self._cleanup_error()
|
|
@ -153,8 +153,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
|
|||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
# Allow for uploading and downloading to/from the media repo
|
||||
self.media_repo = hs.get_media_repository_resource()
|
||||
self.download_resource = self.media_repo.children[b"download"]
|
||||
|
@ -428,7 +426,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
# Mark the second item as safe from quarantine.
|
||||
_, media_id_2 = server_and_media_id_2.split("/")
|
||||
self.get_success(self.store.mark_local_media_as_safe(media_id_2))
|
||||
# Quarantine the media
|
||||
url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),)
|
||||
channel = self.make_request("POST", url, access_token=admin_user_tok)
|
||||
self.pump(1.0)
|
||||
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
|
||||
|
||||
# Quarantine all media by this user
|
||||
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
|
||||
|
|
|
@ -475,7 +475,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
|||
session_id = channel.json_body["session"]
|
||||
|
||||
# do the OIDC auth, but auth as the wrong user
|
||||
channel = self.helper.auth_via_oidc("wrong_user", ui_auth_session_id=session_id)
|
||||
channel = self.helper.auth_via_oidc(
|
||||
{"sub": "wrong_user"}, ui_auth_session_id=session_id
|
||||
)
|
||||
|
||||
# that should return a failure message
|
||||
self.assertSubstring("We were unable to validate", channel.text_body)
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
@ -483,22 +483,20 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
|
|||
login.register_servlets,
|
||||
]
|
||||
|
||||
def test_background_update(self):
|
||||
"""Test that the background update to calculate auth chains for historic
|
||||
rooms works correctly.
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.user_id = self.register_user("foo", "pass")
|
||||
self.token = self.login("foo", "pass")
|
||||
self.requester = create_requester(self.user_id)
|
||||
|
||||
def _generate_room(self) -> Tuple[str, List[Set[str]]]:
|
||||
"""Insert a room without a chain cover index.
|
||||
"""
|
||||
|
||||
# Create a room
|
||||
user_id = self.register_user("foo", "pass")
|
||||
token = self.login("foo", "pass")
|
||||
room_id = self.helper.create_room_as(user_id, tok=token)
|
||||
requester = create_requester(user_id)
|
||||
|
||||
store = self.hs.get_datastore()
|
||||
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
|
||||
|
||||
# Mark the room as not having a chain cover index
|
||||
self.get_success(
|
||||
store.db_pool.simple_update(
|
||||
self.store.db_pool.simple_update(
|
||||
table="rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
updatevalues={"has_auth_chain_index": False},
|
||||
|
@ -508,42 +506,44 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
|
|||
|
||||
# Create a fork in the DAG with different events.
|
||||
event_handler = self.hs.get_event_creation_handler()
|
||||
latest_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
|
||||
latest_event_ids = self.get_success(
|
||||
self.store.get_prev_events_for_room(room_id)
|
||||
)
|
||||
event, context = self.get_success(
|
||||
event_handler.create_event(
|
||||
requester,
|
||||
self.requester,
|
||||
{
|
||||
"type": "some_state_type",
|
||||
"state_key": "",
|
||||
"content": {},
|
||||
"room_id": room_id,
|
||||
"sender": user_id,
|
||||
"sender": self.user_id,
|
||||
},
|
||||
prev_event_ids=latest_event_ids,
|
||||
)
|
||||
)
|
||||
self.get_success(
|
||||
event_handler.handle_new_client_event(requester, event, context)
|
||||
event_handler.handle_new_client_event(self.requester, event, context)
|
||||
)
|
||||
state1 = list(self.get_success(context.get_current_state_ids()).values())
|
||||
state1 = set(self.get_success(context.get_current_state_ids()).values())
|
||||
|
||||
event, context = self.get_success(
|
||||
event_handler.create_event(
|
||||
requester,
|
||||
self.requester,
|
||||
{
|
||||
"type": "some_state_type",
|
||||
"state_key": "",
|
||||
"content": {},
|
||||
"room_id": room_id,
|
||||
"sender": user_id,
|
||||
"sender": self.user_id,
|
||||
},
|
||||
prev_event_ids=latest_event_ids,
|
||||
)
|
||||
)
|
||||
self.get_success(
|
||||
event_handler.handle_new_client_event(requester, event, context)
|
||||
event_handler.handle_new_client_event(self.requester, event, context)
|
||||
)
|
||||
state2 = list(self.get_success(context.get_current_state_ids()).values())
|
||||
state2 = set(self.get_success(context.get_current_state_ids()).values())
|
||||
|
||||
# Delete the chain cover info.
|
||||
|
||||
|
@ -551,36 +551,191 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
|
|||
txn.execute("DELETE FROM event_auth_chains")
|
||||
txn.execute("DELETE FROM event_auth_chain_links")
|
||||
|
||||
self.get_success(store.db_pool.runInteraction("test", _delete_tables))
|
||||
self.get_success(self.store.db_pool.runInteraction("test", _delete_tables))
|
||||
|
||||
return room_id, [state1, state2]
|
||||
|
||||
def test_background_update_single_room(self):
|
||||
"""Test that the background update to calculate auth chains for historic
|
||||
rooms works correctly.
|
||||
"""
|
||||
|
||||
# Create a room
|
||||
room_id, states = self._generate_room()
|
||||
|
||||
# Insert and run the background update.
|
||||
self.get_success(
|
||||
store.db_pool.simple_insert(
|
||||
self.store.db_pool.simple_insert(
|
||||
"background_updates",
|
||||
{"update_name": "chain_cover", "progress_json": "{}"},
|
||||
)
|
||||
)
|
||||
|
||||
# Ugh, have to reset this flag
|
||||
store.db_pool.updates._all_done = False
|
||||
self.store.db_pool.updates._all_done = False
|
||||
|
||||
while not self.get_success(
|
||||
store.db_pool.updates.has_completed_background_updates()
|
||||
self.store.db_pool.updates.has_completed_background_updates()
|
||||
):
|
||||
self.get_success(
|
||||
store.db_pool.updates.do_next_background_update(100), by=0.1
|
||||
self.store.db_pool.updates.do_next_background_update(100), by=0.1
|
||||
)
|
||||
|
||||
# Test that the `has_auth_chain_index` has been set
|
||||
self.assertTrue(self.get_success(store.has_auth_chain_index(room_id)))
|
||||
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
|
||||
|
||||
# Test that calculating the auth chain difference using the newly
|
||||
# calculated chain cover works.
|
||||
self.get_success(
|
||||
store.db_pool.runInteraction(
|
||||
self.store.db_pool.runInteraction(
|
||||
"test",
|
||||
store._get_auth_chain_difference_using_cover_index_txn,
|
||||
self.store._get_auth_chain_difference_using_cover_index_txn,
|
||||
room_id,
|
||||
[state1, state2],
|
||||
states,
|
||||
)
|
||||
)
|
||||
|
||||
def test_background_update_multiple_rooms(self):
|
||||
"""Test that the background update to calculate auth chains for historic
|
||||
rooms works correctly.
|
||||
"""
|
||||
# Create a room
|
||||
room_id1, states1 = self._generate_room()
|
||||
room_id2, states2 = self._generate_room()
|
||||
room_id3, states2 = self._generate_room()
|
||||
|
||||
# Insert and run the background update.
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_insert(
|
||||
"background_updates",
|
||||
{"update_name": "chain_cover", "progress_json": "{}"},
|
||||
)
|
||||
)
|
||||
|
||||
# Ugh, have to reset this flag
|
||||
self.store.db_pool.updates._all_done = False
|
||||
|
||||
while not self.get_success(
|
||||
self.store.db_pool.updates.has_completed_background_updates()
|
||||
):
|
||||
self.get_success(
|
||||
self.store.db_pool.updates.do_next_background_update(100), by=0.1
|
||||
)
|
||||
|
||||
# Test that the `has_auth_chain_index` has been set
|
||||
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
|
||||
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
|
||||
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3)))
|
||||
|
||||
# Test that calculating the auth chain difference using the newly
|
||||
# calculated chain cover works.
|
||||
self.get_success(
|
||||
self.store.db_pool.runInteraction(
|
||||
"test",
|
||||
self.store._get_auth_chain_difference_using_cover_index_txn,
|
||||
room_id1,
|
||||
states1,
|
||||
)
|
||||
)
|
||||
|
||||
def test_background_update_single_large_room(self):
|
||||
"""Test that the background update to calculate auth chains for historic
|
||||
rooms works correctly.
|
||||
"""
|
||||
|
||||
# Create a room
|
||||
room_id, states = self._generate_room()
|
||||
|
||||
# Add a bunch of state so that it takes multiple iterations of the
|
||||
# background update to process the room.
|
||||
for i in range(0, 150):
|
||||
self.helper.send_state(
|
||||
room_id, event_type="m.test", body={"index": i}, tok=self.token
|
||||
)
|
||||
|
||||
# Insert and run the background update.
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_insert(
|
||||
"background_updates",
|
||||
{"update_name": "chain_cover", "progress_json": "{}"},
|
||||
)
|
||||
)
|
||||
|
||||
# Ugh, have to reset this flag
|
||||
self.store.db_pool.updates._all_done = False
|
||||
|
||||
iterations = 0
|
||||
while not self.get_success(
|
||||
self.store.db_pool.updates.has_completed_background_updates()
|
||||
):
|
||||
iterations += 1
|
||||
self.get_success(
|
||||
self.store.db_pool.updates.do_next_background_update(100), by=0.1
|
||||
)
|
||||
|
||||
# Ensure that we did actually take multiple iterations to process the
|
||||
# room.
|
||||
self.assertGreater(iterations, 1)
|
||||
|
||||
# Test that the `has_auth_chain_index` has been set
|
||||
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
|
||||
|
||||
# Test that calculating the auth chain difference using the newly
|
||||
# calculated chain cover works.
|
||||
self.get_success(
|
||||
self.store.db_pool.runInteraction(
|
||||
"test",
|
||||
self.store._get_auth_chain_difference_using_cover_index_txn,
|
||||
room_id,
|
||||
states,
|
||||
)
|
||||
)
|
||||
|
||||
def test_background_update_multiple_large_room(self):
|
||||
"""Test that the background update to calculate auth chains for historic
|
||||
rooms works correctly.
|
||||
"""
|
||||
|
||||
# Create the rooms
|
||||
room_id1, _ = self._generate_room()
|
||||
room_id2, _ = self._generate_room()
|
||||
|
||||
# Add a bunch of state so that it takes multiple iterations of the
|
||||
# background update to process the room.
|
||||
for i in range(0, 150):
|
||||
self.helper.send_state(
|
||||
room_id1, event_type="m.test", body={"index": i}, tok=self.token
|
||||
)
|
||||
|
||||
for i in range(0, 150):
|
||||
self.helper.send_state(
|
||||
room_id2, event_type="m.test", body={"index": i}, tok=self.token
|
||||
)
|
||||
|
||||
# Insert and run the background update.
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_insert(
|
||||
"background_updates",
|
||||
{"update_name": "chain_cover", "progress_json": "{}"},
|
||||
)
|
||||
)
|
||||
|
||||
# Ugh, have to reset this flag
|
||||
self.store.db_pool.updates._all_done = False
|
||||
|
||||
iterations = 0
|
||||
while not self.get_success(
|
||||
self.store.db_pool.updates.has_completed_background_updates()
|
||||
):
|
||||
iterations += 1
|
||||
self.get_success(
|
||||
self.store.db_pool.updates.do_next_background_update(100), by=0.1
|
||||
)
|
||||
|
||||
# Ensure that we did actually take multiple iterations to process the
|
||||
# room.
|
||||
self.assertGreater(iterations, 1)
|
||||
|
||||
# Test that the `has_auth_chain_index` has been set
|
||||
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
|
||||
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
|
||||
|
|
Loading…
Reference in a new issue