Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2021-01-18 11:14:37 +00:00
commit f5ab7d8306
43 changed files with 1337 additions and 753 deletions

1
changelog.d/9086.feature Normal file
View file

@ -0,0 +1 @@
Add an admin API for protecting local media from quarantine.

1
changelog.d/9093.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to media repository.

1
changelog.d/9108.bugfix Normal file
View file

@ -0,0 +1 @@
Fix "Unhandled error in Deferred: BodyExceededMaxSize" errors when .well-known files that are too large.

1
changelog.d/9110.feature Normal file
View file

@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.

1
changelog.d/9117.bugfix Normal file
View file

@ -0,0 +1 @@
Fix corruption of `pushers` data when a postgres bouncer is used.

1
changelog.d/9124.misc Normal file
View file

@ -0,0 +1 @@
Improve efficiency of large state resolutions.

1
changelog.d/9125.misc Normal file
View file

@ -0,0 +1 @@
Remove dependency on `distutils`.

1
changelog.d/9130.feature Normal file
View file

@ -0,0 +1 @@
Add experimental support for handling and persistence of to-device messages to happen on worker processes.

6
debian/changelog vendored
View file

@ -1,3 +1,9 @@
matrix-synapse-py3 (1.25.0ubuntu1) UNRELEASED; urgency=medium
* Remove dependency on `python3-distutils`.
-- Richard van der Hoff <richard@matrix.org> Fri, 15 Jan 2021 12:44:19 +0000
matrix-synapse-py3 (1.25.0) stable; urgency=medium
[ Dan Callahan ]

1
debian/control vendored
View file

@ -31,7 +31,6 @@ Pre-Depends: dpkg (>= 1.16.1)
Depends:
adduser,
debconf,
python3-distutils|libpython3-stdlib (<< 3.6),
${misc:Depends},
${shlibs:Depends},
${synapse:pydepends},

View file

@ -4,6 +4,7 @@
* [Quarantining media by ID](#quarantining-media-by-id)
* [Quarantining media in a room](#quarantining-media-in-a-room)
* [Quarantining all media of a user](#quarantining-all-media-of-a-user)
* [Protecting media from being quarantined](#protecting-media-from-being-quarantined)
- [Delete local media](#delete-local-media)
* [Delete a specific local media](#delete-a-specific-local-media)
* [Delete local media by date or size](#delete-local-media-by-date-or-size)
@ -123,6 +124,29 @@ The following fields are returned in the JSON response body:
* `num_quarantined`: integer - The number of media items successfully quarantined
## Protecting media from being quarantined
This API protects a single piece of local media from being quarantined using the
above APIs. This is useful for sticker packs and other shared media which you do
not want to get quarantined, especially when
[quarantining media in a room](#quarantining-media-in-a-room).
Request:
```
POST /_synapse/admin/v1/media/protect/<media_id>
{}
```
Where `media_id` is in the form of `abcdefg12345...`.
Response:
```json
{}
```
# Delete local media
This API deletes the *local* media from the disk of your own server.
This includes any local thumbnails and copies of media downloaded from

View file

@ -42,11 +42,10 @@ as follows:
* For other installation mechanisms, see the documentation provided by the
maintainer.
To enable the OpenID integration, you should then add an `oidc_config` section
to your configuration file (or uncomment the `enabled: true` line in the
existing section). See [sample_config.yaml](./sample_config.yaml) for some
sample settings, as well as the text below for example configurations for
specific providers.
To enable the OpenID integration, you should then add a section to the `oidc_providers`
setting in your configuration file (or uncomment one of the existing examples).
See [sample_config.yaml](./sample_config.yaml) for some sample settings, as well as
the text below for example configurations for specific providers.
## Sample configs
@ -62,20 +61,21 @@ Directory (tenant) ID as it will be used in the Azure links.
Edit your Synapse config file and change the `oidc_config` section:
```yaml
oidc_config:
enabled: true
issuer: "https://login.microsoftonline.com/<tenant id>/v2.0"
client_id: "<client id>"
client_secret: "<client secret>"
scopes: ["openid", "profile"]
authorization_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/authorize"
token_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/token"
userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo"
oidc_providers:
- idp_id: microsoft
idp_name: Microsoft
issuer: "https://login.microsoftonline.com/<tenant id>/v2.0"
client_id: "<client id>"
client_secret: "<client secret>"
scopes: ["openid", "profile"]
authorization_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/authorize"
token_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/token"
userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo"
user_mapping_provider:
config:
localpart_template: "{{ user.preferred_username.split('@')[0] }}"
display_name_template: "{{ user.name }}"
user_mapping_provider:
config:
localpart_template: "{{ user.preferred_username.split('@')[0] }}"
display_name_template: "{{ user.name }}"
```
### [Dex][dex-idp]
@ -103,17 +103,18 @@ Run with `dex serve examples/config-dev.yaml`.
Synapse config:
```yaml
oidc_config:
enabled: true
skip_verification: true # This is needed as Dex is served on an insecure endpoint
issuer: "http://127.0.0.1:5556/dex"
client_id: "synapse"
client_secret: "secret"
scopes: ["openid", "profile"]
user_mapping_provider:
config:
localpart_template: "{{ user.name }}"
display_name_template: "{{ user.name|capitalize }}"
oidc_providers:
- idp_id: dex
idp_name: "My Dex server"
skip_verification: true # This is needed as Dex is served on an insecure endpoint
issuer: "http://127.0.0.1:5556/dex"
client_id: "synapse"
client_secret: "secret"
scopes: ["openid", "profile"]
user_mapping_provider:
config:
localpart_template: "{{ user.name }}"
display_name_template: "{{ user.name|capitalize }}"
```
### [Keycloak][keycloak-idp]
@ -152,16 +153,17 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to
8. Copy Secret
```yaml
oidc_config:
enabled: true
issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}"
client_id: "synapse"
client_secret: "copy secret generated from above"
scopes: ["openid", "profile"]
user_mapping_provider:
config:
localpart_template: "{{ user.preferred_username }}"
display_name_template: "{{ user.name }}"
oidc_providers:
- idp_id: keycloak
idp_name: "My KeyCloak server"
issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}"
client_id: "synapse"
client_secret: "copy secret generated from above"
scopes: ["openid", "profile"]
user_mapping_provider:
config:
localpart_template: "{{ user.preferred_username }}"
display_name_template: "{{ user.name }}"
```
### [Auth0][auth0]
@ -191,16 +193,17 @@ oidc_config:
Synapse config:
```yaml
oidc_config:
enabled: true
issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
scopes: ["openid", "profile"]
user_mapping_provider:
config:
localpart_template: "{{ user.preferred_username }}"
display_name_template: "{{ user.name }}"
oidc_providers:
- idp_id: auth0
idp_name: Auth0
issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
scopes: ["openid", "profile"]
user_mapping_provider:
config:
localpart_template: "{{ user.preferred_username }}"
display_name_template: "{{ user.name }}"
```
### GitHub
@ -219,21 +222,22 @@ does not return a `sub` property, an alternative `subject_claim` has to be set.
Synapse config:
```yaml
oidc_config:
enabled: true
discover: false
issuer: "https://github.com/"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
authorization_endpoint: "https://github.com/login/oauth/authorize"
token_endpoint: "https://github.com/login/oauth/access_token"
userinfo_endpoint: "https://api.github.com/user"
scopes: ["read:user"]
user_mapping_provider:
config:
subject_claim: "id"
localpart_template: "{{ user.login }}"
display_name_template: "{{ user.name }}"
oidc_providers:
- idp_id: github
idp_name: Github
discover: false
issuer: "https://github.com/"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
authorization_endpoint: "https://github.com/login/oauth/authorize"
token_endpoint: "https://github.com/login/oauth/access_token"
userinfo_endpoint: "https://api.github.com/user"
scopes: ["read:user"]
user_mapping_provider:
config:
subject_claim: "id"
localpart_template: "{{ user.login }}"
display_name_template: "{{ user.name }}"
```
### [Google][google-idp]
@ -243,16 +247,17 @@ oidc_config:
2. add an "OAuth Client ID" for a Web Application under "Credentials".
3. Copy the Client ID and Client Secret, and add the following to your synapse config:
```yaml
oidc_config:
enabled: true
issuer: "https://accounts.google.com/"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
scopes: ["openid", "profile"]
user_mapping_provider:
config:
localpart_template: "{{ user.given_name|lower }}"
display_name_template: "{{ user.name }}"
oidc_providers:
- idp_id: google
idp_name: Google
issuer: "https://accounts.google.com/"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
scopes: ["openid", "profile"]
user_mapping_provider:
config:
localpart_template: "{{ user.given_name|lower }}"
display_name_template: "{{ user.name }}"
```
4. Back in the Google console, add this Authorized redirect URI: `[synapse
public baseurl]/_synapse/oidc/callback`.
@ -266,16 +271,17 @@ oidc_config:
Synapse config:
```yaml
oidc_config:
enabled: true
issuer: "https://id.twitch.tv/oauth2/"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
client_auth_method: "client_secret_post"
user_mapping_provider:
config:
localpart_template: "{{ user.preferred_username }}"
display_name_template: "{{ user.name }}"
oidc_providers:
- idp_id: twitch
idp_name: Twitch
issuer: "https://id.twitch.tv/oauth2/"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
client_auth_method: "client_secret_post"
user_mapping_provider:
config:
localpart_template: "{{ user.preferred_username }}"
display_name_template: "{{ user.name }}"
```
### GitLab
@ -287,16 +293,17 @@ oidc_config:
Synapse config:
```yaml
oidc_config:
enabled: true
issuer: "https://gitlab.com/"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
client_auth_method: "client_secret_post"
scopes: ["openid", "read_user"]
user_profile_method: "userinfo_endpoint"
user_mapping_provider:
config:
localpart_template: '{{ user.nickname }}'
display_name_template: '{{ user.name }}'
oidc_providers:
- idp_id: gitlab
idp_name: Gitlab
issuer: "https://gitlab.com/"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
client_auth_method: "client_secret_post"
scopes: ["openid", "read_user"]
user_profile_method: "userinfo_endpoint"
user_mapping_provider:
config:
localpart_template: '{{ user.nickname }}'
display_name_template: '{{ user.name }}'
```

View file

@ -1709,141 +1709,149 @@ saml2_config:
#idp_entityid: 'https://our_idp/entityid'
# Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login.
# List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
# and login.
#
# Options for each entry include:
#
# idp_id: a unique identifier for this identity provider. Used internally
# by Synapse; should be a single word such as 'github'.
#
# Note that, if this is changed, users authenticating via that provider
# will no longer be recognised as the same user!
#
# idp_name: A user-facing name for this identity provider, which is used to
# offer the user a choice of login mechanisms.
#
# discover: set to 'false' to disable the use of the OIDC discovery mechanism
# to discover endpoints. Defaults to true.
#
# issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery
# is enabled) to discover the provider's endpoints.
#
# client_id: Required. oauth2 client id to use.
#
# client_secret: Required. oauth2 client secret to use.
#
# client_auth_method: auth method to use when exchanging the token. Valid
# values are 'client_secret_basic' (default), 'client_secret_post' and
# 'none'.
#
# scopes: list of scopes to request. This should normally include the "openid"
# scope. Defaults to ["openid"].
#
# authorization_endpoint: the oauth2 authorization endpoint. Required if
# provider discovery is disabled.
#
# token_endpoint: the oauth2 token endpoint. Required if provider discovery is
# disabled.
#
# userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is
# disabled and the 'openid' scope is not requested.
#
# jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and
# the 'openid' scope is used.
#
# skip_verification: set to 'true' to skip metadata verification. Use this if
# you are connecting to a provider that is not OpenID Connect compliant.
# Defaults to false. Avoid this in production.
#
# user_profile_method: Whether to fetch the user profile from the userinfo
# endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
#
# Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
# included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
# userinfo endpoint.
#
# allow_existing_users: set to 'true' to allow a user logging in via OIDC to
# match a pre-existing account instead of failing. This could be used if
# switching from password logins to OIDC. Defaults to false.
#
# user_mapping_provider: Configuration for how attributes returned from a OIDC
# provider are mapped onto a matrix user. This setting has the following
# sub-properties:
#
# module: The class name of a custom mapping module. Default is
# 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'.
# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
# for information on implementing a custom mapping provider.
#
# config: Configuration for the mapping provider module. This section will
# be passed as a Python dictionary to the user mapping provider
# module's `parse_config` method.
#
# For the default provider, the following settings are available:
#
# sub: name of the claim containing a unique identifier for the
# user. Defaults to 'sub', which OpenID Connect compliant
# providers should provide.
#
# localpart_template: Jinja2 template for the localpart of the MXID.
# If this is not set, the user will be prompted to choose their
# own username.
#
# display_name_template: Jinja2 template for the display name to set
# on first login. If unset, no displayname will be set.
#
# extra_attributes: a map of Jinja2 templates for extra attributes
# to send back to the client during login.
# Note that these are non-standard and clients will ignore them
# without modifications.
#
# When rendering, the Jinja2 templates are given a 'user' variable,
# which is set to the claims returned by the UserInfo Endpoint and/or
# in the ID Token.
#
# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
# for some example configurations.
# for information on how to configure these options.
#
oidc_config:
# Uncomment the following to enable authorization against an OpenID Connect
# server. Defaults to false.
# For backwards compatibility, it is also possible to configure a single OIDC
# provider via an 'oidc_config' setting. This is now deprecated and admins are
# advised to migrate to the 'oidc_providers' format.
#
oidc_providers:
# Generic example
#
#enabled: true
#- idp_id: my_idp
# idp_name: "My OpenID provider"
# discover: false
# issuer: "https://accounts.example.com/"
# client_id: "provided-by-your-issuer"
# client_secret: "provided-by-your-issuer"
# client_auth_method: client_secret_post
# scopes: ["openid", "profile"]
# authorization_endpoint: "https://accounts.example.com/oauth2/auth"
# token_endpoint: "https://accounts.example.com/oauth2/token"
# userinfo_endpoint: "https://accounts.example.com/userinfo"
# jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
# skip_verification: true
# Uncomment the following to disable use of the OIDC discovery mechanism to
# discover endpoints. Defaults to true.
# For use with Keycloak
#
#discover: false
#- idp_id: keycloak
# idp_name: Keycloak
# issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name"
# client_id: "synapse"
# client_secret: "copy secret generated in Keycloak UI"
# scopes: ["openid", "profile"]
# the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
# discover the provider's endpoints.
# For use with Github
#
# Required if 'enabled' is true.
#
#issuer: "https://accounts.example.com/"
# oauth2 client id to use.
#
# Required if 'enabled' is true.
#
#client_id: "provided-by-your-issuer"
# oauth2 client secret to use.
#
# Required if 'enabled' is true.
#
#client_secret: "provided-by-your-issuer"
# auth method to use when exchanging the token.
# Valid values are 'client_secret_basic' (default), 'client_secret_post' and
# 'none'.
#
#client_auth_method: client_secret_post
# list of scopes to request. This should normally include the "openid" scope.
# Defaults to ["openid"].
#
#scopes: ["openid", "profile"]
# the oauth2 authorization endpoint. Required if provider discovery is disabled.
#
#authorization_endpoint: "https://accounts.example.com/oauth2/auth"
# the oauth2 token endpoint. Required if provider discovery is disabled.
#
#token_endpoint: "https://accounts.example.com/oauth2/token"
# the OIDC userinfo endpoint. Required if discovery is disabled and the
# "openid" scope is not requested.
#
#userinfo_endpoint: "https://accounts.example.com/userinfo"
# URI where to fetch the JWKS. Required if discovery is disabled and the
# "openid" scope is used.
#
#jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
# Uncomment to skip metadata verification. Defaults to false.
#
# Use this if you are connecting to a provider that is not OpenID Connect
# compliant.
# Avoid this in production.
#
#skip_verification: true
# Whether to fetch the user profile from the userinfo endpoint. Valid
# values are: "auto" or "userinfo_endpoint".
#
# Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
# in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
#
#user_profile_method: "userinfo_endpoint"
# Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
# of failing. This could be used if switching from password logins to OIDC. Defaults to false.
#
#allow_existing_users: true
# An external module can be provided here as a custom solution to mapping
# attributes returned from a OIDC provider onto a matrix user.
#
user_mapping_provider:
# The custom module's class. Uncomment to use a custom module.
# Default is 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'.
#
# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
# for information on implementing a custom mapping provider.
#
#module: mapping_provider.OidcMappingProvider
# Custom configuration values for the module. This section will be passed as
# a Python dictionary to the user mapping provider module's `parse_config`
# method.
#
# The examples below are intended for the default provider: they should be
# changed if using a custom provider.
#
config:
# name of the claim containing a unique identifier for the user.
# Defaults to `sub`, which OpenID Connect compliant providers should provide.
#
#subject_claim: "sub"
# Jinja2 template for the localpart of the MXID.
#
# When rendering, this template is given the following variables:
# * user: The claims returned by the UserInfo Endpoint and/or in the ID
# Token
#
# If this is not set, the user will be prompted to choose their
# own username.
#
#localpart_template: "{{ user.preferred_username }}"
# Jinja2 template for the display name to set on first login.
#
# If unset, no displayname will be set.
#
#display_name_template: "{{ user.given_name }} {{ user.last_name }}"
# Jinja2 templates for extra attributes to send back to the client during
# login.
#
# Note that these are non-standard and clients will ignore them without modifications.
#
#extra_attributes:
#birthdate: "{{ user.birthdate }}"
#- idp_id: google
# idp_name: Google
# discover: false
# issuer: "https://github.com/"
# client_id: "your-client-id" # TO BE FILLED
# client_secret: "your-client-secret" # TO BE FILLED
# authorization_endpoint: "https://github.com/login/oauth/authorize"
# token_endpoint: "https://github.com/login/oauth/access_token"
# userinfo_endpoint: "https://api.github.com/user"
# scopes: ["read:user"]
# user_mapping_provider:
# config:
# subject_claim: "id"
# localpart_template: "{ user.login }"
# display_name_template: "{ user.name }"
# Enable Central Authentication Service (CAS) for registration and login.

View file

@ -16,6 +16,9 @@ workers only work with PostgreSQL-based Synapse deployments. SQLite should only
be used for demo purposes and any admin considering workers should already be
running PostgreSQL.
See also https://matrix.org/blog/2020/11/03/how-we-fixed-synapses-scalability
for a higher level overview.
## Main process/worker communication
The processes communicate with each other via a Synapse-specific protocol called

View file

@ -40,7 +40,7 @@ class CasConfig(Config):
self.cas_required_attributes = {}
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
return """\
# Enable Central Authentication Service (CAS) for registration and login.
#
cas_config:

View file

@ -15,7 +15,7 @@
# limitations under the License.
import string
from typing import Optional, Type
from typing import Iterable, Optional, Type
import attr
@ -33,16 +33,8 @@ class OIDCConfig(Config):
section = "oidc"
def read_config(self, config, **kwargs):
validate_config(MAIN_CONFIG_SCHEMA, config, ())
self.oidc_provider = None # type: Optional[OidcProviderConfig]
oidc_config = config.get("oidc_config")
if oidc_config and oidc_config.get("enabled", False):
validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
self.oidc_provider = _parse_oidc_config_dict(oidc_config)
if not self.oidc_provider:
self.oidc_providers = tuple(_parse_oidc_provider_configs(config))
if not self.oidc_providers:
return
try:
@ -58,144 +50,153 @@ class OIDCConfig(Config):
@property
def oidc_enabled(self) -> bool:
# OIDC is enabled if we have a provider
return bool(self.oidc_provider)
return bool(self.oidc_providers)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
# Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login.
# List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
# and login.
#
# Options for each entry include:
#
# idp_id: a unique identifier for this identity provider. Used internally
# by Synapse; should be a single word such as 'github'.
#
# Note that, if this is changed, users authenticating via that provider
# will no longer be recognised as the same user!
#
# idp_name: A user-facing name for this identity provider, which is used to
# offer the user a choice of login mechanisms.
#
# discover: set to 'false' to disable the use of the OIDC discovery mechanism
# to discover endpoints. Defaults to true.
#
# issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery
# is enabled) to discover the provider's endpoints.
#
# client_id: Required. oauth2 client id to use.
#
# client_secret: Required. oauth2 client secret to use.
#
# client_auth_method: auth method to use when exchanging the token. Valid
# values are 'client_secret_basic' (default), 'client_secret_post' and
# 'none'.
#
# scopes: list of scopes to request. This should normally include the "openid"
# scope. Defaults to ["openid"].
#
# authorization_endpoint: the oauth2 authorization endpoint. Required if
# provider discovery is disabled.
#
# token_endpoint: the oauth2 token endpoint. Required if provider discovery is
# disabled.
#
# userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is
# disabled and the 'openid' scope is not requested.
#
# jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and
# the 'openid' scope is used.
#
# skip_verification: set to 'true' to skip metadata verification. Use this if
# you are connecting to a provider that is not OpenID Connect compliant.
# Defaults to false. Avoid this in production.
#
# user_profile_method: Whether to fetch the user profile from the userinfo
# endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
#
# Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
# included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
# userinfo endpoint.
#
# allow_existing_users: set to 'true' to allow a user logging in via OIDC to
# match a pre-existing account instead of failing. This could be used if
# switching from password logins to OIDC. Defaults to false.
#
# user_mapping_provider: Configuration for how attributes returned from a OIDC
# provider are mapped onto a matrix user. This setting has the following
# sub-properties:
#
# module: The class name of a custom mapping module. Default is
# {mapping_provider!r}.
# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
# for information on implementing a custom mapping provider.
#
# config: Configuration for the mapping provider module. This section will
# be passed as a Python dictionary to the user mapping provider
# module's `parse_config` method.
#
# For the default provider, the following settings are available:
#
# sub: name of the claim containing a unique identifier for the
# user. Defaults to 'sub', which OpenID Connect compliant
# providers should provide.
#
# localpart_template: Jinja2 template for the localpart of the MXID.
# If this is not set, the user will be prompted to choose their
# own username.
#
# display_name_template: Jinja2 template for the display name to set
# on first login. If unset, no displayname will be set.
#
# extra_attributes: a map of Jinja2 templates for extra attributes
# to send back to the client during login.
# Note that these are non-standard and clients will ignore them
# without modifications.
#
# When rendering, the Jinja2 templates are given a 'user' variable,
# which is set to the claims returned by the UserInfo Endpoint and/or
# in the ID Token.
#
# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
# for some example configurations.
# for information on how to configure these options.
#
oidc_config:
# Uncomment the following to enable authorization against an OpenID Connect
# server. Defaults to false.
# For backwards compatibility, it is also possible to configure a single OIDC
# provider via an 'oidc_config' setting. This is now deprecated and admins are
# advised to migrate to the 'oidc_providers' format.
#
oidc_providers:
# Generic example
#
#enabled: true
#- idp_id: my_idp
# idp_name: "My OpenID provider"
# discover: false
# issuer: "https://accounts.example.com/"
# client_id: "provided-by-your-issuer"
# client_secret: "provided-by-your-issuer"
# client_auth_method: client_secret_post
# scopes: ["openid", "profile"]
# authorization_endpoint: "https://accounts.example.com/oauth2/auth"
# token_endpoint: "https://accounts.example.com/oauth2/token"
# userinfo_endpoint: "https://accounts.example.com/userinfo"
# jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
# skip_verification: true
# Uncomment the following to disable use of the OIDC discovery mechanism to
# discover endpoints. Defaults to true.
# For use with Keycloak
#
#discover: false
#- idp_id: keycloak
# idp_name: Keycloak
# issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name"
# client_id: "synapse"
# client_secret: "copy secret generated in Keycloak UI"
# scopes: ["openid", "profile"]
# the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
# discover the provider's endpoints.
# For use with Github
#
# Required if 'enabled' is true.
#
#issuer: "https://accounts.example.com/"
# oauth2 client id to use.
#
# Required if 'enabled' is true.
#
#client_id: "provided-by-your-issuer"
# oauth2 client secret to use.
#
# Required if 'enabled' is true.
#
#client_secret: "provided-by-your-issuer"
# auth method to use when exchanging the token.
# Valid values are 'client_secret_basic' (default), 'client_secret_post' and
# 'none'.
#
#client_auth_method: client_secret_post
# list of scopes to request. This should normally include the "openid" scope.
# Defaults to ["openid"].
#
#scopes: ["openid", "profile"]
# the oauth2 authorization endpoint. Required if provider discovery is disabled.
#
#authorization_endpoint: "https://accounts.example.com/oauth2/auth"
# the oauth2 token endpoint. Required if provider discovery is disabled.
#
#token_endpoint: "https://accounts.example.com/oauth2/token"
# the OIDC userinfo endpoint. Required if discovery is disabled and the
# "openid" scope is not requested.
#
#userinfo_endpoint: "https://accounts.example.com/userinfo"
# URI where to fetch the JWKS. Required if discovery is disabled and the
# "openid" scope is used.
#
#jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
# Uncomment to skip metadata verification. Defaults to false.
#
# Use this if you are connecting to a provider that is not OpenID Connect
# compliant.
# Avoid this in production.
#
#skip_verification: true
# Whether to fetch the user profile from the userinfo endpoint. Valid
# values are: "auto" or "userinfo_endpoint".
#
# Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
# in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
#
#user_profile_method: "userinfo_endpoint"
# Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
# of failing. This could be used if switching from password logins to OIDC. Defaults to false.
#
#allow_existing_users: true
# An external module can be provided here as a custom solution to mapping
# attributes returned from a OIDC provider onto a matrix user.
#
user_mapping_provider:
# The custom module's class. Uncomment to use a custom module.
# Default is {mapping_provider!r}.
#
# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
# for information on implementing a custom mapping provider.
#
#module: mapping_provider.OidcMappingProvider
# Custom configuration values for the module. This section will be passed as
# a Python dictionary to the user mapping provider module's `parse_config`
# method.
#
# The examples below are intended for the default provider: they should be
# changed if using a custom provider.
#
config:
# name of the claim containing a unique identifier for the user.
# Defaults to `sub`, which OpenID Connect compliant providers should provide.
#
#subject_claim: "sub"
# Jinja2 template for the localpart of the MXID.
#
# When rendering, this template is given the following variables:
# * user: The claims returned by the UserInfo Endpoint and/or in the ID
# Token
#
# If this is not set, the user will be prompted to choose their
# own username.
#
#localpart_template: "{{{{ user.preferred_username }}}}"
# Jinja2 template for the display name to set on first login.
#
# If unset, no displayname will be set.
#
#display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}"
# Jinja2 templates for extra attributes to send back to the client during
# login.
#
# Note that these are non-standard and clients will ignore them without modifications.
#
#extra_attributes:
#birthdate: "{{{{ user.birthdate }}}}"
#- idp_id: google
# idp_name: Google
# discover: false
# issuer: "https://github.com/"
# client_id: "your-client-id" # TO BE FILLED
# client_secret: "your-client-secret" # TO BE FILLED
# authorization_endpoint: "https://github.com/login/oauth/authorize"
# token_endpoint: "https://github.com/login/oauth/access_token"
# userinfo_endpoint: "https://api.github.com/user"
# scopes: ["read:user"]
# user_mapping_provider:
# config:
# subject_claim: "id"
# localpart_template: "{{ user.login }}"
# display_name_template: "{{ user.name }}"
""".format(
mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
)
@ -234,7 +235,22 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
},
}
# the `oidc_config` setting can either be None (as it is in the default
# the same as OIDC_PROVIDER_CONFIG_SCHEMA, but with compulsory idp_id and idp_name
OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA = {
"allOf": [OIDC_PROVIDER_CONFIG_SCHEMA, {"required": ["idp_id", "idp_name"]}]
}
# the `oidc_providers` list can either be None (as it is in the default config), or
# a list of provider configs, each of which requires an explicit ID and name.
OIDC_PROVIDER_LIST_SCHEMA = {
"oneOf": [
{"type": "null"},
{"type": "array", "items": OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA},
]
}
# the `oidc_config` setting can either be None (which it used to be in the default
# config), or an object. If an object, it is ignored unless it has an "enabled: True"
# property.
#
@ -243,12 +259,41 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
# additional checks in the code.
OIDC_CONFIG_SCHEMA = {"oneOf": [{"type": "null"}, {"type": "object"}]}
# the top-level schema can contain an "oidc_config" and/or an "oidc_providers".
MAIN_CONFIG_SCHEMA = {
"type": "object",
"properties": {"oidc_config": OIDC_CONFIG_SCHEMA},
"properties": {
"oidc_config": OIDC_CONFIG_SCHEMA,
"oidc_providers": OIDC_PROVIDER_LIST_SCHEMA,
},
}
def _parse_oidc_provider_configs(config: JsonDict) -> Iterable["OidcProviderConfig"]:
"""extract and parse the OIDC provider configs from the config dict
The configuration may contain either a single `oidc_config` object with an
`enabled: True` property, or a list of provider configurations under
`oidc_providers`, *or both*.
Returns a generator which yields the OidcProviderConfig objects
"""
validate_config(MAIN_CONFIG_SCHEMA, config, ())
for p in config.get("oidc_providers") or []:
yield _parse_oidc_config_dict(p)
# for backwards-compatibility, it is also possible to provide a single "oidc_config"
# object with an "enabled: True" property.
oidc_config = config.get("oidc_config")
if oidc_config and oidc_config.get("enabled", False):
# MAIN_CONFIG_SCHEMA checks that `oidc_config` is an object, but not that
# it matches OIDC_PROVIDER_CONFIG_SCHEMA (see the comments on OIDC_CONFIG_SCHEMA
# above), so now we need to validate it.
validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
yield _parse_oidc_config_dict(oidc_config)
def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig":
"""Take the configuration dict and parse it into an OidcProviderConfig

View file

@ -14,14 +14,13 @@
# limitations under the License.
import os
from distutils.util import strtobool
import pkg_resources
from synapse.api.constants import RoomCreationPreset
from synapse.config._base import Config, ConfigError
from synapse.types import RoomAlias, UserID
from synapse.util.stringutils import random_string_with_symbols
from synapse.util.stringutils import random_string_with_symbols, strtobool
class AccountValidityConfig(Config):
@ -86,12 +85,12 @@ class RegistrationConfig(Config):
section = "registration"
def read_config(self, config, **kwargs):
self.enable_registration = bool(
strtobool(str(config.get("enable_registration", False)))
self.enable_registration = strtobool(
str(config.get("enable_registration", False))
)
if "disable_registration" in config:
self.enable_registration = not bool(
strtobool(str(config["disable_registration"]))
self.enable_registration = not strtobool(
str(config["disable_registration"])
)
self.account_validity = AccountValidityConfig(

View file

@ -17,7 +17,6 @@
import abc
import os
from distutils.util import strtobool
from typing import Dict, Optional, Tuple, Type
from unpaddedbase64 import encode_base64
@ -26,6 +25,7 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVers
from synapse.types import JsonDict, RoomStreamToken
from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze
from synapse.util.stringutils import strtobool
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
# bugs where we accidentally share e.g. signature dicts. However, converting a
@ -34,6 +34,7 @@ from synapse.util.frozenutils import freeze
# NOTE: This is overridden by the configuration by the Synapse worker apps, but
# for the sake of tests, it is set here while it cannot be configured on the
# homeserver object itself.
USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))

View file

@ -163,7 +163,7 @@ class DeviceMessageHandler:
await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
# Immediately attempt a resync in the background
run_in_background(self._user_device_resync, sender_user_id)
run_in_background(self._user_device_resync, user_id=sender_user_id)
async def send_device_message(
self,

View file

@ -78,21 +78,28 @@ class OidcHandler:
def __init__(self, hs: "HomeServer"):
self._sso_handler = hs.get_sso_handler()
provider_conf = hs.config.oidc.oidc_provider
provider_confs = hs.config.oidc.oidc_providers
# we should not have been instantiated if there is no configured provider.
assert provider_conf is not None
assert provider_confs
self._token_generator = OidcSessionTokenGenerator(hs)
self._provider = OidcProvider(hs, self._token_generator, provider_conf)
self._providers = {
p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
}
async def load_metadata(self) -> None:
"""Validate the config and load the metadata from the remote endpoint.
Called at startup to ensure we have everything we need.
"""
await self._provider.load_metadata()
await self._provider.load_jwks()
for idp_id, p in self._providers.items():
try:
await p.load_metadata()
await p.load_jwks()
except Exception as e:
raise Exception(
"Error while initialising OIDC provider %r" % (idp_id,)
) from e
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_synapse/oidc/callback
@ -184,6 +191,12 @@ class OidcHandler:
self._sso_handler.render_error(request, "mismatching_session", str(e))
return
oidc_provider = self._providers.get(session_data.idp_id)
if not oidc_provider:
logger.error("OIDC session uses unknown IdP %r", oidc_provider)
self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP")
return
if b"code" not in request.args:
logger.info("Code parameter is missing")
self._sso_handler.render_error(
@ -193,7 +206,7 @@ class OidcHandler:
code = request.args[b"code"][0].decode()
await self._provider.handle_oidc_callback(request, session_data, code)
await oidc_provider.handle_oidc_callback(request, session_data, code)
class OidcError(Exception):

View file

@ -766,14 +766,24 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
self.max_size = max_size
def dataReceived(self, data: bytes) -> None:
# If the deferred was called, bail early.
if self.deferred.called:
return
self.stream.write(data)
self.length += len(data)
# The first time the maximum size is exceeded, error and cancel the
# connection. dataReceived might be called again if data was received
# in the meantime.
if self.max_size is not None and self.length >= self.max_size:
self.deferred.errback(BodyExceededMaxSize())
self.deferred = defer.Deferred()
self.transport.loseConnection()
def connectionLost(self, reason: Failure) -> None:
# If the maximum size was already exceeded, there's nothing to do.
if self.deferred.called:
return
if reason.check(ResponseDone):
self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss):

View file

@ -15,6 +15,9 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from twisted.web.http import Request
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
@ -23,6 +26,10 @@ from synapse.rest.admin._base import (
assert_requester_is_admin,
assert_user_is_admin,
)
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@ -39,11 +46,11 @@ class QuarantineMediaInRoom(RestServlet):
admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request, room_id: str):
async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@ -64,11 +71,11 @@ class QuarantineMediaByUser(RestServlet):
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request, user_id: str):
async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@ -91,11 +98,13 @@ class QuarantineMediaByID(RestServlet):
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request, server_name: str, media_id: str):
async def on_POST(
self, request: Request, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@ -109,17 +118,39 @@ class QuarantineMediaByID(RestServlet):
return 200, {}
class ProtectMediaByID(RestServlet):
"""Protect local media from being quarantined.
"""
PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
logging.info("Protecting local media by ID: %s", media_id)
# Quarantine this media id
await self.store.mark_local_media_as_safe(media_id)
return 200, {}
class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room.
"""
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_GET(self, request, room_id):
async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
is_admin = await self.auth.is_server_admin(requester.user)
if not is_admin:
@ -133,11 +164,11 @@ class ListMediaInRoom(RestServlet):
class PurgeMediaCacheRestServlet(RestServlet):
PATTERNS = admin_patterns("/purge_media_cache")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.media_repository = hs.get_media_repository()
self.auth = hs.get_auth()
async def on_POST(self, request):
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
@ -154,13 +185,15 @@ class DeleteMediaByID(RestServlet):
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.server_name = hs.hostname
self.media_repository = hs.get_media_repository()
async def on_DELETE(self, request, server_name: str, media_id: str):
async def on_DELETE(
self, request: Request, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if self.server_name != server_name:
@ -182,13 +215,13 @@ class DeleteMediaByDateSize(RestServlet):
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.server_name = hs.hostname
self.media_repository = hs.get_media_repository()
async def on_POST(self, request, server_name: str):
async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
@ -222,7 +255,7 @@ class DeleteMediaByDateSize(RestServlet):
return 200, {"deleted_media": deleted_media, "total": total}
def register_servlets_for_media_repo(hs, http_server):
def register_servlets_for_media_repo(hs: "HomeServer", http_server):
"""
Media repo specific APIs.
"""
@ -230,6 +263,7 @@ def register_servlets_for_media_repo(hs, http_server):
QuarantineMediaInRoom(hs).register(http_server)
QuarantineMediaByID(hs).register(http_server)
QuarantineMediaByUser(hs).register(http_server)
ProtectMediaByID(hs).register(http_server)
ListMediaInRoom(hs).register(http_server)
DeleteMediaByID(hs).register(http_server)
DeleteMediaByDateSize(hs).register(http_server)

View file

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -17,10 +17,11 @@
import logging
import os
import urllib
from typing import Awaitable
from typing import Awaitable, Dict, Generator, List, Optional, Tuple
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json
@ -46,7 +47,7 @@ TEXT_CONTENT_TYPES = [
]
def parse_media_id(request):
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
try:
# This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type.
@ -69,7 +70,7 @@ def parse_media_id(request):
)
def respond_404(request):
def respond_404(request: Request) -> None:
respond_with_json(
request,
404,
@ -79,8 +80,12 @@ def respond_404(request):
async def respond_with_file(
request, media_type, file_path, file_size=None, upload_name=None
):
request: Request,
media_type: str,
file_path: str,
file_size: Optional[int] = None,
upload_name: Optional[str] = None,
) -> None:
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
@ -98,15 +103,20 @@ async def respond_with_file(
respond_404(request)
def add_file_headers(request, media_type, file_size, upload_name):
def add_file_headers(
request: Request,
media_type: str,
file_size: Optional[int],
upload_name: Optional[str],
) -> None:
"""Adds the correct response headers in preparation for responding with the
media.
Args:
request (twisted.web.http.Request)
media_type (str): The media/content type.
file_size (int): Size in bytes of the media, if known.
upload_name (str): The name of the requested file, if any.
request
media_type: The media/content type.
file_size: Size in bytes of the media, if known.
upload_name: The name of the requested file, if any.
"""
def _quote(x):
@ -153,7 +163,8 @@ def add_file_headers(request, media_type, file_size, upload_name):
# select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control
request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
request.setHeader(b"Content-Length", b"%d" % (file_size,))
if file_size is not None:
request.setHeader(b"Content-Length", b"%d" % (file_size,))
# Tell web crawlers to not index, archive, or follow links in media. This
# should help to prevent things in the media repo from showing up in web
@ -184,7 +195,7 @@ _FILENAME_SEPARATOR_CHARS = {
}
def _can_encode_filename_as_token(x):
def _can_encode_filename_as_token(x: str) -> bool:
for c in x:
# from RFC2616:
#
@ -206,17 +217,21 @@ def _can_encode_filename_as_token(x):
async def respond_with_responder(
request, responder, media_type, file_size, upload_name=None
):
request: Request,
responder: "Optional[Responder]",
media_type: str,
file_size: Optional[int],
upload_name: Optional[str] = None,
) -> None:
"""Responds to the request with given responder. If responder is None then
returns 404.
Args:
request (twisted.web.http.Request)
responder (Responder|None)
media_type (str): The media/content type.
file_size (int|None): Size in bytes of the media. If not known it should be None
upload_name (str|None): The name of the requested file, if any.
request
responder
media_type: The media/content type.
file_size: Size in bytes of the media. If not known it should be None
upload_name: The name of the requested file, if any.
"""
if request._disconnected:
logger.warning(
@ -308,22 +323,22 @@ class FileInfo:
self.thumbnail_type = thumbnail_type
def get_filename_from_headers(headers):
def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
"""
Get the filename of the downloaded file by inspecting the
Content-Disposition HTTP header.
Args:
headers (dict[bytes, list[bytes]]): The HTTP request headers.
headers: The HTTP request headers.
Returns:
A Unicode string of the filename, or None.
The filename, or None.
"""
content_disposition = headers.get(b"Content-Disposition", [b""])
# No header, bail out.
if not content_disposition[0]:
return
return None
_, params = _parse_header(content_disposition[0])
@ -356,17 +371,16 @@ def get_filename_from_headers(headers):
return upload_name
def _parse_header(line):
def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]:
"""Parse a Content-type like header.
Cargo-culted from `cgi`, but works on bytes rather than strings.
Args:
line (bytes): header to be parsed
line: header to be parsed
Returns:
Tuple[bytes, dict[bytes, bytes]]:
the main content-type, followed by the parameter dictionary
The main content-type, followed by the parameter dictionary
"""
parts = _parseparam(b";" + line)
key = next(parts)
@ -386,16 +400,16 @@ def _parse_header(line):
return key, pdict
def _parseparam(s):
def _parseparam(s: bytes) -> Generator[bytes, None, None]:
"""Generator which splits the input on ;, respecting double-quoted sequences
Cargo-culted from `cgi`, but works on bytes rather than strings.
Args:
s (bytes): header to be parsed
s: header to be parsed
Returns:
Iterable[bytes]: the split input
The split input
"""
while s[:1] == b";":
s = s[1:]

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2018 Will Hunt <will@half-shot.uk>
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,22 +15,29 @@
# limitations under the License.
#
from typing import TYPE_CHECKING
from twisted.web.http import Request
from synapse.http.server import DirectServeJsonResource, respond_with_json
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
class MediaConfigResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
config = hs.get_config()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size}
async def _async_render_GET(self, request):
async def _async_render_GET(self, request: Request) -> None:
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
async def _async_render_OPTIONS(self, request):
async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,24 +14,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from twisted.web.http import Request
import synapse.http.servlet
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_boolean
from ._base import parse_media_id, respond_404
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.rest.media.v1.media_repository import MediaRepository
logger = logging.getLogger(__name__)
class DownloadResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo):
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self.media_repo = media_repo
self.server_name = hs.hostname
async def _async_render_GET(self, request):
async def _async_render_GET(self, request: Request) -> None:
set_cors_headers(request)
request.setHeader(
b"Content-Security-Policy",
@ -49,9 +57,7 @@ class DownloadResource(DirectServeJsonResource):
if server_name == self.server_name:
await self.media_repo.get_local_media(request, media_id, name)
else:
allow_remote = synapse.http.servlet.parse_boolean(
request, "allow_remote", default=True
)
allow_remote = parse_boolean(request, "allow_remote", default=True)
if not allow_remote:
logger.info(
"Rejecting request for remote media %s/%s due to allow_remote",

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,11 +17,12 @@
import functools
import os
import re
from typing import Callable, List
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
def _wrap_in_base_path(func):
def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]":
"""Takes a function that returns a relative path and turns it into an
absolute path based on the location of the primary media store
"""
@ -41,12 +43,18 @@ class MediaFilePaths:
to write to the backup media store (when one is configured)
"""
def __init__(self, primary_base_path):
def __init__(self, primary_base_path: str):
self.base_path = primary_base_path
def default_thumbnail_rel(
self, default_top_level, default_sub_type, width, height, content_type, method
):
self,
default_top_level: str,
default_sub_type: str,
width: int,
height: int,
content_type: str,
method: str,
) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
@ -55,12 +63,14 @@ class MediaFilePaths:
default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
def local_media_filepath_rel(self, media_id):
def local_media_filepath_rel(self, media_id: str) -> str:
return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
def local_media_thumbnail_rel(self, media_id, width, height, content_type, method):
def local_media_thumbnail_rel(
self, media_id: str, width: int, height: int, content_type: str, method: str
) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
@ -86,7 +96,7 @@ class MediaFilePaths:
media_id[4:],
)
def remote_media_filepath_rel(self, server_name, file_id):
def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
return os.path.join(
"remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
)
@ -94,8 +104,14 @@ class MediaFilePaths:
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
def remote_media_thumbnail_rel(
self, server_name, file_id, width, height, content_type, method
):
self,
server_name: str,
file_id: str,
width: int,
height: int,
content_type: str,
method: str,
) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
@ -113,7 +129,7 @@ class MediaFilePaths:
# Should be removed after some time, when most of the thumbnails are stored
# using the new path.
def remote_media_thumbnail_rel_legacy(
self, server_name, file_id, width, height, content_type
self, server_name: str, file_id: str, width: int, height: int, content_type: str
):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
@ -126,7 +142,7 @@ class MediaFilePaths:
file_name,
)
def remote_media_thumbnail_dir(self, server_name, file_id):
def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
return os.path.join(
self.base_path,
"remote_thumbnail",
@ -136,7 +152,7 @@ class MediaFilePaths:
file_id[4:],
)
def url_cache_filepath_rel(self, media_id):
def url_cache_filepath_rel(self, media_id: str) -> str:
if NEW_FORMAT_ID_RE.match(media_id):
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
@ -146,7 +162,7 @@ class MediaFilePaths:
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
def url_cache_filepath_dirs_to_delete(self, media_id):
def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id file"
if NEW_FORMAT_ID_RE.match(media_id):
return [os.path.join(self.base_path, "url_cache", media_id[:10])]
@ -156,7 +172,9 @@ class MediaFilePaths:
os.path.join(self.base_path, "url_cache", media_id[0:2]),
]
def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method):
def url_cache_thumbnail_rel(
self, media_id: str, width: int, height: int, content_type: str, method: str
) -> str:
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
@ -178,7 +196,7 @@ class MediaFilePaths:
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
def url_cache_thumbnail_directory(self, media_id):
def url_cache_thumbnail_directory(self, media_id: str) -> str:
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
@ -195,7 +213,7 @@ class MediaFilePaths:
media_id[4:],
)
def url_cache_thumbnail_dirs_to_delete(self, media_id):
def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id thumbnails"
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf

View file

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,12 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import errno
import logging
import os
import shutil
from typing import IO, Dict, List, Optional, Tuple
from io import BytesIO
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import twisted.internet.error
import twisted.web.http
@ -56,6 +56,9 @@ from .thumbnail_resource import ThumbnailResource
from .thumbnailer import Thumbnailer, ThumbnailError
from .upload_resource import UploadResource
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@ -63,7 +66,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
class MediaRepository:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.client = hs.get_federation_http_client()
@ -73,16 +76,16 @@ class MediaRepository:
self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels
self.primary_base_path = hs.config.media_store_path
self.filepaths = MediaFilePaths(self.primary_base_path)
self.primary_base_path = hs.config.media_store_path # type: str
self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes = set()
self.recently_accessed_locals = set()
self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]]
self.recently_accessed_locals = set() # type: Set[str]
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
@ -113,7 +116,7 @@ class MediaRepository:
"update_recently_accessed_media", self._update_recently_accessed
)
async def _update_recently_accessed(self):
async def _update_recently_accessed(self) -> None:
remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()
@ -124,12 +127,12 @@ class MediaRepository:
local_media, remote_media, self.clock.time_msec()
)
def mark_recently_accessed(self, server_name, media_id):
def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None:
"""Mark the given media as recently accessed.
Args:
server_name (str|None): Origin server of media, or None if local
media_id (str): The media ID of the content
server_name: Origin server of media, or None if local
media_id: The media ID of the content
"""
if server_name:
self.recently_accessed_remotes.add((server_name, media_id))
@ -459,7 +462,14 @@ class MediaRepository:
def _get_thumbnail_requirements(self, media_type):
return self.thumbnail_requirements.get(media_type, ())
def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type):
def _generate_thumbnail(
self,
thumbnailer: Thumbnailer,
t_width: int,
t_height: int,
t_method: str,
t_type: str,
) -> Optional[BytesIO]:
m_width = thumbnailer.width
m_height = thumbnailer.height
@ -470,22 +480,20 @@ class MediaRepository:
m_height,
self.max_image_pixels,
)
return
return None
if thumbnailer.transpose_method is not None:
m_width, m_height = thumbnailer.transpose()
if t_method == "crop":
t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
return thumbnailer.crop(t_width, t_height, t_type)
elif t_method == "scale":
t_width, t_height = thumbnailer.aspect(t_width, t_height)
t_width = min(m_width, t_width)
t_height = min(m_height, t_height)
t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
else:
t_byte_source = None
return thumbnailer.scale(t_width, t_height, t_type)
return t_byte_source
return None
async def generate_local_exact_thumbnail(
self,
@ -776,7 +784,7 @@ class MediaRepository:
return {"width": m_width, "height": m_height}
async def delete_old_remote_media(self, before_ts):
async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:
old_media = await self.store.get_remote_media_before(before_ts)
deleted = 0
@ -928,7 +936,7 @@ class MediaRepositoryResource(Resource):
within a given rectangle.
"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
# If we're not configured to use it, raise if we somehow got here.
if not hs.config.can_load_media_repo:
raise ConfigError("Synapse is not configured to use a media repo.")

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vecotr Ltd
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -18,6 +18,8 @@ import os
import shutil
from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
@ -270,7 +272,7 @@ class MediaStorage:
return self.filepaths.local_media_filepath_rel(file_info.file_id)
def _write_file_synchronously(source, dest):
def _write_file_synchronously(source: IO, dest: IO) -> None:
"""Write `source` to the file like `dest` synchronously. Should be called
from a thread.
@ -286,14 +288,14 @@ class FileResponder(Responder):
"""Wraps an open file that can be sent to a request.
Args:
open_file (file): A file like object to be streamed ot the client,
open_file: A file like object to be streamed ot the client,
is closed when finished streaming.
"""
def __init__(self, open_file):
def __init__(self, open_file: IO):
self.open_file = open_file
def write_to_consumer(self, consumer):
def write_to_consumer(self, consumer: IConsumer) -> Deferred:
return make_deferred_yieldable(
FileSender().beginFileTransfer(self.open_file, consumer)
)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import errno
import fnmatch
@ -23,12 +23,13 @@ import re
import shutil
import sys
import traceback
from typing import Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union
from urllib import parse as urlparse
import attr
from twisted.internet.error import DNSLookupError
from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient
@ -41,6 +42,7 @@ from synapse.http.servlet import parse_integer, parse_string
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
@ -48,6 +50,12 @@ from synapse.util.stringutils import random_string
from ._base import FileInfo
if TYPE_CHECKING:
from lxml import etree
from synapse.app.homeserver import HomeServer
from synapse.rest.media.v1.media_repository import MediaRepository
logger = logging.getLogger(__name__)
_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
@ -119,7 +127,12 @@ class OEmbedError(Exception):
class PreviewUrlResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo, media_storage):
def __init__(
self,
hs: "HomeServer",
media_repo: "MediaRepository",
media_storage: MediaStorage,
):
super().__init__()
self.auth = hs.get_auth()
@ -166,11 +179,11 @@ class PreviewUrlResource(DirectServeJsonResource):
self._start_expire_url_cache_data, 10 * 1000
)
async def _async_render_OPTIONS(self, request):
async def _async_render_OPTIONS(self, request: Request) -> None:
request.setHeader(b"Allow", b"OPTIONS, GET")
respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_GET(self, request):
async def _async_render_GET(self, request: Request) -> None:
# XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request)
@ -450,7 +463,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
raise OEmbedError() from e
async def _download_url(self, url: str, user):
async def _download_url(self, url: str, user: str) -> Dict[str, Any]:
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@ -580,7 +593,7 @@ class PreviewUrlResource(DirectServeJsonResource):
"expire_url_cache_data", self._expire_url_cache_data
)
async def _expire_url_cache_data(self):
async def _expire_url_cache_data(self) -> None:
"""Clean up expired url cache content, media and thumbnails.
"""
# TODO: Delete from backup media store
@ -676,7 +689,9 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("No media removed from url cache")
def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]:
def decode_and_calc_og(
body: bytes, media_uri: str, request_encoding: Optional[str] = None
) -> Dict[str, Optional[str]]:
# If there's no body, nothing useful is going to be found.
if not body:
return {}
@ -697,7 +712,7 @@ def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]
return og
def _calc_og(tree, media_uri):
def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]:
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them
@ -801,7 +816,9 @@ def _calc_og(tree, media_uri):
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
)
og["og:description"] = summarize_paragraphs(text_nodes)
else:
elif og["og:description"]:
# This must be a non-empty string at this point.
assert isinstance(og["og:description"], str)
og["og:description"] = summarize_paragraphs([og["og:description"]])
# TODO: delete the url downloads to stop diskfilling,
@ -809,7 +826,9 @@ def _calc_og(tree, media_uri):
return og
def _iterate_over_text(tree, *tags_to_ignore):
def _iterate_over_text(
tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
) -> Generator[str, None, None]:
"""Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags.
"""
@ -843,32 +862,32 @@ def _iterate_over_text(tree, *tags_to_ignore):
)
def _rebase_url(url, base):
base = list(urlparse.urlparse(base))
url = list(urlparse.urlparse(url))
if not url[0]: # fix up schema
url[0] = base[0] or "http"
if not url[1]: # fix up hostname
url[1] = base[1]
if not url[2].startswith("/"):
url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2]
return urlparse.urlunparse(url)
def _rebase_url(url: str, base: str) -> str:
base_parts = list(urlparse.urlparse(base))
url_parts = list(urlparse.urlparse(url))
if not url_parts[0]: # fix up schema
url_parts[0] = base_parts[0] or "http"
if not url_parts[1]: # fix up hostname
url_parts[1] = base_parts[1]
if not url_parts[2].startswith("/"):
url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
return urlparse.urlunparse(url_parts)
def _is_media(content_type):
if content_type.lower().startswith("image/"):
return True
def _is_media(content_type: str) -> bool:
return content_type.lower().startswith("image/")
def _is_html(content_type):
def _is_html(content_type: str) -> bool:
content_type = content_type.lower()
if content_type.startswith("text/html") or content_type.startswith(
return content_type.startswith("text/html") or content_type.startswith(
"application/xhtml"
):
return True
)
def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
def summarize_paragraphs(
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
) -> Optional[str]:
# Try to get a summary of between 200 and 500 words, respecting
# first paragraph and then word boundaries.
# TODO: Respect sentences?

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
import os
import shutil
from typing import Optional
from typing import TYPE_CHECKING, Optional
from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background
@ -27,13 +28,17 @@ from .media_storage import FileResponder
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
class StorageProvider:
class StorageProvider(metaclass=abc.ABCMeta):
"""A storage provider is a service that can store uploaded media and
retrieve them.
"""
async def store_file(self, path: str, file_info: FileInfo):
@abc.abstractmethod
async def store_file(self, path: str, file_info: FileInfo) -> None:
"""Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path.
@ -42,6 +47,7 @@ class StorageProvider:
file_info: The metadata of the file.
"""
@abc.abstractmethod
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""Attempt to fetch the file described by file_info and stream it
into writer.
@ -78,10 +84,10 @@ class StorageProviderWrapper(StorageProvider):
self.store_synchronous = store_synchronous
self.store_remote = store_remote
def __str__(self):
def __str__(self) -> str:
return "StorageProviderWrapper[%s]" % (self.backend,)
async def store_file(self, path, file_info):
async def store_file(self, path: str, file_info: FileInfo) -> None:
if not file_info.server_name and not self.store_local:
return None
@ -91,7 +97,7 @@ class StorageProviderWrapper(StorageProvider):
if self.store_synchronous:
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
return await maybe_awaitable(self.backend.store_file(path, file_info))
await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
else:
# TODO: Handle errors.
async def store():
@ -103,9 +109,8 @@ class StorageProviderWrapper(StorageProvider):
logger.exception("Error storing file")
run_in_background(store)
return None
async def fetch(self, path, file_info):
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
return await maybe_awaitable(self.backend.fetch(path, file_info))
@ -115,11 +120,11 @@ class FileStorageProviderBackend(StorageProvider):
"""A storage provider that stores files in a directory on a filesystem.
Args:
hs (HomeServer)
hs
config: The config returned by `parse_config`.
"""
def __init__(self, hs, config):
def __init__(self, hs: "HomeServer", config: str):
self.hs = hs
self.cache_directory = hs.config.media_store_path
self.base_directory = config
@ -127,7 +132,7 @@ class FileStorageProviderBackend(StorageProvider):
def __str__(self):
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
async def store_file(self, path, file_info):
async def store_file(self, path: str, file_info: FileInfo) -> None:
"""See StorageProvider.store_file"""
primary_fname = os.path.join(self.cache_directory, path)
@ -137,19 +142,21 @@ class FileStorageProviderBackend(StorageProvider):
if not os.path.exists(dirname):
os.makedirs(dirname)
return await defer_to_thread(
await defer_to_thread(
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
)
async def fetch(self, path, file_info):
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""See StorageProvider.fetch"""
backup_fname = os.path.join(self.base_directory, path)
if os.path.isfile(backup_fname):
return FileResponder(open(backup_fname, "rb"))
return None
@staticmethod
def parse_config(config):
def parse_config(config: dict) -> str:
"""Called on startup to parse config supplied. This should parse
the config and raise if there is a problem.

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -15,10 +16,14 @@
import logging
from typing import TYPE_CHECKING
from twisted.web.http import Request
from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string
from synapse.rest.media.v1.media_storage import MediaStorage
from ._base import (
FileInfo,
@ -28,13 +33,22 @@ from ._base import (
respond_with_responder,
)
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.rest.media.v1.media_repository import MediaRepository
logger = logging.getLogger(__name__)
class ThumbnailResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo, media_storage):
def __init__(
self,
hs: "HomeServer",
media_repo: "MediaRepository",
media_storage: MediaStorage,
):
super().__init__()
self.store = hs.get_datastore()
@ -43,7 +57,7 @@ class ThumbnailResource(DirectServeJsonResource):
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
async def _async_render_GET(self, request):
async def _async_render_GET(self, request: Request) -> None:
set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width", required=True)
@ -73,8 +87,14 @@ class ThumbnailResource(DirectServeJsonResource):
self.media_repo.mark_recently_accessed(server_name, media_id)
async def _respond_local_thumbnail(
self, request, media_id, width, height, method, m_type
):
self,
request: Request,
media_id: str,
width: int,
height: int,
method: str,
m_type: str,
) -> None:
media_info = await self.store.get_local_media(media_id)
if not media_info:
@ -114,13 +134,13 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_local_thumbnail(
self,
request,
media_id,
desired_width,
desired_height,
desired_method,
desired_type,
):
request: Request,
media_id: str,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
) -> None:
media_info = await self.store.get_local_media(media_id)
if not media_info:
@ -178,14 +198,14 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_remote_thumbnail(
self,
request,
server_name,
media_id,
desired_width,
desired_height,
desired_method,
desired_type,
):
request: Request,
server_name: str,
media_id: str,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
) -> None:
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = await self.store.get_remote_media_thumbnails(
@ -239,8 +259,15 @@ class ThumbnailResource(DirectServeJsonResource):
raise SynapseError(400, "Failed to generate thumbnail.")
async def _respond_remote_thumbnail(
self, request, server_name, media_id, width, height, method, m_type
):
self,
request: Request,
server_name: str,
media_id: str,
width: int,
height: int,
method: str,
m_type: str,
) -> None:
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
@ -275,12 +302,12 @@ class ThumbnailResource(DirectServeJsonResource):
def _select_thumbnail(
self,
desired_width,
desired_height,
desired_method,
desired_type,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
thumbnail_infos,
):
) -> dict:
d_w = desired_width
d_h = desired_height

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,6 +15,7 @@
# limitations under the License.
import logging
from io import BytesIO
from typing import Tuple
from PIL import Image
@ -39,7 +41,7 @@ class Thumbnailer:
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
def __init__(self, input_path):
def __init__(self, input_path: str):
try:
self.image = Image.open(input_path)
except OSError as e:
@ -59,11 +61,11 @@ class Thumbnailer:
# A lot of parsing errors can happen when parsing EXIF
logger.info("Error parsing image EXIF information: %s", e)
def transpose(self):
def transpose(self) -> Tuple[int, int]:
"""Transpose the image using its EXIF Orientation tag
Returns:
Tuple[int, int]: (width, height) containing the new image size in pixels.
A tuple containing the new image size in pixels as (width, height).
"""
if self.transpose_method is not None:
self.image = self.image.transpose(self.transpose_method)
@ -73,7 +75,7 @@ class Thumbnailer:
self.image.info["exif"] = None
return self.image.size
def aspect(self, max_width, max_height):
def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]:
"""Calculate the largest size that preserves aspect ratio which
fits within the given rectangle::
@ -91,7 +93,7 @@ class Thumbnailer:
else:
return (max_height * self.width) // self.height, max_height
def _resize(self, width, height):
def _resize(self, width: int, height: int) -> Image:
# 1-bit or 8-bit color palette images need converting to RGB
# otherwise they will be scaled using nearest neighbour which
# looks awful
@ -99,7 +101,7 @@ class Thumbnailer:
self.image = self.image.convert("RGB")
return self.image.resize((width, height), Image.ANTIALIAS)
def scale(self, width, height, output_type):
def scale(self, width: int, height: int, output_type: str) -> BytesIO:
"""Rescales the image to the given dimensions.
Returns:
@ -108,7 +110,7 @@ class Thumbnailer:
scaled = self._resize(width, height)
return self._encode_image(scaled, output_type)
def crop(self, width, height, output_type):
def crop(self, width: int, height: int, output_type: str) -> BytesIO:
"""Rescales and crops the image to the given dimensions preserving
aspect::
(w_in / h_in) = (w_scaled / h_scaled)
@ -136,7 +138,7 @@ class Thumbnailer:
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
return self._encode_image(cropped, output_type)
def _encode_image(self, output_image, output_type):
def _encode_image(self, output_image: Image, output_type: str) -> BytesIO:
output_bytes_io = BytesIO()
fmt = self.FORMATS[output_type]
if fmt == "JPEG":

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,18 +15,25 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.rest.media.v1.media_repository import MediaRepository
logger = logging.getLogger(__name__)
class UploadResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo):
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self.media_repo = media_repo
@ -37,10 +45,10 @@ class UploadResource(DirectServeJsonResource):
self.max_upload_size = hs.config.max_upload_size
self.clock = hs.get_clock()
async def _async_render_OPTIONS(self, request):
async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_POST(self, request):
async def _async_render_POST(self, request: Request) -> None:
requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point

View file

@ -16,6 +16,8 @@
import logging
from typing import Dict, List, Optional, Tuple
import attr
from synapse.api.constants import EventContentFields
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
@ -28,6 +30,25 @@ from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True)
class _CalculateChainCover:
"""Return value for _calculate_chain_cover_txn.
"""
# The last room_id/depth/stream processed.
room_id = attr.ib(type=str)
depth = attr.ib(type=int)
stream = attr.ib(type=int)
# Number of rows processed
processed_count = attr.ib(type=int)
# Map from room_id to last depth/stream processed for each room that we have
# processed all events for (i.e. the rooms we can flip the
# `has_auth_chain_index` for)
finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]])
class EventsBackgroundUpdatesStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
@ -719,138 +740,29 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
current_room_id = progress.get("current_room_id", "")
# Have we finished processing the current room.
finished = progress.get("finished", True)
# Where we've processed up to in the room, defaults to the start of the
# room.
last_depth = progress.get("last_depth", -1)
last_stream = progress.get("last_stream", -1)
# Have we set the `has_auth_chain_index` for the room yet.
has_set_room_has_chain_index = progress.get(
"has_set_room_has_chain_index", False
result = await self.db_pool.runInteraction(
"_chain_cover_index",
self._calculate_chain_cover_txn,
current_room_id,
last_depth,
last_stream,
batch_size,
single_room=False,
)
if finished:
# If we've finished with the previous room (or its our first
# iteration) we move on to the next room.
finished = result.processed_count == 0
def _get_next_room(txn: Cursor) -> Optional[str]:
sql = """
SELECT room_id FROM rooms
WHERE room_id > ?
AND (
NOT has_auth_chain_index
OR has_auth_chain_index IS NULL
)
ORDER BY room_id
LIMIT 1
"""
txn.execute(sql, (current_room_id,))
row = txn.fetchone()
if row:
return row[0]
total_rows_processed = result.processed_count
current_room_id = result.room_id
last_depth = result.depth
last_stream = result.stream
return None
current_room_id = await self.db_pool.runInteraction(
"_chain_cover_index", _get_next_room
)
if not current_room_id:
await self.db_pool.updates._end_background_update("chain_cover")
return 0
logger.debug("Adding chain cover to %s", current_room_id)
def _calculate_auth_chain(
txn: Cursor, last_depth: int, last_stream: int
) -> Tuple[int, int, int]:
# Get the next set of events in the room (that we haven't already
# computed chain cover for). We do this in topological order.
# We want to do a `(topological_ordering, stream_ordering) > (?,?)`
# comparison, but that is not supported on older SQLite versions
tuple_clause, tuple_args = make_tuple_comparison_clause(
self.database_engine,
[
("topological_ordering", last_depth),
("stream_ordering", last_stream),
],
)
sql = """
SELECT
event_id, state_events.type, state_events.state_key,
topological_ordering, stream_ordering
FROM events
INNER JOIN state_events USING (event_id)
LEFT JOIN event_auth_chains USING (event_id)
LEFT JOIN event_auth_chain_to_calculate USING (event_id)
WHERE events.room_id = ?
AND event_auth_chains.event_id IS NULL
AND event_auth_chain_to_calculate.event_id IS NULL
AND %(tuple_cmp)s
ORDER BY topological_ordering, stream_ordering
LIMIT ?
""" % {
"tuple_cmp": tuple_clause,
}
args = [current_room_id]
args.extend(tuple_args)
args.append(batch_size)
txn.execute(sql, args)
rows = txn.fetchall()
# Put the results in the necessary format for
# `_add_chain_cover_index`
event_to_room_id = {row[0]: current_room_id for row in rows}
event_to_types = {row[0]: (row[1], row[2]) for row in rows}
new_last_depth = rows[-1][3] if rows else last_depth # type: int
new_last_stream = rows[-1][4] if rows else last_stream # type: int
count = len(rows)
# We also need to fetch the auth events for them.
auth_events = self.db_pool.simple_select_many_txn(
txn,
table="event_auth",
column="event_id",
iterable=event_to_room_id,
keyvalues={},
retcols=("event_id", "auth_id"),
)
event_to_auth_chain = {} # type: Dict[str, List[str]]
for row in auth_events:
event_to_auth_chain.setdefault(row["event_id"], []).append(
row["auth_id"]
)
# Calculate and persist the chain cover index for this set of events.
#
# Annoyingly we need to gut wrench into the persit event store so that
# we can reuse the function to calculate the chain cover for rooms.
PersistEventsStore._add_chain_cover_index(
txn,
self.db_pool,
event_to_room_id,
event_to_types,
event_to_auth_chain,
)
return new_last_depth, new_last_stream, count
last_depth, last_stream, count = await self.db_pool.runInteraction(
"_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
)
total_rows_processed = count
if count < batch_size and not has_set_room_has_chain_index:
for room_id, (depth, stream) in result.finished_room_map.items():
# If we've done all the events in the room we flip the
# `has_auth_chain_index` in the DB. Note that its possible for
# further events to be persisted between the above and setting the
@ -860,42 +772,159 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
await self.db_pool.simple_update(
table="rooms",
keyvalues={"room_id": current_room_id},
keyvalues={"room_id": room_id},
updatevalues={"has_auth_chain_index": True},
desc="_chain_cover_index",
)
has_set_room_has_chain_index = True
# Handle any events that might have raced with us flipping the
# bit above.
last_depth, last_stream, count = await self.db_pool.runInteraction(
"_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
result = await self.db_pool.runInteraction(
"_chain_cover_index",
self._calculate_chain_cover_txn,
room_id,
depth,
stream,
batch_size=None,
single_room=True,
)
total_rows_processed += count
total_rows_processed += result.processed_count
# Note that at this point its technically possible that more events
# than our `batch_size` have been persisted without their chain
# cover, so we need to continue processing this room if the last
# count returned was equal to the `batch_size`.
if finished:
await self.db_pool.updates._end_background_update("chain_cover")
return total_rows_processed
if count < batch_size:
# We've finished calculating the index for this room, move on to the
# next room.
await self.db_pool.updates._background_update_progress(
"chain_cover", {"current_room_id": current_room_id, "finished": True},
)
else:
# We still have outstanding events to calculate the index for.
await self.db_pool.updates._background_update_progress(
"chain_cover",
{
"current_room_id": current_room_id,
"last_depth": last_depth,
"last_stream": last_stream,
"has_auth_chain_index": has_set_room_has_chain_index,
"finished": False,
},
)
await self.db_pool.updates._background_update_progress(
"chain_cover",
{
"current_room_id": current_room_id,
"last_depth": last_depth,
"last_stream": last_stream,
},
)
return total_rows_processed
def _calculate_chain_cover_txn(
self,
txn: Cursor,
last_room_id: str,
last_depth: int,
last_stream: int,
batch_size: Optional[int],
single_room: bool,
) -> _CalculateChainCover:
"""Calculate the chain cover for `batch_size` events, ordered by
`(room_id, depth, stream)`.
Args:
txn,
last_room_id, last_depth, last_stream: The `(room_id, depth, stream)`
tuple to fetch results after.
batch_size: The maximum number of events to process. If None then
no limit.
single_room: Whether to calculate the index for just the given
room.
"""
# Get the next set of events in the room (that we haven't already
# computed chain cover for). We do this in topological order.
# We want to do a `(topological_ordering, stream_ordering) > (?,?)`
# comparison, but that is not supported on older SQLite versions
tuple_clause, tuple_args = make_tuple_comparison_clause(
self.database_engine,
[
("events.room_id", last_room_id),
("topological_ordering", last_depth),
("stream_ordering", last_stream),
],
)
extra_clause = ""
if single_room:
extra_clause = "AND events.room_id = ?"
tuple_args.append(last_room_id)
sql = """
SELECT
event_id, state_events.type, state_events.state_key,
topological_ordering, stream_ordering,
events.room_id
FROM events
INNER JOIN state_events USING (event_id)
LEFT JOIN event_auth_chains USING (event_id)
LEFT JOIN event_auth_chain_to_calculate USING (event_id)
WHERE event_auth_chains.event_id IS NULL
AND event_auth_chain_to_calculate.event_id IS NULL
AND %(tuple_cmp)s
%(extra)s
ORDER BY events.room_id, topological_ordering, stream_ordering
%(limit)s
""" % {
"tuple_cmp": tuple_clause,
"limit": "LIMIT ?" if batch_size is not None else "",
"extra": extra_clause,
}
if batch_size is not None:
tuple_args.append(batch_size)
txn.execute(sql, tuple_args)
rows = txn.fetchall()
# Put the results in the necessary format for
# `_add_chain_cover_index`
event_to_room_id = {row[0]: row[5] for row in rows}
event_to_types = {row[0]: (row[1], row[2]) for row in rows}
# Calculate the new last position we've processed up to.
new_last_depth = rows[-1][3] if rows else last_depth # type: int
new_last_stream = rows[-1][4] if rows else last_stream # type: int
new_last_room_id = rows[-1][5] if rows else "" # type: str
# Map from room_id to last depth/stream_ordering processed for the room,
# excluding the last room (which we're likely still processing). We also
# need to include the room passed in if it's not included in the result
# set (as we then know we've processed all events in said room).
#
# This is the set of rooms that we can now safely flip the
# `has_auth_chain_index` bit for.
finished_rooms = {
row[5]: (row[3], row[4]) for row in rows if row[5] != new_last_room_id
}
if last_room_id not in finished_rooms and last_room_id != new_last_room_id:
finished_rooms[last_room_id] = (last_depth, last_stream)
count = len(rows)
# We also need to fetch the auth events for them.
auth_events = self.db_pool.simple_select_many_txn(
txn,
table="event_auth",
column="event_id",
iterable=event_to_room_id,
keyvalues={},
retcols=("event_id", "auth_id"),
)
event_to_auth_chain = {} # type: Dict[str, List[str]]
for row in auth_events:
event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
# Calculate and persist the chain cover index for this set of events.
#
# Annoyingly we need to gut wrench into the persit event store so that
# we can reuse the function to calculate the chain cover for rooms.
PersistEventsStore._add_chain_cover_index(
txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
)
return _CalculateChainCover(
room_id=new_last_room_id,
depth=new_last_depth,
stream=new_last_stream,
processed_count=count,
finished_room_map=finished_rooms,
)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -169,7 +170,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_local_media_before(
self, before_ts: int, size_gt: int, keep_profiles: bool,
) -> Optional[List[str]]:
) -> List[str]:
# to find files that have never been accessed (last_access_ts IS NULL)
# compare with `created_ts`

View file

@ -17,14 +17,13 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
from canonicaljson import encode_canonical_json
from synapse.push import PusherConfig, ThrottleParams
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@ -315,7 +314,7 @@ class PusherStore(PusherWorkerStore):
"device_display_name": device_display_name,
"ts": pushkey_ts,
"lang": lang,
"data": bytearray(encode_canonical_json(data)),
"data": json_encoder.encode(data),
"last_stream_ordering": last_stream_ordering,
"profile_tag": profile_tag,
"id": stream_id,

View file

@ -75,3 +75,22 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
if len(items) <= maxitems:
return str(items)
return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
def strtobool(val: str) -> bool:
"""Convert a string representation of truth to True or False
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
'val' is anything else.
This is lifted from distutils.util.strtobool, with the exception that it actually
returns a bool, rather than an int.
"""
val = val.lower()
if val in ("y", "yes", "t", "true", "on", "1"):
return True
elif val in ("n", "no", "f", "false", "off", "0"):
return False
else:
raise ValueError("invalid truth value %r" % (val,))

View file

@ -145,7 +145,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
self.handler = hs.get_oidc_handler()
self.provider = self.handler._provider
self.provider = self.handler._providers["oidc"]
sso_handler = hs.get_sso_handler()
# Mock the render error method.
self.render_error = Mock(return_value=None)
@ -866,7 +866,7 @@ async def _make_callback_with_userinfo(
from synapse.handlers.oidc_handler import OidcSessionData
handler = hs.get_oidc_handler()
provider = handler._provider
provider = handler._providers["oidc"]
provider._exchange_code = simple_async_mock(return_value={})
provider._parse_id_token = simple_async_mock(return_value=userinfo)
provider._fetch_userinfo = simple_async_mock(return_value=userinfo)

View file

@ -1095,7 +1095,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
# Expire both caches and repeat the request
self.reactor.pump((10000.0,))
# Repated the request, this time it should fail if the lookup fails.
# Repeat the request, this time it should fail if the lookup fails.
fetch_d = defer.ensureDeferred(
self.well_known_resolver.get_well_known(b"testserv")
)
@ -1130,7 +1130,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
content=b'{ "m.server": "' + (b"a" * WELL_KNOWN_MAX_SIZE) + b'" }',
)
# The result is sucessful, but disabled delegation.
# The result is successful, but disabled delegation.
r = self.successResultOf(fetch_d)
self.assertIsNone(r.delegated_server)

101
tests/http/test_client.py Normal file
View file

@ -0,0 +1,101 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from io import BytesIO
from mock import Mock
from twisted.python.failure import Failure
from twisted.web.client import ResponseDone
from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
from tests.unittest import TestCase
class ReadBodyWithMaxSizeTests(TestCase):
def setUp(self):
"""Start reading the body, returns the response, result and proto"""
self.response = Mock()
self.result = BytesIO()
self.deferred = read_body_with_max_size(self.response, self.result, 6)
# Fish the protocol out of the response.
self.protocol = self.response.deliverBody.call_args[0][0]
self.protocol.transport = Mock()
def _cleanup_error(self):
"""Ensure that the error in the Deferred is handled gracefully."""
called = [False]
def errback(f):
called[0] = True
self.deferred.addErrback(errback)
self.assertTrue(called[0])
def test_no_error(self):
"""A response that is NOT too large."""
# Start sending data.
self.protocol.dataReceived(b"12345")
# Close the connection.
self.protocol.connectionLost(Failure(ResponseDone()))
self.assertEqual(self.result.getvalue(), b"12345")
self.assertEqual(self.deferred.result, 5)
def test_too_large(self):
"""A response which is too large raises an exception."""
# Start sending data.
self.protocol.dataReceived(b"1234567890")
# Close the connection.
self.protocol.connectionLost(Failure(ResponseDone()))
self.assertEqual(self.result.getvalue(), b"1234567890")
self.assertIsInstance(self.deferred.result, Failure)
self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
self._cleanup_error()
def test_multiple_packets(self):
"""Data should be accummulated through mutliple packets."""
# Start sending data.
self.protocol.dataReceived(b"12")
self.protocol.dataReceived(b"34")
# Close the connection.
self.protocol.connectionLost(Failure(ResponseDone()))
self.assertEqual(self.result.getvalue(), b"1234")
self.assertEqual(self.deferred.result, 4)
def test_additional_data(self):
"""A connection can receive data after being closed."""
# Start sending data.
self.protocol.dataReceived(b"1234567890")
self.assertIsInstance(self.deferred.result, Failure)
self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
self.protocol.transport.loseConnection.assert_called_once()
# More data might have come in.
self.protocol.dataReceived(b"1234567890")
# Close the connection.
self.protocol.connectionLost(Failure(ResponseDone()))
self.assertEqual(self.result.getvalue(), b"1234567890")
self.assertIsInstance(self.deferred.result, Failure)
self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
self._cleanup_error()

View file

@ -153,8 +153,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
# Allow for uploading and downloading to/from the media repo
self.media_repo = hs.get_media_repository_resource()
self.download_resource = self.media_repo.children[b"download"]
@ -428,7 +426,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Mark the second item as safe from quarantine.
_, media_id_2 = server_and_media_id_2.split("/")
self.get_success(self.store.mark_local_media_as_safe(media_id_2))
# Quarantine the media
url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),)
channel = self.make_request("POST", url, access_token=admin_user_tok)
self.pump(1.0)
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
# Quarantine all media by this user
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(

View file

@ -475,7 +475,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
session_id = channel.json_body["session"]
# do the OIDC auth, but auth as the wrong user
channel = self.helper.auth_via_oidc("wrong_user", ui_auth_session_id=session_id)
channel = self.helper.auth_via_oidc(
{"sub": "wrong_user"}, ui_auth_session_id=session_id
)
# that should return a failure message
self.assertSubstring("We were unable to validate", channel.text_body)

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Tuple
from typing import Dict, List, Set, Tuple
from twisted.trial import unittest
@ -483,22 +483,20 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
login.register_servlets,
]
def test_background_update(self):
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.user_id = self.register_user("foo", "pass")
self.token = self.login("foo", "pass")
self.requester = create_requester(self.user_id)
def _generate_room(self) -> Tuple[str, List[Set[str]]]:
"""Insert a room without a chain cover index.
"""
# Create a room
user_id = self.register_user("foo", "pass")
token = self.login("foo", "pass")
room_id = self.helper.create_room_as(user_id, tok=token)
requester = create_requester(user_id)
store = self.hs.get_datastore()
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
# Mark the room as not having a chain cover index
self.get_success(
store.db_pool.simple_update(
self.store.db_pool.simple_update(
table="rooms",
keyvalues={"room_id": room_id},
updatevalues={"has_auth_chain_index": False},
@ -508,42 +506,44 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
# Create a fork in the DAG with different events.
event_handler = self.hs.get_event_creation_handler()
latest_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
latest_event_ids = self.get_success(
self.store.get_prev_events_for_room(room_id)
)
event, context = self.get_success(
event_handler.create_event(
requester,
self.requester,
{
"type": "some_state_type",
"state_key": "",
"content": {},
"room_id": room_id,
"sender": user_id,
"sender": self.user_id,
},
prev_event_ids=latest_event_ids,
)
)
self.get_success(
event_handler.handle_new_client_event(requester, event, context)
event_handler.handle_new_client_event(self.requester, event, context)
)
state1 = list(self.get_success(context.get_current_state_ids()).values())
state1 = set(self.get_success(context.get_current_state_ids()).values())
event, context = self.get_success(
event_handler.create_event(
requester,
self.requester,
{
"type": "some_state_type",
"state_key": "",
"content": {},
"room_id": room_id,
"sender": user_id,
"sender": self.user_id,
},
prev_event_ids=latest_event_ids,
)
)
self.get_success(
event_handler.handle_new_client_event(requester, event, context)
event_handler.handle_new_client_event(self.requester, event, context)
)
state2 = list(self.get_success(context.get_current_state_ids()).values())
state2 = set(self.get_success(context.get_current_state_ids()).values())
# Delete the chain cover info.
@ -551,36 +551,191 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
txn.execute("DELETE FROM event_auth_chains")
txn.execute("DELETE FROM event_auth_chain_links")
self.get_success(store.db_pool.runInteraction("test", _delete_tables))
self.get_success(self.store.db_pool.runInteraction("test", _delete_tables))
return room_id, [state1, state2]
def test_background_update_single_room(self):
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
# Create a room
room_id, states = self._generate_room()
# Insert and run the background update.
self.get_success(
store.db_pool.simple_insert(
self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "chain_cover", "progress_json": "{}"},
)
)
# Ugh, have to reset this flag
store.db_pool.updates._all_done = False
self.store.db_pool.updates._all_done = False
while not self.get_success(
store.db_pool.updates.has_completed_background_updates()
self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
store.db_pool.updates.do_next_background_update(100), by=0.1
self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Test that the `has_auth_chain_index` has been set
self.assertTrue(self.get_success(store.has_auth_chain_index(room_id)))
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
# Test that calculating the auth chain difference using the newly
# calculated chain cover works.
self.get_success(
store.db_pool.runInteraction(
self.store.db_pool.runInteraction(
"test",
store._get_auth_chain_difference_using_cover_index_txn,
self.store._get_auth_chain_difference_using_cover_index_txn,
room_id,
[state1, state2],
states,
)
)
def test_background_update_multiple_rooms(self):
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
# Create a room
room_id1, states1 = self._generate_room()
room_id2, states2 = self._generate_room()
room_id3, states2 = self._generate_room()
# Insert and run the background update.
self.get_success(
self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "chain_cover", "progress_json": "{}"},
)
)
# Ugh, have to reset this flag
self.store.db_pool.updates._all_done = False
while not self.get_success(
self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Test that the `has_auth_chain_index` has been set
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3)))
# Test that calculating the auth chain difference using the newly
# calculated chain cover works.
self.get_success(
self.store.db_pool.runInteraction(
"test",
self.store._get_auth_chain_difference_using_cover_index_txn,
room_id1,
states1,
)
)
def test_background_update_single_large_room(self):
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
# Create a room
room_id, states = self._generate_room()
# Add a bunch of state so that it takes multiple iterations of the
# background update to process the room.
for i in range(0, 150):
self.helper.send_state(
room_id, event_type="m.test", body={"index": i}, tok=self.token
)
# Insert and run the background update.
self.get_success(
self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "chain_cover", "progress_json": "{}"},
)
)
# Ugh, have to reset this flag
self.store.db_pool.updates._all_done = False
iterations = 0
while not self.get_success(
self.store.db_pool.updates.has_completed_background_updates()
):
iterations += 1
self.get_success(
self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Ensure that we did actually take multiple iterations to process the
# room.
self.assertGreater(iterations, 1)
# Test that the `has_auth_chain_index` has been set
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
# Test that calculating the auth chain difference using the newly
# calculated chain cover works.
self.get_success(
self.store.db_pool.runInteraction(
"test",
self.store._get_auth_chain_difference_using_cover_index_txn,
room_id,
states,
)
)
def test_background_update_multiple_large_room(self):
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
# Create the rooms
room_id1, _ = self._generate_room()
room_id2, _ = self._generate_room()
# Add a bunch of state so that it takes multiple iterations of the
# background update to process the room.
for i in range(0, 150):
self.helper.send_state(
room_id1, event_type="m.test", body={"index": i}, tok=self.token
)
for i in range(0, 150):
self.helper.send_state(
room_id2, event_type="m.test", body={"index": i}, tok=self.token
)
# Insert and run the background update.
self.get_success(
self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "chain_cover", "progress_json": "{}"},
)
)
# Ugh, have to reset this flag
self.store.db_pool.updates._all_done = False
iterations = 0
while not self.get_success(
self.store.db_pool.updates.has_completed_background_updates()
):
iterations += 1
self.get_success(
self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Ensure that we did actually take multiple iterations to process the
# room.
self.assertGreater(iterations, 1)
# Test that the `has_auth_chain_index` has been set
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))