Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2021-01-18 11:14:37 +00:00
commit f5ab7d8306
43 changed files with 1337 additions and 753 deletions

1
changelog.d/9086.feature Normal file
View file

@ -0,0 +1 @@
Add an admin API for protecting local media from quarantine.

1
changelog.d/9093.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to media repository.

1
changelog.d/9108.bugfix Normal file
View file

@ -0,0 +1 @@
Fix "Unhandled error in Deferred: BodyExceededMaxSize" errors when .well-known files that are too large.

1
changelog.d/9110.feature Normal file
View file

@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.

1
changelog.d/9117.bugfix Normal file
View file

@ -0,0 +1 @@
Fix corruption of `pushers` data when a postgres bouncer is used.

1
changelog.d/9124.misc Normal file
View file

@ -0,0 +1 @@
Improve efficiency of large state resolutions.

1
changelog.d/9125.misc Normal file
View file

@ -0,0 +1 @@
Remove dependency on `distutils`.

1
changelog.d/9130.feature Normal file
View file

@ -0,0 +1 @@
Add experimental support for handling and persistence of to-device messages to happen on worker processes.

6
debian/changelog vendored
View file

@ -1,3 +1,9 @@
matrix-synapse-py3 (1.25.0ubuntu1) UNRELEASED; urgency=medium
* Remove dependency on `python3-distutils`.
-- Richard van der Hoff <richard@matrix.org> Fri, 15 Jan 2021 12:44:19 +0000
matrix-synapse-py3 (1.25.0) stable; urgency=medium matrix-synapse-py3 (1.25.0) stable; urgency=medium
[ Dan Callahan ] [ Dan Callahan ]

1
debian/control vendored
View file

@ -31,7 +31,6 @@ Pre-Depends: dpkg (>= 1.16.1)
Depends: Depends:
adduser, adduser,
debconf, debconf,
python3-distutils|libpython3-stdlib (<< 3.6),
${misc:Depends}, ${misc:Depends},
${shlibs:Depends}, ${shlibs:Depends},
${synapse:pydepends}, ${synapse:pydepends},

View file

@ -4,6 +4,7 @@
* [Quarantining media by ID](#quarantining-media-by-id) * [Quarantining media by ID](#quarantining-media-by-id)
* [Quarantining media in a room](#quarantining-media-in-a-room) * [Quarantining media in a room](#quarantining-media-in-a-room)
* [Quarantining all media of a user](#quarantining-all-media-of-a-user) * [Quarantining all media of a user](#quarantining-all-media-of-a-user)
* [Protecting media from being quarantined](#protecting-media-from-being-quarantined)
- [Delete local media](#delete-local-media) - [Delete local media](#delete-local-media)
* [Delete a specific local media](#delete-a-specific-local-media) * [Delete a specific local media](#delete-a-specific-local-media)
* [Delete local media by date or size](#delete-local-media-by-date-or-size) * [Delete local media by date or size](#delete-local-media-by-date-or-size)
@ -123,6 +124,29 @@ The following fields are returned in the JSON response body:
* `num_quarantined`: integer - The number of media items successfully quarantined * `num_quarantined`: integer - The number of media items successfully quarantined
## Protecting media from being quarantined
This API protects a single piece of local media from being quarantined using the
above APIs. This is useful for sticker packs and other shared media which you do
not want to get quarantined, especially when
[quarantining media in a room](#quarantining-media-in-a-room).
Request:
```
POST /_synapse/admin/v1/media/protect/<media_id>
{}
```
Where `media_id` is in the form of `abcdefg12345...`.
Response:
```json
{}
```
# Delete local media # Delete local media
This API deletes the *local* media from the disk of your own server. This API deletes the *local* media from the disk of your own server.
This includes any local thumbnails and copies of media downloaded from This includes any local thumbnails and copies of media downloaded from

View file

@ -42,11 +42,10 @@ as follows:
* For other installation mechanisms, see the documentation provided by the * For other installation mechanisms, see the documentation provided by the
maintainer. maintainer.
To enable the OpenID integration, you should then add an `oidc_config` section To enable the OpenID integration, you should then add a section to the `oidc_providers`
to your configuration file (or uncomment the `enabled: true` line in the setting in your configuration file (or uncomment one of the existing examples).
existing section). See [sample_config.yaml](./sample_config.yaml) for some See [sample_config.yaml](./sample_config.yaml) for some sample settings, as well as
sample settings, as well as the text below for example configurations for the text below for example configurations for specific providers.
specific providers.
## Sample configs ## Sample configs
@ -62,8 +61,9 @@ Directory (tenant) ID as it will be used in the Azure links.
Edit your Synapse config file and change the `oidc_config` section: Edit your Synapse config file and change the `oidc_config` section:
```yaml ```yaml
oidc_config: oidc_providers:
enabled: true - idp_id: microsoft
idp_name: Microsoft
issuer: "https://login.microsoftonline.com/<tenant id>/v2.0" issuer: "https://login.microsoftonline.com/<tenant id>/v2.0"
client_id: "<client id>" client_id: "<client id>"
client_secret: "<client secret>" client_secret: "<client secret>"
@ -103,8 +103,9 @@ Run with `dex serve examples/config-dev.yaml`.
Synapse config: Synapse config:
```yaml ```yaml
oidc_config: oidc_providers:
enabled: true - idp_id: dex
idp_name: "My Dex server"
skip_verification: true # This is needed as Dex is served on an insecure endpoint skip_verification: true # This is needed as Dex is served on an insecure endpoint
issuer: "http://127.0.0.1:5556/dex" issuer: "http://127.0.0.1:5556/dex"
client_id: "synapse" client_id: "synapse"
@ -152,8 +153,9 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to
8. Copy Secret 8. Copy Secret
```yaml ```yaml
oidc_config: oidc_providers:
enabled: true - idp_id: keycloak
idp_name: "My KeyCloak server"
issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}" issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}"
client_id: "synapse" client_id: "synapse"
client_secret: "copy secret generated from above" client_secret: "copy secret generated from above"
@ -191,8 +193,9 @@ oidc_config:
Synapse config: Synapse config:
```yaml ```yaml
oidc_config: oidc_providers:
enabled: true - idp_id: auth0
idp_name: Auth0
issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED
client_id: "your-client-id" # TO BE FILLED client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED client_secret: "your-client-secret" # TO BE FILLED
@ -219,8 +222,9 @@ does not return a `sub` property, an alternative `subject_claim` has to be set.
Synapse config: Synapse config:
```yaml ```yaml
oidc_config: oidc_providers:
enabled: true - idp_id: github
idp_name: Github
discover: false discover: false
issuer: "https://github.com/" issuer: "https://github.com/"
client_id: "your-client-id" # TO BE FILLED client_id: "your-client-id" # TO BE FILLED
@ -243,8 +247,9 @@ oidc_config:
2. add an "OAuth Client ID" for a Web Application under "Credentials". 2. add an "OAuth Client ID" for a Web Application under "Credentials".
3. Copy the Client ID and Client Secret, and add the following to your synapse config: 3. Copy the Client ID and Client Secret, and add the following to your synapse config:
```yaml ```yaml
oidc_config: oidc_providers:
enabled: true - idp_id: google
idp_name: Google
issuer: "https://accounts.google.com/" issuer: "https://accounts.google.com/"
client_id: "your-client-id" # TO BE FILLED client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED client_secret: "your-client-secret" # TO BE FILLED
@ -266,8 +271,9 @@ oidc_config:
Synapse config: Synapse config:
```yaml ```yaml
oidc_config: oidc_providers:
enabled: true - idp_id: twitch
idp_name: Twitch
issuer: "https://id.twitch.tv/oauth2/" issuer: "https://id.twitch.tv/oauth2/"
client_id: "your-client-id" # TO BE FILLED client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED client_secret: "your-client-secret" # TO BE FILLED
@ -287,8 +293,9 @@ oidc_config:
Synapse config: Synapse config:
```yaml ```yaml
oidc_config: oidc_providers:
enabled: true - idp_id: gitlab
idp_name: Gitlab
issuer: "https://gitlab.com/" issuer: "https://gitlab.com/"
client_id: "your-client-id" # TO BE FILLED client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED client_secret: "your-client-secret" # TO BE FILLED

View file

@ -1709,141 +1709,149 @@ saml2_config:
#idp_entityid: 'https://our_idp/entityid' #idp_entityid: 'https://our_idp/entityid'
# Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login. # List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
# and login.
# #
# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md # Options for each entry include:
# for some example configurations.
# #
oidc_config: # idp_id: a unique identifier for this identity provider. Used internally
# Uncomment the following to enable authorization against an OpenID Connect # by Synapse; should be a single word such as 'github'.
# server. Defaults to false.
# #
#enabled: true # Note that, if this is changed, users authenticating via that provider
# will no longer be recognised as the same user!
# Uncomment the following to disable use of the OIDC discovery mechanism to
# discover endpoints. Defaults to true.
# #
#discover: false # idp_name: A user-facing name for this identity provider, which is used to
# offer the user a choice of login mechanisms.
# the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
# discover the provider's endpoints.
# #
# Required if 'enabled' is true. # discover: set to 'false' to disable the use of the OIDC discovery mechanism
# to discover endpoints. Defaults to true.
# #
#issuer: "https://accounts.example.com/" # issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery
# is enabled) to discover the provider's endpoints.
# oauth2 client id to use.
# #
# Required if 'enabled' is true. # client_id: Required. oauth2 client id to use.
# #
#client_id: "provided-by-your-issuer" # client_secret: Required. oauth2 client secret to use.
# oauth2 client secret to use.
# #
# Required if 'enabled' is true. # client_auth_method: auth method to use when exchanging the token. Valid
# # values are 'client_secret_basic' (default), 'client_secret_post' and
#client_secret: "provided-by-your-issuer"
# auth method to use when exchanging the token.
# Valid values are 'client_secret_basic' (default), 'client_secret_post' and
# 'none'. # 'none'.
# #
#client_auth_method: client_secret_post # scopes: list of scopes to request. This should normally include the "openid"
# scope. Defaults to ["openid"].
# list of scopes to request. This should normally include the "openid" scope.
# Defaults to ["openid"].
# #
#scopes: ["openid", "profile"] # authorization_endpoint: the oauth2 authorization endpoint. Required if
# provider discovery is disabled.
# the oauth2 authorization endpoint. Required if provider discovery is disabled.
# #
#authorization_endpoint: "https://accounts.example.com/oauth2/auth" # token_endpoint: the oauth2 token endpoint. Required if provider discovery is
# disabled.
# the oauth2 token endpoint. Required if provider discovery is disabled.
# #
#token_endpoint: "https://accounts.example.com/oauth2/token" # userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is
# disabled and the 'openid' scope is not requested.
# the OIDC userinfo endpoint. Required if discovery is disabled and the
# "openid" scope is not requested.
# #
#userinfo_endpoint: "https://accounts.example.com/userinfo" # jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and
# the 'openid' scope is used.
# URI where to fetch the JWKS. Required if discovery is disabled and the
# "openid" scope is used.
# #
#jwks_uri: "https://accounts.example.com/.well-known/jwks.json" # skip_verification: set to 'true' to skip metadata verification. Use this if
# you are connecting to a provider that is not OpenID Connect compliant.
# Uncomment to skip metadata verification. Defaults to false. # Defaults to false. Avoid this in production.
# #
# Use this if you are connecting to a provider that is not OpenID Connect # user_profile_method: Whether to fetch the user profile from the userinfo
# compliant. # endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
# Avoid this in production.
# #
#skip_verification: true # Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
# included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
# Whether to fetch the user profile from the userinfo endpoint. Valid # userinfo endpoint.
# values are: "auto" or "userinfo_endpoint".
# #
# Defaults to "auto", which fetches the userinfo endpoint if "openid" is included # allow_existing_users: set to 'true' to allow a user logging in via OIDC to
# in `scopes`. Uncomment the following to always fetch the userinfo endpoint. # match a pre-existing account instead of failing. This could be used if
# switching from password logins to OIDC. Defaults to false.
# #
#user_profile_method: "userinfo_endpoint" # user_mapping_provider: Configuration for how attributes returned from a OIDC
# provider are mapped onto a matrix user. This setting has the following
# Uncomment to allow a user logging in via OIDC to match a pre-existing account instead # sub-properties:
# of failing. This could be used if switching from password logins to OIDC. Defaults to false.
#
#allow_existing_users: true
# An external module can be provided here as a custom solution to mapping
# attributes returned from a OIDC provider onto a matrix user.
#
user_mapping_provider:
# The custom module's class. Uncomment to use a custom module.
# Default is 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'.
# #
# module: The class name of a custom mapping module. Default is
# 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'.
# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
# for information on implementing a custom mapping provider. # for information on implementing a custom mapping provider.
# #
#module: mapping_provider.OidcMappingProvider # config: Configuration for the mapping provider module. This section will
# be passed as a Python dictionary to the user mapping provider
# Custom configuration values for the module. This section will be passed as # module's `parse_config` method.
# a Python dictionary to the user mapping provider module's `parse_config`
# method.
# #
# The examples below are intended for the default provider: they should be # For the default provider, the following settings are available:
# changed if using a custom provider.
# #
config: # sub: name of the claim containing a unique identifier for the
# name of the claim containing a unique identifier for the user. # user. Defaults to 'sub', which OpenID Connect compliant
# Defaults to `sub`, which OpenID Connect compliant providers should provide. # providers should provide.
#
#subject_claim: "sub"
# Jinja2 template for the localpart of the MXID.
#
# When rendering, this template is given the following variables:
# * user: The claims returned by the UserInfo Endpoint and/or in the ID
# Token
# #
# localpart_template: Jinja2 template for the localpart of the MXID.
# If this is not set, the user will be prompted to choose their # If this is not set, the user will be prompted to choose their
# own username. # own username.
# #
#localpart_template: "{{ user.preferred_username }}" # display_name_template: Jinja2 template for the display name to set
# on first login. If unset, no displayname will be set.
#
# extra_attributes: a map of Jinja2 templates for extra attributes
# to send back to the client during login.
# Note that these are non-standard and clients will ignore them
# without modifications.
#
# When rendering, the Jinja2 templates are given a 'user' variable,
# which is set to the claims returned by the UserInfo Endpoint and/or
# in the ID Token.
#
# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
# for information on how to configure these options.
#
# For backwards compatibility, it is also possible to configure a single OIDC
# provider via an 'oidc_config' setting. This is now deprecated and admins are
# advised to migrate to the 'oidc_providers' format.
#
oidc_providers:
# Generic example
#
#- idp_id: my_idp
# idp_name: "My OpenID provider"
# discover: false
# issuer: "https://accounts.example.com/"
# client_id: "provided-by-your-issuer"
# client_secret: "provided-by-your-issuer"
# client_auth_method: client_secret_post
# scopes: ["openid", "profile"]
# authorization_endpoint: "https://accounts.example.com/oauth2/auth"
# token_endpoint: "https://accounts.example.com/oauth2/token"
# userinfo_endpoint: "https://accounts.example.com/userinfo"
# jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
# skip_verification: true
# Jinja2 template for the display name to set on first login. # For use with Keycloak
# #
# If unset, no displayname will be set. #- idp_id: keycloak
# # idp_name: Keycloak
#display_name_template: "{{ user.given_name }} {{ user.last_name }}" # issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name"
# client_id: "synapse"
# client_secret: "copy secret generated in Keycloak UI"
# scopes: ["openid", "profile"]
# Jinja2 templates for extra attributes to send back to the client during # For use with Github
# login.
# #
# Note that these are non-standard and clients will ignore them without modifications. #- idp_id: google
# # idp_name: Google
#extra_attributes: # discover: false
#birthdate: "{{ user.birthdate }}" # issuer: "https://github.com/"
# client_id: "your-client-id" # TO BE FILLED
# client_secret: "your-client-secret" # TO BE FILLED
# authorization_endpoint: "https://github.com/login/oauth/authorize"
# token_endpoint: "https://github.com/login/oauth/access_token"
# userinfo_endpoint: "https://api.github.com/user"
# scopes: ["read:user"]
# user_mapping_provider:
# config:
# subject_claim: "id"
# localpart_template: "{ user.login }"
# display_name_template: "{ user.name }"
# Enable Central Authentication Service (CAS) for registration and login. # Enable Central Authentication Service (CAS) for registration and login.

View file

@ -16,6 +16,9 @@ workers only work with PostgreSQL-based Synapse deployments. SQLite should only
be used for demo purposes and any admin considering workers should already be be used for demo purposes and any admin considering workers should already be
running PostgreSQL. running PostgreSQL.
See also https://matrix.org/blog/2020/11/03/how-we-fixed-synapses-scalability
for a higher level overview.
## Main process/worker communication ## Main process/worker communication
The processes communicate with each other via a Synapse-specific protocol called The processes communicate with each other via a Synapse-specific protocol called

View file

@ -40,7 +40,7 @@ class CasConfig(Config):
self.cas_required_attributes = {} self.cas_required_attributes = {}
def generate_config_section(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """ return """\
# Enable Central Authentication Service (CAS) for registration and login. # Enable Central Authentication Service (CAS) for registration and login.
# #
cas_config: cas_config:

View file

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import string import string
from typing import Optional, Type from typing import Iterable, Optional, Type
import attr import attr
@ -33,16 +33,8 @@ class OIDCConfig(Config):
section = "oidc" section = "oidc"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
validate_config(MAIN_CONFIG_SCHEMA, config, ()) self.oidc_providers = tuple(_parse_oidc_provider_configs(config))
if not self.oidc_providers:
self.oidc_provider = None # type: Optional[OidcProviderConfig]
oidc_config = config.get("oidc_config")
if oidc_config and oidc_config.get("enabled", False):
validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
self.oidc_provider = _parse_oidc_config_dict(oidc_config)
if not self.oidc_provider:
return return
try: try:
@ -58,144 +50,153 @@ class OIDCConfig(Config):
@property @property
def oidc_enabled(self) -> bool: def oidc_enabled(self) -> bool:
# OIDC is enabled if we have a provider # OIDC is enabled if we have a provider
return bool(self.oidc_provider) return bool(self.oidc_providers)
def generate_config_section(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\ return """\
# Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login. # List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
# and login.
# #
# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md # Options for each entry include:
# for some example configurations.
# #
oidc_config: # idp_id: a unique identifier for this identity provider. Used internally
# Uncomment the following to enable authorization against an OpenID Connect # by Synapse; should be a single word such as 'github'.
# server. Defaults to false.
# #
#enabled: true # Note that, if this is changed, users authenticating via that provider
# will no longer be recognised as the same user!
# Uncomment the following to disable use of the OIDC discovery mechanism to
# discover endpoints. Defaults to true.
# #
#discover: false # idp_name: A user-facing name for this identity provider, which is used to
# offer the user a choice of login mechanisms.
# the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
# discover the provider's endpoints.
# #
# Required if 'enabled' is true. # discover: set to 'false' to disable the use of the OIDC discovery mechanism
# to discover endpoints. Defaults to true.
# #
#issuer: "https://accounts.example.com/" # issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery
# is enabled) to discover the provider's endpoints.
# oauth2 client id to use.
# #
# Required if 'enabled' is true. # client_id: Required. oauth2 client id to use.
# #
#client_id: "provided-by-your-issuer" # client_secret: Required. oauth2 client secret to use.
# oauth2 client secret to use.
# #
# Required if 'enabled' is true. # client_auth_method: auth method to use when exchanging the token. Valid
# # values are 'client_secret_basic' (default), 'client_secret_post' and
#client_secret: "provided-by-your-issuer"
# auth method to use when exchanging the token.
# Valid values are 'client_secret_basic' (default), 'client_secret_post' and
# 'none'. # 'none'.
# #
#client_auth_method: client_secret_post # scopes: list of scopes to request. This should normally include the "openid"
# scope. Defaults to ["openid"].
# list of scopes to request. This should normally include the "openid" scope.
# Defaults to ["openid"].
# #
#scopes: ["openid", "profile"] # authorization_endpoint: the oauth2 authorization endpoint. Required if
# provider discovery is disabled.
# the oauth2 authorization endpoint. Required if provider discovery is disabled.
# #
#authorization_endpoint: "https://accounts.example.com/oauth2/auth" # token_endpoint: the oauth2 token endpoint. Required if provider discovery is
# disabled.
# the oauth2 token endpoint. Required if provider discovery is disabled.
# #
#token_endpoint: "https://accounts.example.com/oauth2/token" # userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is
# disabled and the 'openid' scope is not requested.
# the OIDC userinfo endpoint. Required if discovery is disabled and the
# "openid" scope is not requested.
# #
#userinfo_endpoint: "https://accounts.example.com/userinfo" # jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and
# the 'openid' scope is used.
# URI where to fetch the JWKS. Required if discovery is disabled and the
# "openid" scope is used.
# #
#jwks_uri: "https://accounts.example.com/.well-known/jwks.json" # skip_verification: set to 'true' to skip metadata verification. Use this if
# you are connecting to a provider that is not OpenID Connect compliant.
# Uncomment to skip metadata verification. Defaults to false. # Defaults to false. Avoid this in production.
# #
# Use this if you are connecting to a provider that is not OpenID Connect # user_profile_method: Whether to fetch the user profile from the userinfo
# compliant. # endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
# Avoid this in production.
# #
#skip_verification: true # Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
# included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
# Whether to fetch the user profile from the userinfo endpoint. Valid # userinfo endpoint.
# values are: "auto" or "userinfo_endpoint".
# #
# Defaults to "auto", which fetches the userinfo endpoint if "openid" is included # allow_existing_users: set to 'true' to allow a user logging in via OIDC to
# in `scopes`. Uncomment the following to always fetch the userinfo endpoint. # match a pre-existing account instead of failing. This could be used if
# switching from password logins to OIDC. Defaults to false.
# #
#user_profile_method: "userinfo_endpoint" # user_mapping_provider: Configuration for how attributes returned from a OIDC
# provider are mapped onto a matrix user. This setting has the following
# Uncomment to allow a user logging in via OIDC to match a pre-existing account instead # sub-properties:
# of failing. This could be used if switching from password logins to OIDC. Defaults to false.
#
#allow_existing_users: true
# An external module can be provided here as a custom solution to mapping
# attributes returned from a OIDC provider onto a matrix user.
#
user_mapping_provider:
# The custom module's class. Uncomment to use a custom module.
# Default is {mapping_provider!r}.
# #
# module: The class name of a custom mapping module. Default is
# {mapping_provider!r}.
# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
# for information on implementing a custom mapping provider. # for information on implementing a custom mapping provider.
# #
#module: mapping_provider.OidcMappingProvider # config: Configuration for the mapping provider module. This section will
# be passed as a Python dictionary to the user mapping provider
# Custom configuration values for the module. This section will be passed as # module's `parse_config` method.
# a Python dictionary to the user mapping provider module's `parse_config`
# method.
# #
# The examples below are intended for the default provider: they should be # For the default provider, the following settings are available:
# changed if using a custom provider.
# #
config: # sub: name of the claim containing a unique identifier for the
# name of the claim containing a unique identifier for the user. # user. Defaults to 'sub', which OpenID Connect compliant
# Defaults to `sub`, which OpenID Connect compliant providers should provide. # providers should provide.
#
#subject_claim: "sub"
# Jinja2 template for the localpart of the MXID.
#
# When rendering, this template is given the following variables:
# * user: The claims returned by the UserInfo Endpoint and/or in the ID
# Token
# #
# localpart_template: Jinja2 template for the localpart of the MXID.
# If this is not set, the user will be prompted to choose their # If this is not set, the user will be prompted to choose their
# own username. # own username.
# #
#localpart_template: "{{{{ user.preferred_username }}}}" # display_name_template: Jinja2 template for the display name to set
# on first login. If unset, no displayname will be set.
#
# extra_attributes: a map of Jinja2 templates for extra attributes
# to send back to the client during login.
# Note that these are non-standard and clients will ignore them
# without modifications.
#
# When rendering, the Jinja2 templates are given a 'user' variable,
# which is set to the claims returned by the UserInfo Endpoint and/or
# in the ID Token.
#
# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
# for information on how to configure these options.
#
# For backwards compatibility, it is also possible to configure a single OIDC
# provider via an 'oidc_config' setting. This is now deprecated and admins are
# advised to migrate to the 'oidc_providers' format.
#
oidc_providers:
# Generic example
#
#- idp_id: my_idp
# idp_name: "My OpenID provider"
# discover: false
# issuer: "https://accounts.example.com/"
# client_id: "provided-by-your-issuer"
# client_secret: "provided-by-your-issuer"
# client_auth_method: client_secret_post
# scopes: ["openid", "profile"]
# authorization_endpoint: "https://accounts.example.com/oauth2/auth"
# token_endpoint: "https://accounts.example.com/oauth2/token"
# userinfo_endpoint: "https://accounts.example.com/userinfo"
# jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
# skip_verification: true
# Jinja2 template for the display name to set on first login. # For use with Keycloak
# #
# If unset, no displayname will be set. #- idp_id: keycloak
# # idp_name: Keycloak
#display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}" # issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name"
# client_id: "synapse"
# client_secret: "copy secret generated in Keycloak UI"
# scopes: ["openid", "profile"]
# Jinja2 templates for extra attributes to send back to the client during # For use with Github
# login.
# #
# Note that these are non-standard and clients will ignore them without modifications. #- idp_id: google
# # idp_name: Google
#extra_attributes: # discover: false
#birthdate: "{{{{ user.birthdate }}}}" # issuer: "https://github.com/"
# client_id: "your-client-id" # TO BE FILLED
# client_secret: "your-client-secret" # TO BE FILLED
# authorization_endpoint: "https://github.com/login/oauth/authorize"
# token_endpoint: "https://github.com/login/oauth/access_token"
# userinfo_endpoint: "https://api.github.com/user"
# scopes: ["read:user"]
# user_mapping_provider:
# config:
# subject_claim: "id"
# localpart_template: "{{ user.login }}"
# display_name_template: "{{ user.name }}"
""".format( """.format(
mapping_provider=DEFAULT_USER_MAPPING_PROVIDER mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
) )
@ -234,7 +235,22 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
}, },
} }
# the `oidc_config` setting can either be None (as it is in the default # the same as OIDC_PROVIDER_CONFIG_SCHEMA, but with compulsory idp_id and idp_name
OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA = {
"allOf": [OIDC_PROVIDER_CONFIG_SCHEMA, {"required": ["idp_id", "idp_name"]}]
}
# the `oidc_providers` list can either be None (as it is in the default config), or
# a list of provider configs, each of which requires an explicit ID and name.
OIDC_PROVIDER_LIST_SCHEMA = {
"oneOf": [
{"type": "null"},
{"type": "array", "items": OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA},
]
}
# the `oidc_config` setting can either be None (which it used to be in the default
# config), or an object. If an object, it is ignored unless it has an "enabled: True" # config), or an object. If an object, it is ignored unless it has an "enabled: True"
# property. # property.
# #
@ -243,12 +259,41 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
# additional checks in the code. # additional checks in the code.
OIDC_CONFIG_SCHEMA = {"oneOf": [{"type": "null"}, {"type": "object"}]} OIDC_CONFIG_SCHEMA = {"oneOf": [{"type": "null"}, {"type": "object"}]}
# the top-level schema can contain an "oidc_config" and/or an "oidc_providers".
MAIN_CONFIG_SCHEMA = { MAIN_CONFIG_SCHEMA = {
"type": "object", "type": "object",
"properties": {"oidc_config": OIDC_CONFIG_SCHEMA}, "properties": {
"oidc_config": OIDC_CONFIG_SCHEMA,
"oidc_providers": OIDC_PROVIDER_LIST_SCHEMA,
},
} }
def _parse_oidc_provider_configs(config: JsonDict) -> Iterable["OidcProviderConfig"]:
"""extract and parse the OIDC provider configs from the config dict
The configuration may contain either a single `oidc_config` object with an
`enabled: True` property, or a list of provider configurations under
`oidc_providers`, *or both*.
Returns a generator which yields the OidcProviderConfig objects
"""
validate_config(MAIN_CONFIG_SCHEMA, config, ())
for p in config.get("oidc_providers") or []:
yield _parse_oidc_config_dict(p)
# for backwards-compatibility, it is also possible to provide a single "oidc_config"
# object with an "enabled: True" property.
oidc_config = config.get("oidc_config")
if oidc_config and oidc_config.get("enabled", False):
# MAIN_CONFIG_SCHEMA checks that `oidc_config` is an object, but not that
# it matches OIDC_PROVIDER_CONFIG_SCHEMA (see the comments on OIDC_CONFIG_SCHEMA
# above), so now we need to validate it.
validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
yield _parse_oidc_config_dict(oidc_config)
def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig": def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig":
"""Take the configuration dict and parse it into an OidcProviderConfig """Take the configuration dict and parse it into an OidcProviderConfig

View file

@ -14,14 +14,13 @@
# limitations under the License. # limitations under the License.
import os import os
from distutils.util import strtobool
import pkg_resources import pkg_resources
from synapse.api.constants import RoomCreationPreset from synapse.api.constants import RoomCreationPreset
from synapse.config._base import Config, ConfigError from synapse.config._base import Config, ConfigError
from synapse.types import RoomAlias, UserID from synapse.types import RoomAlias, UserID
from synapse.util.stringutils import random_string_with_symbols from synapse.util.stringutils import random_string_with_symbols, strtobool
class AccountValidityConfig(Config): class AccountValidityConfig(Config):
@ -86,12 +85,12 @@ class RegistrationConfig(Config):
section = "registration" section = "registration"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.enable_registration = bool( self.enable_registration = strtobool(
strtobool(str(config.get("enable_registration", False))) str(config.get("enable_registration", False))
) )
if "disable_registration" in config: if "disable_registration" in config:
self.enable_registration = not bool( self.enable_registration = not strtobool(
strtobool(str(config["disable_registration"])) str(config["disable_registration"])
) )
self.account_validity = AccountValidityConfig( self.account_validity = AccountValidityConfig(

View file

@ -17,7 +17,6 @@
import abc import abc
import os import os
from distutils.util import strtobool
from typing import Dict, Optional, Tuple, Type from typing import Dict, Optional, Tuple, Type
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
@ -26,6 +25,7 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVers
from synapse.types import JsonDict, RoomStreamToken from synapse.types import JsonDict, RoomStreamToken
from synapse.util.caches import intern_dict from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze from synapse.util.frozenutils import freeze
from synapse.util.stringutils import strtobool
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
# bugs where we accidentally share e.g. signature dicts. However, converting a # bugs where we accidentally share e.g. signature dicts. However, converting a
@ -34,6 +34,7 @@ from synapse.util.frozenutils import freeze
# NOTE: This is overridden by the configuration by the Synapse worker apps, but # NOTE: This is overridden by the configuration by the Synapse worker apps, but
# for the sake of tests, it is set here while it cannot be configured on the # for the sake of tests, it is set here while it cannot be configured on the
# homeserver object itself. # homeserver object itself.
USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0")) USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))

View file

@ -163,7 +163,7 @@ class DeviceMessageHandler:
await self.store.mark_remote_user_device_cache_as_stale(sender_user_id) await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
# Immediately attempt a resync in the background # Immediately attempt a resync in the background
run_in_background(self._user_device_resync, sender_user_id) run_in_background(self._user_device_resync, user_id=sender_user_id)
async def send_device_message( async def send_device_message(
self, self,

View file

@ -78,21 +78,28 @@ class OidcHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
provider_conf = hs.config.oidc.oidc_provider provider_confs = hs.config.oidc.oidc_providers
# we should not have been instantiated if there is no configured provider. # we should not have been instantiated if there is no configured provider.
assert provider_conf is not None assert provider_confs
self._token_generator = OidcSessionTokenGenerator(hs) self._token_generator = OidcSessionTokenGenerator(hs)
self._providers = {
self._provider = OidcProvider(hs, self._token_generator, provider_conf) p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
}
async def load_metadata(self) -> None: async def load_metadata(self) -> None:
"""Validate the config and load the metadata from the remote endpoint. """Validate the config and load the metadata from the remote endpoint.
Called at startup to ensure we have everything we need. Called at startup to ensure we have everything we need.
""" """
await self._provider.load_metadata() for idp_id, p in self._providers.items():
await self._provider.load_jwks() try:
await p.load_metadata()
await p.load_jwks()
except Exception as e:
raise Exception(
"Error while initialising OIDC provider %r" % (idp_id,)
) from e
async def handle_oidc_callback(self, request: SynapseRequest) -> None: async def handle_oidc_callback(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_synapse/oidc/callback """Handle an incoming request to /_synapse/oidc/callback
@ -184,6 +191,12 @@ class OidcHandler:
self._sso_handler.render_error(request, "mismatching_session", str(e)) self._sso_handler.render_error(request, "mismatching_session", str(e))
return return
oidc_provider = self._providers.get(session_data.idp_id)
if not oidc_provider:
logger.error("OIDC session uses unknown IdP %r", oidc_provider)
self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP")
return
if b"code" not in request.args: if b"code" not in request.args:
logger.info("Code parameter is missing") logger.info("Code parameter is missing")
self._sso_handler.render_error( self._sso_handler.render_error(
@ -193,7 +206,7 @@ class OidcHandler:
code = request.args[b"code"][0].decode() code = request.args[b"code"][0].decode()
await self._provider.handle_oidc_callback(request, session_data, code) await oidc_provider.handle_oidc_callback(request, session_data, code)
class OidcError(Exception): class OidcError(Exception):

View file

@ -766,14 +766,24 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
self.max_size = max_size self.max_size = max_size
def dataReceived(self, data: bytes) -> None: def dataReceived(self, data: bytes) -> None:
# If the deferred was called, bail early.
if self.deferred.called:
return
self.stream.write(data) self.stream.write(data)
self.length += len(data) self.length += len(data)
# The first time the maximum size is exceeded, error and cancel the
# connection. dataReceived might be called again if data was received
# in the meantime.
if self.max_size is not None and self.length >= self.max_size: if self.max_size is not None and self.length >= self.max_size:
self.deferred.errback(BodyExceededMaxSize()) self.deferred.errback(BodyExceededMaxSize())
self.deferred = defer.Deferred()
self.transport.loseConnection() self.transport.loseConnection()
def connectionLost(self, reason: Failure) -> None: def connectionLost(self, reason: Failure) -> None:
# If the maximum size was already exceeded, there's nothing to do.
if self.deferred.called:
return
if reason.check(ResponseDone): if reason.check(ResponseDone):
self.deferred.callback(self.length) self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss): elif reason.check(PotentialDataLoss):

View file

@ -15,6 +15,9 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from twisted.web.http import Request
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
@ -23,6 +26,10 @@ from synapse.rest.admin._base import (
assert_requester_is_admin, assert_requester_is_admin,
assert_user_is_admin, assert_user_is_admin,
) )
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,11 +46,11 @@ class QuarantineMediaInRoom(RestServlet):
admin_patterns("/quarantine_media/(?P<room_id>[^/]+)") admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_POST(self, request, room_id: str): async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
@ -64,11 +71,11 @@ class QuarantineMediaByUser(RestServlet):
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine") PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_POST(self, request, user_id: str): async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
@ -91,11 +98,13 @@ class QuarantineMediaByID(RestServlet):
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)" "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_POST(self, request, server_name: str, media_id: str): async def on_POST(
self, request: Request, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
@ -109,17 +118,39 @@ class QuarantineMediaByID(RestServlet):
return 200, {} return 200, {}
class ProtectMediaByID(RestServlet):
"""Protect local media from being quarantined.
"""
PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
logging.info("Protecting local media by ID: %s", media_id)
# Quarantine this media id
await self.store.mark_local_media_as_safe(media_id)
return 200, {}
class ListMediaInRoom(RestServlet): class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room. """Lists all of the media in a given room.
""" """
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media") PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request, room_id): async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
is_admin = await self.auth.is_server_admin(requester.user) is_admin = await self.auth.is_server_admin(requester.user)
if not is_admin: if not is_admin:
@ -133,11 +164,11 @@ class ListMediaInRoom(RestServlet):
class PurgeMediaCacheRestServlet(RestServlet): class PurgeMediaCacheRestServlet(RestServlet):
PATTERNS = admin_patterns("/purge_media_cache") PATTERNS = admin_patterns("/purge_media_cache")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.media_repository = hs.get_media_repository() self.media_repository = hs.get_media_repository()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_POST(self, request): async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True) before_ts = parse_integer(request, "before_ts", required=True)
@ -154,13 +185,15 @@ class DeleteMediaByID(RestServlet):
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)") PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.server_name = hs.hostname self.server_name = hs.hostname
self.media_repository = hs.get_media_repository() self.media_repository = hs.get_media_repository()
async def on_DELETE(self, request, server_name: str, media_id: str): async def on_DELETE(
self, request: Request, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
if self.server_name != server_name: if self.server_name != server_name:
@ -182,13 +215,13 @@ class DeleteMediaByDateSize(RestServlet):
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete") PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.server_name = hs.hostname self.server_name = hs.hostname
self.media_repository = hs.get_media_repository() self.media_repository = hs.get_media_repository()
async def on_POST(self, request, server_name: str): async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True) before_ts = parse_integer(request, "before_ts", required=True)
@ -222,7 +255,7 @@ class DeleteMediaByDateSize(RestServlet):
return 200, {"deleted_media": deleted_media, "total": total} return 200, {"deleted_media": deleted_media, "total": total}
def register_servlets_for_media_repo(hs, http_server): def register_servlets_for_media_repo(hs: "HomeServer", http_server):
""" """
Media repo specific APIs. Media repo specific APIs.
""" """
@ -230,6 +263,7 @@ def register_servlets_for_media_repo(hs, http_server):
QuarantineMediaInRoom(hs).register(http_server) QuarantineMediaInRoom(hs).register(http_server)
QuarantineMediaByID(hs).register(http_server) QuarantineMediaByID(hs).register(http_server)
QuarantineMediaByUser(hs).register(http_server) QuarantineMediaByUser(hs).register(http_server)
ProtectMediaByID(hs).register(http_server)
ListMediaInRoom(hs).register(http_server) ListMediaInRoom(hs).register(http_server)
DeleteMediaByID(hs).register(http_server) DeleteMediaByID(hs).register(http_server)
DeleteMediaByDateSize(hs).register(http_server) DeleteMediaByDateSize(hs).register(http_server)

View file

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd # Copyright 2019-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -17,10 +17,11 @@
import logging import logging
import os import os
import urllib import urllib
from typing import Awaitable from typing import Awaitable, Dict, Generator, List, Optional, Tuple
from twisted.internet.interfaces import IConsumer from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError, cs_error from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json from synapse.http.server import finish_request, respond_with_json
@ -46,7 +47,7 @@ TEXT_CONTENT_TYPES = [
] ]
def parse_media_id(request): def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
try: try:
# This allows users to append e.g. /test.png to the URL. Useful for # This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type. # clients that parse the URL to see content type.
@ -69,7 +70,7 @@ def parse_media_id(request):
) )
def respond_404(request): def respond_404(request: Request) -> None:
respond_with_json( respond_with_json(
request, request,
404, 404,
@ -79,8 +80,12 @@ def respond_404(request):
async def respond_with_file( async def respond_with_file(
request, media_type, file_path, file_size=None, upload_name=None request: Request,
): media_type: str,
file_path: str,
file_size: Optional[int] = None,
upload_name: Optional[str] = None,
) -> None:
logger.debug("Responding with %r", file_path) logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path): if os.path.isfile(file_path):
@ -98,15 +103,20 @@ async def respond_with_file(
respond_404(request) respond_404(request)
def add_file_headers(request, media_type, file_size, upload_name): def add_file_headers(
request: Request,
media_type: str,
file_size: Optional[int],
upload_name: Optional[str],
) -> None:
"""Adds the correct response headers in preparation for responding with the """Adds the correct response headers in preparation for responding with the
media. media.
Args: Args:
request (twisted.web.http.Request) request
media_type (str): The media/content type. media_type: The media/content type.
file_size (int): Size in bytes of the media, if known. file_size: Size in bytes of the media, if known.
upload_name (str): The name of the requested file, if any. upload_name: The name of the requested file, if any.
""" """
def _quote(x): def _quote(x):
@ -153,6 +163,7 @@ def add_file_headers(request, media_type, file_size, upload_name):
# select private. don't bother setting Expires as all our # select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control # clients are smart enough to be happy with Cache-Control
request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
if file_size is not None:
request.setHeader(b"Content-Length", b"%d" % (file_size,)) request.setHeader(b"Content-Length", b"%d" % (file_size,))
# Tell web crawlers to not index, archive, or follow links in media. This # Tell web crawlers to not index, archive, or follow links in media. This
@ -184,7 +195,7 @@ _FILENAME_SEPARATOR_CHARS = {
} }
def _can_encode_filename_as_token(x): def _can_encode_filename_as_token(x: str) -> bool:
for c in x: for c in x:
# from RFC2616: # from RFC2616:
# #
@ -206,17 +217,21 @@ def _can_encode_filename_as_token(x):
async def respond_with_responder( async def respond_with_responder(
request, responder, media_type, file_size, upload_name=None request: Request,
): responder: "Optional[Responder]",
media_type: str,
file_size: Optional[int],
upload_name: Optional[str] = None,
) -> None:
"""Responds to the request with given responder. If responder is None then """Responds to the request with given responder. If responder is None then
returns 404. returns 404.
Args: Args:
request (twisted.web.http.Request) request
responder (Responder|None) responder
media_type (str): The media/content type. media_type: The media/content type.
file_size (int|None): Size in bytes of the media. If not known it should be None file_size: Size in bytes of the media. If not known it should be None
upload_name (str|None): The name of the requested file, if any. upload_name: The name of the requested file, if any.
""" """
if request._disconnected: if request._disconnected:
logger.warning( logger.warning(
@ -308,22 +323,22 @@ class FileInfo:
self.thumbnail_type = thumbnail_type self.thumbnail_type = thumbnail_type
def get_filename_from_headers(headers): def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
""" """
Get the filename of the downloaded file by inspecting the Get the filename of the downloaded file by inspecting the
Content-Disposition HTTP header. Content-Disposition HTTP header.
Args: Args:
headers (dict[bytes, list[bytes]]): The HTTP request headers. headers: The HTTP request headers.
Returns: Returns:
A Unicode string of the filename, or None. The filename, or None.
""" """
content_disposition = headers.get(b"Content-Disposition", [b""]) content_disposition = headers.get(b"Content-Disposition", [b""])
# No header, bail out. # No header, bail out.
if not content_disposition[0]: if not content_disposition[0]:
return return None
_, params = _parse_header(content_disposition[0]) _, params = _parse_header(content_disposition[0])
@ -356,17 +371,16 @@ def get_filename_from_headers(headers):
return upload_name return upload_name
def _parse_header(line): def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]:
"""Parse a Content-type like header. """Parse a Content-type like header.
Cargo-culted from `cgi`, but works on bytes rather than strings. Cargo-culted from `cgi`, but works on bytes rather than strings.
Args: Args:
line (bytes): header to be parsed line: header to be parsed
Returns: Returns:
Tuple[bytes, dict[bytes, bytes]]: The main content-type, followed by the parameter dictionary
the main content-type, followed by the parameter dictionary
""" """
parts = _parseparam(b";" + line) parts = _parseparam(b";" + line)
key = next(parts) key = next(parts)
@ -386,16 +400,16 @@ def _parse_header(line):
return key, pdict return key, pdict
def _parseparam(s): def _parseparam(s: bytes) -> Generator[bytes, None, None]:
"""Generator which splits the input on ;, respecting double-quoted sequences """Generator which splits the input on ;, respecting double-quoted sequences
Cargo-culted from `cgi`, but works on bytes rather than strings. Cargo-culted from `cgi`, but works on bytes rather than strings.
Args: Args:
s (bytes): header to be parsed s: header to be parsed
Returns: Returns:
Iterable[bytes]: the split input The split input
""" """
while s[:1] == b";": while s[:1] == b";":
s = s[1:] s = s[1:]

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2018 Will Hunt <will@half-shot.uk> # Copyright 2018 Will Hunt <will@half-shot.uk>
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,22 +15,29 @@
# limitations under the License. # limitations under the License.
# #
from typing import TYPE_CHECKING
from twisted.web.http import Request
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import DirectServeJsonResource, respond_with_json
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
class MediaConfigResource(DirectServeJsonResource): class MediaConfigResource(DirectServeJsonResource):
isLeaf = True isLeaf = True
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
config = hs.get_config() config = hs.get_config()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size} self.limits_dict = {"m.upload.size": config.max_upload_size}
async def _async_render_GET(self, request): async def _async_render_GET(self, request: Request) -> None:
await self.auth.get_user_by_req(request) await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True) respond_with_json(request, 200, self.limits_dict, send_cors=True)
async def _async_render_OPTIONS(self, request): async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,24 +14,31 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from twisted.web.http import Request
import synapse.http.servlet
from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_boolean
from ._base import parse_media_id, respond_404 from ._base import parse_media_id, respond_404
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.rest.media.v1.media_repository import MediaRepository
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DownloadResource(DirectServeJsonResource): class DownloadResource(DirectServeJsonResource):
isLeaf = True isLeaf = True
def __init__(self, hs, media_repo): def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__() super().__init__()
self.media_repo = media_repo self.media_repo = media_repo
self.server_name = hs.hostname self.server_name = hs.hostname
async def _async_render_GET(self, request): async def _async_render_GET(self, request: Request) -> None:
set_cors_headers(request) set_cors_headers(request)
request.setHeader( request.setHeader(
b"Content-Security-Policy", b"Content-Security-Policy",
@ -49,9 +57,7 @@ class DownloadResource(DirectServeJsonResource):
if server_name == self.server_name: if server_name == self.server_name:
await self.media_repo.get_local_media(request, media_id, name) await self.media_repo.get_local_media(request, media_id, name)
else: else:
allow_remote = synapse.http.servlet.parse_boolean( allow_remote = parse_boolean(request, "allow_remote", default=True)
request, "allow_remote", default=True
)
if not allow_remote: if not allow_remote:
logger.info( logger.info(
"Rejecting request for remote media %s/%s due to allow_remote", "Rejecting request for remote media %s/%s due to allow_remote",

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,11 +17,12 @@
import functools import functools
import os import os
import re import re
from typing import Callable, List
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
def _wrap_in_base_path(func): def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]":
"""Takes a function that returns a relative path and turns it into an """Takes a function that returns a relative path and turns it into an
absolute path based on the location of the primary media store absolute path based on the location of the primary media store
""" """
@ -41,12 +43,18 @@ class MediaFilePaths:
to write to the backup media store (when one is configured) to write to the backup media store (when one is configured)
""" """
def __init__(self, primary_base_path): def __init__(self, primary_base_path: str):
self.base_path = primary_base_path self.base_path = primary_base_path
def default_thumbnail_rel( def default_thumbnail_rel(
self, default_top_level, default_sub_type, width, height, content_type, method self,
): default_top_level: str,
default_sub_type: str,
width: int,
height: int,
content_type: str,
method: str,
) -> str:
top_level_type, sub_type = content_type.split("/") top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join( return os.path.join(
@ -55,12 +63,14 @@ class MediaFilePaths:
default_thumbnail = _wrap_in_base_path(default_thumbnail_rel) default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
def local_media_filepath_rel(self, media_id): def local_media_filepath_rel(self, media_id: str) -> str:
return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:]) return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
def local_media_thumbnail_rel(self, media_id, width, height, content_type, method): def local_media_thumbnail_rel(
self, media_id: str, width: int, height: int, content_type: str, method: str
) -> str:
top_level_type, sub_type = content_type.split("/") top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join( return os.path.join(
@ -86,7 +96,7 @@ class MediaFilePaths:
media_id[4:], media_id[4:],
) )
def remote_media_filepath_rel(self, server_name, file_id): def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
return os.path.join( return os.path.join(
"remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
) )
@ -94,8 +104,14 @@ class MediaFilePaths:
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
def remote_media_thumbnail_rel( def remote_media_thumbnail_rel(
self, server_name, file_id, width, height, content_type, method self,
): server_name: str,
file_id: str,
width: int,
height: int,
content_type: str,
method: str,
) -> str:
top_level_type, sub_type = content_type.split("/") top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join( return os.path.join(
@ -113,7 +129,7 @@ class MediaFilePaths:
# Should be removed after some time, when most of the thumbnails are stored # Should be removed after some time, when most of the thumbnails are stored
# using the new path. # using the new path.
def remote_media_thumbnail_rel_legacy( def remote_media_thumbnail_rel_legacy(
self, server_name, file_id, width, height, content_type self, server_name: str, file_id: str, width: int, height: int, content_type: str
): ):
top_level_type, sub_type = content_type.split("/") top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
@ -126,7 +142,7 @@ class MediaFilePaths:
file_name, file_name,
) )
def remote_media_thumbnail_dir(self, server_name, file_id): def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
return os.path.join( return os.path.join(
self.base_path, self.base_path,
"remote_thumbnail", "remote_thumbnail",
@ -136,7 +152,7 @@ class MediaFilePaths:
file_id[4:], file_id[4:],
) )
def url_cache_filepath_rel(self, media_id): def url_cache_filepath_rel(self, media_id: str) -> str:
if NEW_FORMAT_ID_RE.match(media_id): if NEW_FORMAT_ID_RE.match(media_id):
# Media id is of the form <DATE><RANDOM_STRING> # Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf # E.g.: 2017-09-28-fsdRDt24DS234dsf
@ -146,7 +162,7 @@ class MediaFilePaths:
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
def url_cache_filepath_dirs_to_delete(self, media_id): def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id file" "The dirs to try and remove if we delete the media_id file"
if NEW_FORMAT_ID_RE.match(media_id): if NEW_FORMAT_ID_RE.match(media_id):
return [os.path.join(self.base_path, "url_cache", media_id[:10])] return [os.path.join(self.base_path, "url_cache", media_id[:10])]
@ -156,7 +172,9 @@ class MediaFilePaths:
os.path.join(self.base_path, "url_cache", media_id[0:2]), os.path.join(self.base_path, "url_cache", media_id[0:2]),
] ]
def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method): def url_cache_thumbnail_rel(
self, media_id: str, width: int, height: int, content_type: str, method: str
) -> str:
# Media id is of the form <DATE><RANDOM_STRING> # Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf # E.g.: 2017-09-28-fsdRDt24DS234dsf
@ -178,7 +196,7 @@ class MediaFilePaths:
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
def url_cache_thumbnail_directory(self, media_id): def url_cache_thumbnail_directory(self, media_id: str) -> str:
# Media id is of the form <DATE><RANDOM_STRING> # Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf # E.g.: 2017-09-28-fsdRDt24DS234dsf
@ -195,7 +213,7 @@ class MediaFilePaths:
media_id[4:], media_id[4:],
) )
def url_cache_thumbnail_dirs_to_delete(self, media_id): def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id thumbnails" "The dirs to try and remove if we delete the media_id thumbnails"
# Media id is of the form <DATE><RANDOM_STRING> # Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf # E.g.: 2017-09-28-fsdRDt24DS234dsf

View file

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,12 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import errno import errno
import logging import logging
import os import os
import shutil import shutil
from typing import IO, Dict, List, Optional, Tuple from io import BytesIO
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import twisted.internet.error import twisted.internet.error
import twisted.web.http import twisted.web.http
@ -56,6 +56,9 @@ from .thumbnail_resource import ThumbnailResource
from .thumbnailer import Thumbnailer, ThumbnailError from .thumbnailer import Thumbnailer, ThumbnailError
from .upload_resource import UploadResource from .upload_resource import UploadResource
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -63,7 +66,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
class MediaRepository: class MediaRepository:
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.client = hs.get_federation_http_client() self.client = hs.get_federation_http_client()
@ -73,16 +76,16 @@ class MediaRepository:
self.max_upload_size = hs.config.max_upload_size self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels self.max_image_pixels = hs.config.max_image_pixels
self.primary_base_path = hs.config.media_store_path self.primary_base_path = hs.config.media_store_path # type: str
self.filepaths = MediaFilePaths(self.primary_base_path) self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths
self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote") self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes = set() self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]]
self.recently_accessed_locals = set() self.recently_accessed_locals = set() # type: Set[str]
self.federation_domain_whitelist = hs.config.federation_domain_whitelist self.federation_domain_whitelist = hs.config.federation_domain_whitelist
@ -113,7 +116,7 @@ class MediaRepository:
"update_recently_accessed_media", self._update_recently_accessed "update_recently_accessed_media", self._update_recently_accessed
) )
async def _update_recently_accessed(self): async def _update_recently_accessed(self) -> None:
remote_media = self.recently_accessed_remotes remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set() self.recently_accessed_remotes = set()
@ -124,12 +127,12 @@ class MediaRepository:
local_media, remote_media, self.clock.time_msec() local_media, remote_media, self.clock.time_msec()
) )
def mark_recently_accessed(self, server_name, media_id): def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None:
"""Mark the given media as recently accessed. """Mark the given media as recently accessed.
Args: Args:
server_name (str|None): Origin server of media, or None if local server_name: Origin server of media, or None if local
media_id (str): The media ID of the content media_id: The media ID of the content
""" """
if server_name: if server_name:
self.recently_accessed_remotes.add((server_name, media_id)) self.recently_accessed_remotes.add((server_name, media_id))
@ -459,7 +462,14 @@ class MediaRepository:
def _get_thumbnail_requirements(self, media_type): def _get_thumbnail_requirements(self, media_type):
return self.thumbnail_requirements.get(media_type, ()) return self.thumbnail_requirements.get(media_type, ())
def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): def _generate_thumbnail(
self,
thumbnailer: Thumbnailer,
t_width: int,
t_height: int,
t_method: str,
t_type: str,
) -> Optional[BytesIO]:
m_width = thumbnailer.width m_width = thumbnailer.width
m_height = thumbnailer.height m_height = thumbnailer.height
@ -470,22 +480,20 @@ class MediaRepository:
m_height, m_height,
self.max_image_pixels, self.max_image_pixels,
) )
return return None
if thumbnailer.transpose_method is not None: if thumbnailer.transpose_method is not None:
m_width, m_height = thumbnailer.transpose() m_width, m_height = thumbnailer.transpose()
if t_method == "crop": if t_method == "crop":
t_byte_source = thumbnailer.crop(t_width, t_height, t_type) return thumbnailer.crop(t_width, t_height, t_type)
elif t_method == "scale": elif t_method == "scale":
t_width, t_height = thumbnailer.aspect(t_width, t_height) t_width, t_height = thumbnailer.aspect(t_width, t_height)
t_width = min(m_width, t_width) t_width = min(m_width, t_width)
t_height = min(m_height, t_height) t_height = min(m_height, t_height)
t_byte_source = thumbnailer.scale(t_width, t_height, t_type) return thumbnailer.scale(t_width, t_height, t_type)
else:
t_byte_source = None
return t_byte_source return None
async def generate_local_exact_thumbnail( async def generate_local_exact_thumbnail(
self, self,
@ -776,7 +784,7 @@ class MediaRepository:
return {"width": m_width, "height": m_height} return {"width": m_width, "height": m_height}
async def delete_old_remote_media(self, before_ts): async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:
old_media = await self.store.get_remote_media_before(before_ts) old_media = await self.store.get_remote_media_before(before_ts)
deleted = 0 deleted = 0
@ -928,7 +936,7 @@ class MediaRepositoryResource(Resource):
within a given rectangle. within a given rectangle.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
# If we're not configured to use it, raise if we somehow got here. # If we're not configured to use it, raise if we somehow got here.
if not hs.config.can_load_media_repo: if not hs.config.can_load_media_repo:
raise ConfigError("Synapse is not configured to use a media repo.") raise ConfigError("Synapse is not configured to use a media repo.")

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2018 New Vecotr Ltd # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -18,6 +18,8 @@ import os
import shutil import shutil
from typing import IO, TYPE_CHECKING, Any, Optional, Sequence from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from synapse.logging.context import defer_to_thread, make_deferred_yieldable from synapse.logging.context import defer_to_thread, make_deferred_yieldable
@ -270,7 +272,7 @@ class MediaStorage:
return self.filepaths.local_media_filepath_rel(file_info.file_id) return self.filepaths.local_media_filepath_rel(file_info.file_id)
def _write_file_synchronously(source, dest): def _write_file_synchronously(source: IO, dest: IO) -> None:
"""Write `source` to the file like `dest` synchronously. Should be called """Write `source` to the file like `dest` synchronously. Should be called
from a thread. from a thread.
@ -286,14 +288,14 @@ class FileResponder(Responder):
"""Wraps an open file that can be sent to a request. """Wraps an open file that can be sent to a request.
Args: Args:
open_file (file): A file like object to be streamed ot the client, open_file: A file like object to be streamed ot the client,
is closed when finished streaming. is closed when finished streaming.
""" """
def __init__(self, open_file): def __init__(self, open_file: IO):
self.open_file = open_file self.open_file = open_file
def write_to_consumer(self, consumer): def write_to_consumer(self, consumer: IConsumer) -> Deferred:
return make_deferred_yieldable( return make_deferred_yieldable(
FileSender().beginFileTransfer(self.open_file, consumer) FileSender().beginFileTransfer(self.open_file, consumer)
) )

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import datetime import datetime
import errno import errno
import fnmatch import fnmatch
@ -23,12 +23,13 @@ import re
import shutil import shutil
import sys import sys
import traceback import traceback
from typing import Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union
from urllib import parse as urlparse from urllib import parse as urlparse
import attr import attr
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
@ -41,6 +42,7 @@ from synapse.http.servlet import parse_integer, parse_string
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -48,6 +50,12 @@ from synapse.util.stringutils import random_string
from ._base import FileInfo from ._base import FileInfo
if TYPE_CHECKING:
from lxml import etree
from synapse.app.homeserver import HomeServer
from synapse.rest.media.v1.media_repository import MediaRepository
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I) _charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
@ -119,7 +127,12 @@ class OEmbedError(Exception):
class PreviewUrlResource(DirectServeJsonResource): class PreviewUrlResource(DirectServeJsonResource):
isLeaf = True isLeaf = True
def __init__(self, hs, media_repo, media_storage): def __init__(
self,
hs: "HomeServer",
media_repo: "MediaRepository",
media_storage: MediaStorage,
):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -166,11 +179,11 @@ class PreviewUrlResource(DirectServeJsonResource):
self._start_expire_url_cache_data, 10 * 1000 self._start_expire_url_cache_data, 10 * 1000
) )
async def _async_render_OPTIONS(self, request): async def _async_render_OPTIONS(self, request: Request) -> None:
request.setHeader(b"Allow", b"OPTIONS, GET") request.setHeader(b"Allow", b"OPTIONS, GET")
respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_GET(self, request): async def _async_render_GET(self, request: Request) -> None:
# XXX: if get_user_by_req fails, what should we do in an async render? # XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
@ -450,7 +463,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e) logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
raise OEmbedError() from e raise OEmbedError() from e
async def _download_url(self, url: str, user): async def _download_url(self, url: str, user: str) -> Dict[str, Any]:
# TODO: we should probably honour robots.txt... except in practice # TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a # we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot? # bot, so are we really a robot?
@ -580,7 +593,7 @@ class PreviewUrlResource(DirectServeJsonResource):
"expire_url_cache_data", self._expire_url_cache_data "expire_url_cache_data", self._expire_url_cache_data
) )
async def _expire_url_cache_data(self): async def _expire_url_cache_data(self) -> None:
"""Clean up expired url cache content, media and thumbnails. """Clean up expired url cache content, media and thumbnails.
""" """
# TODO: Delete from backup media store # TODO: Delete from backup media store
@ -676,7 +689,9 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("No media removed from url cache") logger.debug("No media removed from url cache")
def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]: def decode_and_calc_og(
body: bytes, media_uri: str, request_encoding: Optional[str] = None
) -> Dict[str, Optional[str]]:
# If there's no body, nothing useful is going to be found. # If there's no body, nothing useful is going to be found.
if not body: if not body:
return {} return {}
@ -697,7 +712,7 @@ def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]
return og return og
def _calc_og(tree, media_uri): def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]:
# suck our tree into lxml and define our OG response. # suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them # if we see any image URLs in the OG response, then spider them
@ -801,7 +816,9 @@ def _calc_og(tree, media_uri):
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE) for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
) )
og["og:description"] = summarize_paragraphs(text_nodes) og["og:description"] = summarize_paragraphs(text_nodes)
else: elif og["og:description"]:
# This must be a non-empty string at this point.
assert isinstance(og["og:description"], str)
og["og:description"] = summarize_paragraphs([og["og:description"]]) og["og:description"] = summarize_paragraphs([og["og:description"]])
# TODO: delete the url downloads to stop diskfilling, # TODO: delete the url downloads to stop diskfilling,
@ -809,7 +826,9 @@ def _calc_og(tree, media_uri):
return og return og
def _iterate_over_text(tree, *tags_to_ignore): def _iterate_over_text(
tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
) -> Generator[str, None, None]:
"""Iterate over the tree returning text nodes in a depth first fashion, """Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags. skipping text nodes inside certain tags.
""" """
@ -843,32 +862,32 @@ def _iterate_over_text(tree, *tags_to_ignore):
) )
def _rebase_url(url, base): def _rebase_url(url: str, base: str) -> str:
base = list(urlparse.urlparse(base)) base_parts = list(urlparse.urlparse(base))
url = list(urlparse.urlparse(url)) url_parts = list(urlparse.urlparse(url))
if not url[0]: # fix up schema if not url_parts[0]: # fix up schema
url[0] = base[0] or "http" url_parts[0] = base_parts[0] or "http"
if not url[1]: # fix up hostname if not url_parts[1]: # fix up hostname
url[1] = base[1] url_parts[1] = base_parts[1]
if not url[2].startswith("/"): if not url_parts[2].startswith("/"):
url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2] url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
return urlparse.urlunparse(url) return urlparse.urlunparse(url_parts)
def _is_media(content_type): def _is_media(content_type: str) -> bool:
if content_type.lower().startswith("image/"): return content_type.lower().startswith("image/")
return True
def _is_html(content_type): def _is_html(content_type: str) -> bool:
content_type = content_type.lower() content_type = content_type.lower()
if content_type.startswith("text/html") or content_type.startswith( return content_type.startswith("text/html") or content_type.startswith(
"application/xhtml" "application/xhtml"
): )
return True
def summarize_paragraphs(text_nodes, min_size=200, max_size=500): def summarize_paragraphs(
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
) -> Optional[str]:
# Try to get a summary of between 200 and 500 words, respecting # Try to get a summary of between 200 and 500 words, respecting
# first paragraph and then word boundaries. # first paragraph and then word boundaries.
# TODO: Respect sentences? # TODO: Respect sentences?

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import logging import logging
import os import os
import shutil import shutil
from typing import Optional from typing import TYPE_CHECKING, Optional
from synapse.config._base import Config from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background from synapse.logging.context import defer_to_thread, run_in_background
@ -27,13 +28,17 @@ from .media_storage import FileResponder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
class StorageProvider:
class StorageProvider(metaclass=abc.ABCMeta):
"""A storage provider is a service that can store uploaded media and """A storage provider is a service that can store uploaded media and
retrieve them. retrieve them.
""" """
async def store_file(self, path: str, file_info: FileInfo): @abc.abstractmethod
async def store_file(self, path: str, file_info: FileInfo) -> None:
"""Store the file described by file_info. The actual contents can be """Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path. retrieved by reading the file in file_info.upload_path.
@ -42,6 +47,7 @@ class StorageProvider:
file_info: The metadata of the file. file_info: The metadata of the file.
""" """
@abc.abstractmethod
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""Attempt to fetch the file described by file_info and stream it """Attempt to fetch the file described by file_info and stream it
into writer. into writer.
@ -78,10 +84,10 @@ class StorageProviderWrapper(StorageProvider):
self.store_synchronous = store_synchronous self.store_synchronous = store_synchronous
self.store_remote = store_remote self.store_remote = store_remote
def __str__(self): def __str__(self) -> str:
return "StorageProviderWrapper[%s]" % (self.backend,) return "StorageProviderWrapper[%s]" % (self.backend,)
async def store_file(self, path, file_info): async def store_file(self, path: str, file_info: FileInfo) -> None:
if not file_info.server_name and not self.store_local: if not file_info.server_name and not self.store_local:
return None return None
@ -91,7 +97,7 @@ class StorageProviderWrapper(StorageProvider):
if self.store_synchronous: if self.store_synchronous:
# store_file is supposed to return an Awaitable, but guard # store_file is supposed to return an Awaitable, but guard
# against improper implementations. # against improper implementations.
return await maybe_awaitable(self.backend.store_file(path, file_info)) await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
else: else:
# TODO: Handle errors. # TODO: Handle errors.
async def store(): async def store():
@ -103,9 +109,8 @@ class StorageProviderWrapper(StorageProvider):
logger.exception("Error storing file") logger.exception("Error storing file")
run_in_background(store) run_in_background(store)
return None
async def fetch(self, path, file_info): async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
# store_file is supposed to return an Awaitable, but guard # store_file is supposed to return an Awaitable, but guard
# against improper implementations. # against improper implementations.
return await maybe_awaitable(self.backend.fetch(path, file_info)) return await maybe_awaitable(self.backend.fetch(path, file_info))
@ -115,11 +120,11 @@ class FileStorageProviderBackend(StorageProvider):
"""A storage provider that stores files in a directory on a filesystem. """A storage provider that stores files in a directory on a filesystem.
Args: Args:
hs (HomeServer) hs
config: The config returned by `parse_config`. config: The config returned by `parse_config`.
""" """
def __init__(self, hs, config): def __init__(self, hs: "HomeServer", config: str):
self.hs = hs self.hs = hs
self.cache_directory = hs.config.media_store_path self.cache_directory = hs.config.media_store_path
self.base_directory = config self.base_directory = config
@ -127,7 +132,7 @@ class FileStorageProviderBackend(StorageProvider):
def __str__(self): def __str__(self):
return "FileStorageProviderBackend[%s]" % (self.base_directory,) return "FileStorageProviderBackend[%s]" % (self.base_directory,)
async def store_file(self, path, file_info): async def store_file(self, path: str, file_info: FileInfo) -> None:
"""See StorageProvider.store_file""" """See StorageProvider.store_file"""
primary_fname = os.path.join(self.cache_directory, path) primary_fname = os.path.join(self.cache_directory, path)
@ -137,19 +142,21 @@ class FileStorageProviderBackend(StorageProvider):
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
return await defer_to_thread( await defer_to_thread(
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
) )
async def fetch(self, path, file_info): async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""See StorageProvider.fetch""" """See StorageProvider.fetch"""
backup_fname = os.path.join(self.base_directory, path) backup_fname = os.path.join(self.base_directory, path)
if os.path.isfile(backup_fname): if os.path.isfile(backup_fname):
return FileResponder(open(backup_fname, "rb")) return FileResponder(open(backup_fname, "rb"))
return None
@staticmethod @staticmethod
def parse_config(config): def parse_config(config: dict) -> str:
"""Called on startup to parse config supplied. This should parse """Called on startup to parse config supplied. This should parse
the config and raise if there is a problem. the config and raise if there is a problem.

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,10 +16,14 @@
import logging import logging
from typing import TYPE_CHECKING
from twisted.web.http import Request
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string from synapse.http.servlet import parse_integer, parse_string
from synapse.rest.media.v1.media_storage import MediaStorage
from ._base import ( from ._base import (
FileInfo, FileInfo,
@ -28,13 +33,22 @@ from ._base import (
respond_with_responder, respond_with_responder,
) )
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.rest.media.v1.media_repository import MediaRepository
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ThumbnailResource(DirectServeJsonResource): class ThumbnailResource(DirectServeJsonResource):
isLeaf = True isLeaf = True
def __init__(self, hs, media_repo, media_storage): def __init__(
self,
hs: "HomeServer",
media_repo: "MediaRepository",
media_storage: MediaStorage,
):
super().__init__() super().__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -43,7 +57,7 @@ class ThumbnailResource(DirectServeJsonResource):
self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname self.server_name = hs.hostname
async def _async_render_GET(self, request): async def _async_render_GET(self, request: Request) -> None:
set_cors_headers(request) set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request) server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width", required=True) width = parse_integer(request, "width", required=True)
@ -73,8 +87,14 @@ class ThumbnailResource(DirectServeJsonResource):
self.media_repo.mark_recently_accessed(server_name, media_id) self.media_repo.mark_recently_accessed(server_name, media_id)
async def _respond_local_thumbnail( async def _respond_local_thumbnail(
self, request, media_id, width, height, method, m_type self,
): request: Request,
media_id: str,
width: int,
height: int,
method: str,
m_type: str,
) -> None:
media_info = await self.store.get_local_media(media_id) media_info = await self.store.get_local_media(media_id)
if not media_info: if not media_info:
@ -114,13 +134,13 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_local_thumbnail( async def _select_or_generate_local_thumbnail(
self, self,
request, request: Request,
media_id, media_id: str,
desired_width, desired_width: int,
desired_height, desired_height: int,
desired_method, desired_method: str,
desired_type, desired_type: str,
): ) -> None:
media_info = await self.store.get_local_media(media_id) media_info = await self.store.get_local_media(media_id)
if not media_info: if not media_info:
@ -178,14 +198,14 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_remote_thumbnail( async def _select_or_generate_remote_thumbnail(
self, self,
request, request: Request,
server_name, server_name: str,
media_id, media_id: str,
desired_width, desired_width: int,
desired_height, desired_height: int,
desired_method, desired_method: str,
desired_type, desired_type: str,
): ) -> None:
media_info = await self.media_repo.get_remote_media_info(server_name, media_id) media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = await self.store.get_remote_media_thumbnails( thumbnail_infos = await self.store.get_remote_media_thumbnails(
@ -239,8 +259,15 @@ class ThumbnailResource(DirectServeJsonResource):
raise SynapseError(400, "Failed to generate thumbnail.") raise SynapseError(400, "Failed to generate thumbnail.")
async def _respond_remote_thumbnail( async def _respond_remote_thumbnail(
self, request, server_name, media_id, width, height, method, m_type self,
): request: Request,
server_name: str,
media_id: str,
width: int,
height: int,
method: str,
m_type: str,
) -> None:
# TODO: Don't download the whole remote file # TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of # We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails. # downloading the remote file and generating our own thumbnails.
@ -275,12 +302,12 @@ class ThumbnailResource(DirectServeJsonResource):
def _select_thumbnail( def _select_thumbnail(
self, self,
desired_width, desired_width: int,
desired_height, desired_height: int,
desired_method, desired_method: str,
desired_type, desired_type: str,
thumbnail_infos, thumbnail_infos,
): ) -> dict:
d_w = desired_width d_w = desired_width
d_h = desired_height d_h = desired_height

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,6 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from io import BytesIO from io import BytesIO
from typing import Tuple
from PIL import Image from PIL import Image
@ -39,7 +41,7 @@ class Thumbnailer:
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
def __init__(self, input_path): def __init__(self, input_path: str):
try: try:
self.image = Image.open(input_path) self.image = Image.open(input_path)
except OSError as e: except OSError as e:
@ -59,11 +61,11 @@ class Thumbnailer:
# A lot of parsing errors can happen when parsing EXIF # A lot of parsing errors can happen when parsing EXIF
logger.info("Error parsing image EXIF information: %s", e) logger.info("Error parsing image EXIF information: %s", e)
def transpose(self): def transpose(self) -> Tuple[int, int]:
"""Transpose the image using its EXIF Orientation tag """Transpose the image using its EXIF Orientation tag
Returns: Returns:
Tuple[int, int]: (width, height) containing the new image size in pixels. A tuple containing the new image size in pixels as (width, height).
""" """
if self.transpose_method is not None: if self.transpose_method is not None:
self.image = self.image.transpose(self.transpose_method) self.image = self.image.transpose(self.transpose_method)
@ -73,7 +75,7 @@ class Thumbnailer:
self.image.info["exif"] = None self.image.info["exif"] = None
return self.image.size return self.image.size
def aspect(self, max_width, max_height): def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]:
"""Calculate the largest size that preserves aspect ratio which """Calculate the largest size that preserves aspect ratio which
fits within the given rectangle:: fits within the given rectangle::
@ -91,7 +93,7 @@ class Thumbnailer:
else: else:
return (max_height * self.width) // self.height, max_height return (max_height * self.width) // self.height, max_height
def _resize(self, width, height): def _resize(self, width: int, height: int) -> Image:
# 1-bit or 8-bit color palette images need converting to RGB # 1-bit or 8-bit color palette images need converting to RGB
# otherwise they will be scaled using nearest neighbour which # otherwise they will be scaled using nearest neighbour which
# looks awful # looks awful
@ -99,7 +101,7 @@ class Thumbnailer:
self.image = self.image.convert("RGB") self.image = self.image.convert("RGB")
return self.image.resize((width, height), Image.ANTIALIAS) return self.image.resize((width, height), Image.ANTIALIAS)
def scale(self, width, height, output_type): def scale(self, width: int, height: int, output_type: str) -> BytesIO:
"""Rescales the image to the given dimensions. """Rescales the image to the given dimensions.
Returns: Returns:
@ -108,7 +110,7 @@ class Thumbnailer:
scaled = self._resize(width, height) scaled = self._resize(width, height)
return self._encode_image(scaled, output_type) return self._encode_image(scaled, output_type)
def crop(self, width, height, output_type): def crop(self, width: int, height: int, output_type: str) -> BytesIO:
"""Rescales and crops the image to the given dimensions preserving """Rescales and crops the image to the given dimensions preserving
aspect:: aspect::
(w_in / h_in) = (w_scaled / h_scaled) (w_in / h_in) = (w_scaled / h_scaled)
@ -136,7 +138,7 @@ class Thumbnailer:
cropped = scaled_image.crop((crop_left, 0, crop_right, height)) cropped = scaled_image.crop((crop_left, 0, crop_right, height))
return self._encode_image(cropped, output_type) return self._encode_image(cropped, output_type)
def _encode_image(self, output_image, output_type): def _encode_image(self, output_image: Image, output_type: str) -> BytesIO:
output_bytes_io = BytesIO() output_bytes_io = BytesIO()
fmt = self.FORMATS[output_type] fmt = self.FORMATS[output_type]
if fmt == "JPEG": if fmt == "JPEG":

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,18 +15,25 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.rest.media.v1.media_repository import MediaRepository
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UploadResource(DirectServeJsonResource): class UploadResource(DirectServeJsonResource):
isLeaf = True isLeaf = True
def __init__(self, hs, media_repo): def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__() super().__init__()
self.media_repo = media_repo self.media_repo = media_repo
@ -37,10 +45,10 @@ class UploadResource(DirectServeJsonResource):
self.max_upload_size = hs.config.max_upload_size self.max_upload_size = hs.config.max_upload_size
self.clock = hs.get_clock() self.clock = hs.get_clock()
async def _async_render_OPTIONS(self, request): async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_POST(self, request): async def _async_render_POST(self, request: Request) -> None:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have # TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point # already been uploaded to a tmp file at this point

View file

@ -16,6 +16,8 @@
import logging import logging
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import attr
from synapse.api.constants import EventContentFields from synapse.api.constants import EventContentFields
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict from synapse.events import make_event_from_dict
@ -28,6 +30,25 @@ from synapse.types import JsonDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True)
class _CalculateChainCover:
"""Return value for _calculate_chain_cover_txn.
"""
# The last room_id/depth/stream processed.
room_id = attr.ib(type=str)
depth = attr.ib(type=int)
stream = attr.ib(type=int)
# Number of rows processed
processed_count = attr.ib(type=int)
# Map from room_id to last depth/stream processed for each room that we have
# processed all events for (i.e. the rooms we can flip the
# `has_auth_chain_index` for)
finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]])
class EventsBackgroundUpdatesStore(SQLBaseStore): class EventsBackgroundUpdatesStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
@ -719,53 +740,94 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
current_room_id = progress.get("current_room_id", "") current_room_id = progress.get("current_room_id", "")
# Have we finished processing the current room.
finished = progress.get("finished", True)
# Where we've processed up to in the room, defaults to the start of the # Where we've processed up to in the room, defaults to the start of the
# room. # room.
last_depth = progress.get("last_depth", -1) last_depth = progress.get("last_depth", -1)
last_stream = progress.get("last_stream", -1) last_stream = progress.get("last_stream", -1)
# Have we set the `has_auth_chain_index` for the room yet. result = await self.db_pool.runInteraction(
has_set_room_has_chain_index = progress.get( "_chain_cover_index",
"has_set_room_has_chain_index", False self._calculate_chain_cover_txn,
current_room_id,
last_depth,
last_stream,
batch_size,
single_room=False,
) )
finished = result.processed_count == 0
total_rows_processed = result.processed_count
current_room_id = result.room_id
last_depth = result.depth
last_stream = result.stream
for room_id, (depth, stream) in result.finished_room_map.items():
# If we've done all the events in the room we flip the
# `has_auth_chain_index` in the DB. Note that its possible for
# further events to be persisted between the above and setting the
# flag without having the chain cover calculated for them. This is
# fine as a) the code gracefully handles these cases and b) we'll
# calculate them below.
await self.db_pool.simple_update(
table="rooms",
keyvalues={"room_id": room_id},
updatevalues={"has_auth_chain_index": True},
desc="_chain_cover_index",
)
# Handle any events that might have raced with us flipping the
# bit above.
result = await self.db_pool.runInteraction(
"_chain_cover_index",
self._calculate_chain_cover_txn,
room_id,
depth,
stream,
batch_size=None,
single_room=True,
)
total_rows_processed += result.processed_count
if finished: if finished:
# If we've finished with the previous room (or its our first
# iteration) we move on to the next room.
def _get_next_room(txn: Cursor) -> Optional[str]:
sql = """
SELECT room_id FROM rooms
WHERE room_id > ?
AND (
NOT has_auth_chain_index
OR has_auth_chain_index IS NULL
)
ORDER BY room_id
LIMIT 1
"""
txn.execute(sql, (current_room_id,))
row = txn.fetchone()
if row:
return row[0]
return None
current_room_id = await self.db_pool.runInteraction(
"_chain_cover_index", _get_next_room
)
if not current_room_id:
await self.db_pool.updates._end_background_update("chain_cover") await self.db_pool.updates._end_background_update("chain_cover")
return 0 return total_rows_processed
logger.debug("Adding chain cover to %s", current_room_id) await self.db_pool.updates._background_update_progress(
"chain_cover",
{
"current_room_id": current_room_id,
"last_depth": last_depth,
"last_stream": last_stream,
},
)
return total_rows_processed
def _calculate_chain_cover_txn(
self,
txn: Cursor,
last_room_id: str,
last_depth: int,
last_stream: int,
batch_size: Optional[int],
single_room: bool,
) -> _CalculateChainCover:
"""Calculate the chain cover for `batch_size` events, ordered by
`(room_id, depth, stream)`.
Args:
txn,
last_room_id, last_depth, last_stream: The `(room_id, depth, stream)`
tuple to fetch results after.
batch_size: The maximum number of events to process. If None then
no limit.
single_room: Whether to calculate the index for just the given
room.
"""
def _calculate_auth_chain(
txn: Cursor, last_depth: int, last_stream: int
) -> Tuple[int, int, int]:
# Get the next set of events in the room (that we haven't already # Get the next set of events in the room (that we haven't already
# computed chain cover for). We do this in topological order. # computed chain cover for). We do this in topological order.
@ -774,43 +836,66 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
tuple_clause, tuple_args = make_tuple_comparison_clause( tuple_clause, tuple_args = make_tuple_comparison_clause(
self.database_engine, self.database_engine,
[ [
("events.room_id", last_room_id),
("topological_ordering", last_depth), ("topological_ordering", last_depth),
("stream_ordering", last_stream), ("stream_ordering", last_stream),
], ],
) )
extra_clause = ""
if single_room:
extra_clause = "AND events.room_id = ?"
tuple_args.append(last_room_id)
sql = """ sql = """
SELECT SELECT
event_id, state_events.type, state_events.state_key, event_id, state_events.type, state_events.state_key,
topological_ordering, stream_ordering topological_ordering, stream_ordering,
events.room_id
FROM events FROM events
INNER JOIN state_events USING (event_id) INNER JOIN state_events USING (event_id)
LEFT JOIN event_auth_chains USING (event_id) LEFT JOIN event_auth_chains USING (event_id)
LEFT JOIN event_auth_chain_to_calculate USING (event_id) LEFT JOIN event_auth_chain_to_calculate USING (event_id)
WHERE events.room_id = ? WHERE event_auth_chains.event_id IS NULL
AND event_auth_chains.event_id IS NULL
AND event_auth_chain_to_calculate.event_id IS NULL AND event_auth_chain_to_calculate.event_id IS NULL
AND %(tuple_cmp)s AND %(tuple_cmp)s
ORDER BY topological_ordering, stream_ordering %(extra)s
LIMIT ? ORDER BY events.room_id, topological_ordering, stream_ordering
%(limit)s
""" % { """ % {
"tuple_cmp": tuple_clause, "tuple_cmp": tuple_clause,
"limit": "LIMIT ?" if batch_size is not None else "",
"extra": extra_clause,
} }
args = [current_room_id] if batch_size is not None:
args.extend(tuple_args) tuple_args.append(batch_size)
args.append(batch_size)
txn.execute(sql, args) txn.execute(sql, tuple_args)
rows = txn.fetchall() rows = txn.fetchall()
# Put the results in the necessary format for # Put the results in the necessary format for
# `_add_chain_cover_index` # `_add_chain_cover_index`
event_to_room_id = {row[0]: current_room_id for row in rows} event_to_room_id = {row[0]: row[5] for row in rows}
event_to_types = {row[0]: (row[1], row[2]) for row in rows} event_to_types = {row[0]: (row[1], row[2]) for row in rows}
# Calculate the new last position we've processed up to.
new_last_depth = rows[-1][3] if rows else last_depth # type: int new_last_depth = rows[-1][3] if rows else last_depth # type: int
new_last_stream = rows[-1][4] if rows else last_stream # type: int new_last_stream = rows[-1][4] if rows else last_stream # type: int
new_last_room_id = rows[-1][5] if rows else "" # type: str
# Map from room_id to last depth/stream_ordering processed for the room,
# excluding the last room (which we're likely still processing). We also
# need to include the room passed in if it's not included in the result
# set (as we then know we've processed all events in said room).
#
# This is the set of rooms that we can now safely flip the
# `has_auth_chain_index` bit for.
finished_rooms = {
row[5]: (row[3], row[4]) for row in rows if row[5] != new_last_room_id
}
if last_room_id not in finished_rooms and last_room_id != new_last_room_id:
finished_rooms[last_room_id] = (last_depth, last_stream)
count = len(rows) count = len(rows)
@ -826,76 +911,20 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
event_to_auth_chain = {} # type: Dict[str, List[str]] event_to_auth_chain = {} # type: Dict[str, List[str]]
for row in auth_events: for row in auth_events:
event_to_auth_chain.setdefault(row["event_id"], []).append( event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
row["auth_id"]
)
# Calculate and persist the chain cover index for this set of events. # Calculate and persist the chain cover index for this set of events.
# #
# Annoyingly we need to gut wrench into the persit event store so that # Annoyingly we need to gut wrench into the persit event store so that
# we can reuse the function to calculate the chain cover for rooms. # we can reuse the function to calculate the chain cover for rooms.
PersistEventsStore._add_chain_cover_index( PersistEventsStore._add_chain_cover_index(
txn, txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
self.db_pool,
event_to_room_id,
event_to_types,
event_to_auth_chain,
) )
return new_last_depth, new_last_stream, count return _CalculateChainCover(
room_id=new_last_room_id,
last_depth, last_stream, count = await self.db_pool.runInteraction( depth=new_last_depth,
"_chain_cover_index", _calculate_auth_chain, last_depth, last_stream stream=new_last_stream,
processed_count=count,
finished_room_map=finished_rooms,
) )
total_rows_processed = count
if count < batch_size and not has_set_room_has_chain_index:
# If we've done all the events in the room we flip the
# `has_auth_chain_index` in the DB. Note that its possible for
# further events to be persisted between the above and setting the
# flag without having the chain cover calculated for them. This is
# fine as a) the code gracefully handles these cases and b) we'll
# calculate them below.
await self.db_pool.simple_update(
table="rooms",
keyvalues={"room_id": current_room_id},
updatevalues={"has_auth_chain_index": True},
desc="_chain_cover_index",
)
has_set_room_has_chain_index = True
# Handle any events that might have raced with us flipping the
# bit above.
last_depth, last_stream, count = await self.db_pool.runInteraction(
"_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
)
total_rows_processed += count
# Note that at this point its technically possible that more events
# than our `batch_size` have been persisted without their chain
# cover, so we need to continue processing this room if the last
# count returned was equal to the `batch_size`.
if count < batch_size:
# We've finished calculating the index for this room, move on to the
# next room.
await self.db_pool.updates._background_update_progress(
"chain_cover", {"current_room_id": current_room_id, "finished": True},
)
else:
# We still have outstanding events to calculate the index for.
await self.db_pool.updates._background_update_progress(
"chain_cover",
{
"current_room_id": current_room_id,
"last_depth": last_depth,
"last_stream": last_stream,
"has_auth_chain_index": has_set_room_has_chain_index,
"finished": False,
},
)
return total_rows_processed

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -169,7 +170,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_local_media_before( async def get_local_media_before(
self, before_ts: int, size_gt: int, keep_profiles: bool, self, before_ts: int, size_gt: int, keep_profiles: bool,
) -> Optional[List[str]]: ) -> List[str]:
# to find files that have never been accessed (last_access_ts IS NULL) # to find files that have never been accessed (last_access_ts IS NULL)
# compare with `created_ts` # compare with `created_ts`

View file

@ -17,14 +17,13 @@
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
from canonicaljson import encode_canonical_json
from synapse.push import PusherConfig, ThrottleParams from synapse.push import PusherConfig, ThrottleParams
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.types import Connection from synapse.storage.types import Connection
from synapse.storage.util.id_generators import StreamIdGenerator from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING: if TYPE_CHECKING:
@ -315,7 +314,7 @@ class PusherStore(PusherWorkerStore):
"device_display_name": device_display_name, "device_display_name": device_display_name,
"ts": pushkey_ts, "ts": pushkey_ts,
"lang": lang, "lang": lang,
"data": bytearray(encode_canonical_json(data)), "data": json_encoder.encode(data),
"last_stream_ordering": last_stream_ordering, "last_stream_ordering": last_stream_ordering,
"profile_tag": profile_tag, "profile_tag": profile_tag,
"id": stream_id, "id": stream_id,

View file

@ -75,3 +75,22 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
if len(items) <= maxitems: if len(items) <= maxitems:
return str(items) return str(items)
return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]" return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
def strtobool(val: str) -> bool:
"""Convert a string representation of truth to True or False
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
'val' is anything else.
This is lifted from distutils.util.strtobool, with the exception that it actually
returns a bool, rather than an int.
"""
val = val.lower()
if val in ("y", "yes", "t", "true", "on", "1"):
return True
elif val in ("n", "no", "f", "false", "off", "0"):
return False
else:
raise ValueError("invalid truth value %r" % (val,))

View file

@ -145,7 +145,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
hs = self.setup_test_homeserver(proxied_http_client=self.http_client) hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
self.handler = hs.get_oidc_handler() self.handler = hs.get_oidc_handler()
self.provider = self.handler._provider self.provider = self.handler._providers["oidc"]
sso_handler = hs.get_sso_handler() sso_handler = hs.get_sso_handler()
# Mock the render error method. # Mock the render error method.
self.render_error = Mock(return_value=None) self.render_error = Mock(return_value=None)
@ -866,7 +866,7 @@ async def _make_callback_with_userinfo(
from synapse.handlers.oidc_handler import OidcSessionData from synapse.handlers.oidc_handler import OidcSessionData
handler = hs.get_oidc_handler() handler = hs.get_oidc_handler()
provider = handler._provider provider = handler._providers["oidc"]
provider._exchange_code = simple_async_mock(return_value={}) provider._exchange_code = simple_async_mock(return_value={})
provider._parse_id_token = simple_async_mock(return_value=userinfo) provider._parse_id_token = simple_async_mock(return_value=userinfo)
provider._fetch_userinfo = simple_async_mock(return_value=userinfo) provider._fetch_userinfo = simple_async_mock(return_value=userinfo)

View file

@ -1095,7 +1095,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
# Expire both caches and repeat the request # Expire both caches and repeat the request
self.reactor.pump((10000.0,)) self.reactor.pump((10000.0,))
# Repated the request, this time it should fail if the lookup fails. # Repeat the request, this time it should fail if the lookup fails.
fetch_d = defer.ensureDeferred( fetch_d = defer.ensureDeferred(
self.well_known_resolver.get_well_known(b"testserv") self.well_known_resolver.get_well_known(b"testserv")
) )
@ -1130,7 +1130,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
content=b'{ "m.server": "' + (b"a" * WELL_KNOWN_MAX_SIZE) + b'" }', content=b'{ "m.server": "' + (b"a" * WELL_KNOWN_MAX_SIZE) + b'" }',
) )
# The result is sucessful, but disabled delegation. # The result is successful, but disabled delegation.
r = self.successResultOf(fetch_d) r = self.successResultOf(fetch_d)
self.assertIsNone(r.delegated_server) self.assertIsNone(r.delegated_server)

101
tests/http/test_client.py Normal file
View file

@ -0,0 +1,101 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from io import BytesIO
from mock import Mock
from twisted.python.failure import Failure
from twisted.web.client import ResponseDone
from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
from tests.unittest import TestCase
class ReadBodyWithMaxSizeTests(TestCase):
def setUp(self):
"""Start reading the body, returns the response, result and proto"""
self.response = Mock()
self.result = BytesIO()
self.deferred = read_body_with_max_size(self.response, self.result, 6)
# Fish the protocol out of the response.
self.protocol = self.response.deliverBody.call_args[0][0]
self.protocol.transport = Mock()
def _cleanup_error(self):
"""Ensure that the error in the Deferred is handled gracefully."""
called = [False]
def errback(f):
called[0] = True
self.deferred.addErrback(errback)
self.assertTrue(called[0])
def test_no_error(self):
"""A response that is NOT too large."""
# Start sending data.
self.protocol.dataReceived(b"12345")
# Close the connection.
self.protocol.connectionLost(Failure(ResponseDone()))
self.assertEqual(self.result.getvalue(), b"12345")
self.assertEqual(self.deferred.result, 5)
def test_too_large(self):
"""A response which is too large raises an exception."""
# Start sending data.
self.protocol.dataReceived(b"1234567890")
# Close the connection.
self.protocol.connectionLost(Failure(ResponseDone()))
self.assertEqual(self.result.getvalue(), b"1234567890")
self.assertIsInstance(self.deferred.result, Failure)
self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
self._cleanup_error()
def test_multiple_packets(self):
"""Data should be accummulated through mutliple packets."""
# Start sending data.
self.protocol.dataReceived(b"12")
self.protocol.dataReceived(b"34")
# Close the connection.
self.protocol.connectionLost(Failure(ResponseDone()))
self.assertEqual(self.result.getvalue(), b"1234")
self.assertEqual(self.deferred.result, 4)
def test_additional_data(self):
"""A connection can receive data after being closed."""
# Start sending data.
self.protocol.dataReceived(b"1234567890")
self.assertIsInstance(self.deferred.result, Failure)
self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
self.protocol.transport.loseConnection.assert_called_once()
# More data might have come in.
self.protocol.dataReceived(b"1234567890")
# Close the connection.
self.protocol.connectionLost(Failure(ResponseDone()))
self.assertEqual(self.result.getvalue(), b"1234567890")
self.assertIsInstance(self.deferred.result, Failure)
self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
self._cleanup_error()

View file

@ -153,8 +153,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
# Allow for uploading and downloading to/from the media repo # Allow for uploading and downloading to/from the media repo
self.media_repo = hs.get_media_repository_resource() self.media_repo = hs.get_media_repository_resource()
self.download_resource = self.media_repo.children[b"download"] self.download_resource = self.media_repo.children[b"download"]
@ -428,7 +426,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Mark the second item as safe from quarantine. # Mark the second item as safe from quarantine.
_, media_id_2 = server_and_media_id_2.split("/") _, media_id_2 = server_and_media_id_2.split("/")
self.get_success(self.store.mark_local_media_as_safe(media_id_2)) # Quarantine the media
url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),)
channel = self.make_request("POST", url, access_token=admin_user_tok)
self.pump(1.0)
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
# Quarantine all media by this user # Quarantine all media by this user
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(

View file

@ -475,7 +475,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
session_id = channel.json_body["session"] session_id = channel.json_body["session"]
# do the OIDC auth, but auth as the wrong user # do the OIDC auth, but auth as the wrong user
channel = self.helper.auth_via_oidc("wrong_user", ui_auth_session_id=session_id) channel = self.helper.auth_via_oidc(
{"sub": "wrong_user"}, ui_auth_session_id=session_id
)
# that should return a failure message # that should return a failure message
self.assertSubstring("We were unable to validate", channel.text_body) self.assertSubstring("We were unable to validate", channel.text_body)

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, List, Tuple from typing import Dict, List, Set, Tuple
from twisted.trial import unittest from twisted.trial import unittest
@ -483,22 +483,20 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def test_background_update(self): def prepare(self, reactor, clock, hs):
"""Test that the background update to calculate auth chains for historic self.store = hs.get_datastore()
rooms works correctly. self.user_id = self.register_user("foo", "pass")
self.token = self.login("foo", "pass")
self.requester = create_requester(self.user_id)
def _generate_room(self) -> Tuple[str, List[Set[str]]]:
"""Insert a room without a chain cover index.
""" """
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
# Create a room
user_id = self.register_user("foo", "pass")
token = self.login("foo", "pass")
room_id = self.helper.create_room_as(user_id, tok=token)
requester = create_requester(user_id)
store = self.hs.get_datastore()
# Mark the room as not having a chain cover index # Mark the room as not having a chain cover index
self.get_success( self.get_success(
store.db_pool.simple_update( self.store.db_pool.simple_update(
table="rooms", table="rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
updatevalues={"has_auth_chain_index": False}, updatevalues={"has_auth_chain_index": False},
@ -508,42 +506,44 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
# Create a fork in the DAG with different events. # Create a fork in the DAG with different events.
event_handler = self.hs.get_event_creation_handler() event_handler = self.hs.get_event_creation_handler()
latest_event_ids = self.get_success(store.get_prev_events_for_room(room_id)) latest_event_ids = self.get_success(
self.store.get_prev_events_for_room(room_id)
)
event, context = self.get_success( event, context = self.get_success(
event_handler.create_event( event_handler.create_event(
requester, self.requester,
{ {
"type": "some_state_type", "type": "some_state_type",
"state_key": "", "state_key": "",
"content": {}, "content": {},
"room_id": room_id, "room_id": room_id,
"sender": user_id, "sender": self.user_id,
}, },
prev_event_ids=latest_event_ids, prev_event_ids=latest_event_ids,
) )
) )
self.get_success( self.get_success(
event_handler.handle_new_client_event(requester, event, context) event_handler.handle_new_client_event(self.requester, event, context)
) )
state1 = list(self.get_success(context.get_current_state_ids()).values()) state1 = set(self.get_success(context.get_current_state_ids()).values())
event, context = self.get_success( event, context = self.get_success(
event_handler.create_event( event_handler.create_event(
requester, self.requester,
{ {
"type": "some_state_type", "type": "some_state_type",
"state_key": "", "state_key": "",
"content": {}, "content": {},
"room_id": room_id, "room_id": room_id,
"sender": user_id, "sender": self.user_id,
}, },
prev_event_ids=latest_event_ids, prev_event_ids=latest_event_ids,
) )
) )
self.get_success( self.get_success(
event_handler.handle_new_client_event(requester, event, context) event_handler.handle_new_client_event(self.requester, event, context)
) )
state2 = list(self.get_success(context.get_current_state_ids()).values()) state2 = set(self.get_success(context.get_current_state_ids()).values())
# Delete the chain cover info. # Delete the chain cover info.
@ -551,36 +551,191 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
txn.execute("DELETE FROM event_auth_chains") txn.execute("DELETE FROM event_auth_chains")
txn.execute("DELETE FROM event_auth_chain_links") txn.execute("DELETE FROM event_auth_chain_links")
self.get_success(store.db_pool.runInteraction("test", _delete_tables)) self.get_success(self.store.db_pool.runInteraction("test", _delete_tables))
return room_id, [state1, state2]
def test_background_update_single_room(self):
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
# Create a room
room_id, states = self._generate_room()
# Insert and run the background update. # Insert and run the background update.
self.get_success( self.get_success(
store.db_pool.simple_insert( self.store.db_pool.simple_insert(
"background_updates", "background_updates",
{"update_name": "chain_cover", "progress_json": "{}"}, {"update_name": "chain_cover", "progress_json": "{}"},
) )
) )
# Ugh, have to reset this flag # Ugh, have to reset this flag
store.db_pool.updates._all_done = False self.store.db_pool.updates._all_done = False
while not self.get_success( while not self.get_success(
store.db_pool.updates.has_completed_background_updates() self.store.db_pool.updates.has_completed_background_updates()
): ):
self.get_success( self.get_success(
store.db_pool.updates.do_next_background_update(100), by=0.1 self.store.db_pool.updates.do_next_background_update(100), by=0.1
) )
# Test that the `has_auth_chain_index` has been set # Test that the `has_auth_chain_index` has been set
self.assertTrue(self.get_success(store.has_auth_chain_index(room_id))) self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
# Test that calculating the auth chain difference using the newly # Test that calculating the auth chain difference using the newly
# calculated chain cover works. # calculated chain cover works.
self.get_success( self.get_success(
store.db_pool.runInteraction( self.store.db_pool.runInteraction(
"test", "test",
store._get_auth_chain_difference_using_cover_index_txn, self.store._get_auth_chain_difference_using_cover_index_txn,
room_id, room_id,
[state1, state2], states,
) )
) )
def test_background_update_multiple_rooms(self):
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
# Create a room
room_id1, states1 = self._generate_room()
room_id2, states2 = self._generate_room()
room_id3, states2 = self._generate_room()
# Insert and run the background update.
self.get_success(
self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "chain_cover", "progress_json": "{}"},
)
)
# Ugh, have to reset this flag
self.store.db_pool.updates._all_done = False
while not self.get_success(
self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Test that the `has_auth_chain_index` has been set
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3)))
# Test that calculating the auth chain difference using the newly
# calculated chain cover works.
self.get_success(
self.store.db_pool.runInteraction(
"test",
self.store._get_auth_chain_difference_using_cover_index_txn,
room_id1,
states1,
)
)
def test_background_update_single_large_room(self):
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
# Create a room
room_id, states = self._generate_room()
# Add a bunch of state so that it takes multiple iterations of the
# background update to process the room.
for i in range(0, 150):
self.helper.send_state(
room_id, event_type="m.test", body={"index": i}, tok=self.token
)
# Insert and run the background update.
self.get_success(
self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "chain_cover", "progress_json": "{}"},
)
)
# Ugh, have to reset this flag
self.store.db_pool.updates._all_done = False
iterations = 0
while not self.get_success(
self.store.db_pool.updates.has_completed_background_updates()
):
iterations += 1
self.get_success(
self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Ensure that we did actually take multiple iterations to process the
# room.
self.assertGreater(iterations, 1)
# Test that the `has_auth_chain_index` has been set
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
# Test that calculating the auth chain difference using the newly
# calculated chain cover works.
self.get_success(
self.store.db_pool.runInteraction(
"test",
self.store._get_auth_chain_difference_using_cover_index_txn,
room_id,
states,
)
)
def test_background_update_multiple_large_room(self):
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
# Create the rooms
room_id1, _ = self._generate_room()
room_id2, _ = self._generate_room()
# Add a bunch of state so that it takes multiple iterations of the
# background update to process the room.
for i in range(0, 150):
self.helper.send_state(
room_id1, event_type="m.test", body={"index": i}, tok=self.token
)
for i in range(0, 150):
self.helper.send_state(
room_id2, event_type="m.test", body={"index": i}, tok=self.token
)
# Insert and run the background update.
self.get_success(
self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "chain_cover", "progress_json": "{}"},
)
)
# Ugh, have to reset this flag
self.store.db_pool.updates._all_done = False
iterations = 0
while not self.get_success(
self.store.db_pool.updates.has_completed_background_updates()
):
iterations += 1
self.get_success(
self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Ensure that we did actually take multiple iterations to process the
# room.
self.assertGreater(iterations, 1)
# Test that the `has_auth_chain_index` has been set
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))