diff --git a/changelog.d/9086.feature b/changelog.d/9086.feature new file mode 100644 index 0000000000..3e678e24d5 --- /dev/null +++ b/changelog.d/9086.feature @@ -0,0 +1 @@ +Add an admin API for protecting local media from quarantine. diff --git a/changelog.d/9093.misc b/changelog.d/9093.misc new file mode 100644 index 0000000000..53eb8f72a8 --- /dev/null +++ b/changelog.d/9093.misc @@ -0,0 +1 @@ +Add type hints to media repository. diff --git a/changelog.d/9108.bugfix b/changelog.d/9108.bugfix new file mode 100644 index 0000000000..465ef63508 --- /dev/null +++ b/changelog.d/9108.bugfix @@ -0,0 +1 @@ +Fix "Unhandled error in Deferred: BodyExceededMaxSize" errors when .well-known files that are too large. diff --git a/changelog.d/9110.feature b/changelog.d/9110.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9110.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/changelog.d/9117.bugfix b/changelog.d/9117.bugfix new file mode 100644 index 0000000000..233a76d18b --- /dev/null +++ b/changelog.d/9117.bugfix @@ -0,0 +1 @@ +Fix corruption of `pushers` data when a postgres bouncer is used. diff --git a/changelog.d/9124.misc b/changelog.d/9124.misc new file mode 100644 index 0000000000..346741d982 --- /dev/null +++ b/changelog.d/9124.misc @@ -0,0 +1 @@ +Improve efficiency of large state resolutions. diff --git a/changelog.d/9125.misc b/changelog.d/9125.misc new file mode 100644 index 0000000000..08459caf5a --- /dev/null +++ b/changelog.d/9125.misc @@ -0,0 +1 @@ +Remove dependency on `distutils`. diff --git a/changelog.d/9130.feature b/changelog.d/9130.feature new file mode 100644 index 0000000000..4ec319f1f2 --- /dev/null +++ b/changelog.d/9130.feature @@ -0,0 +1 @@ +Add experimental support for handling and persistence of to-device messages to happen on worker processes. diff --git a/debian/changelog b/debian/changelog index 609436bf75..1c6308e3a2 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.25.0ubuntu1) UNRELEASED; urgency=medium + + * Remove dependency on `python3-distutils`. + + -- Richard van der Hoff Fri, 15 Jan 2021 12:44:19 +0000 + matrix-synapse-py3 (1.25.0) stable; urgency=medium [ Dan Callahan ] diff --git a/debian/control b/debian/control index b10401be43..8167a901a4 100644 --- a/debian/control +++ b/debian/control @@ -31,7 +31,6 @@ Pre-Depends: dpkg (>= 1.16.1) Depends: adduser, debconf, - python3-distutils|libpython3-stdlib (<< 3.6), ${misc:Depends}, ${shlibs:Depends}, ${synapse:pydepends}, diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index dfb8c5d751..90faeaaef0 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -4,6 +4,7 @@ * [Quarantining media by ID](#quarantining-media-by-id) * [Quarantining media in a room](#quarantining-media-in-a-room) * [Quarantining all media of a user](#quarantining-all-media-of-a-user) + * [Protecting media from being quarantined](#protecting-media-from-being-quarantined) - [Delete local media](#delete-local-media) * [Delete a specific local media](#delete-a-specific-local-media) * [Delete local media by date or size](#delete-local-media-by-date-or-size) @@ -123,6 +124,29 @@ The following fields are returned in the JSON response body: * `num_quarantined`: integer - The number of media items successfully quarantined +## Protecting media from being quarantined + +This API protects a single piece of local media from being quarantined using the +above APIs. This is useful for sticker packs and other shared media which you do +not want to get quarantined, especially when +[quarantining media in a room](#quarantining-media-in-a-room). + +Request: + +``` +POST /_synapse/admin/v1/media/protect/ + +{} +``` + +Where `media_id` is in the form of `abcdefg12345...`. + +Response: + +```json +{} +``` + # Delete local media This API deletes the *local* media from the disk of your own server. This includes any local thumbnails and copies of media downloaded from diff --git a/docs/openid.md b/docs/openid.md index ffa4238fff..b86ae89768 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -42,11 +42,10 @@ as follows: * For other installation mechanisms, see the documentation provided by the maintainer. -To enable the OpenID integration, you should then add an `oidc_config` section -to your configuration file (or uncomment the `enabled: true` line in the -existing section). See [sample_config.yaml](./sample_config.yaml) for some -sample settings, as well as the text below for example configurations for -specific providers. +To enable the OpenID integration, you should then add a section to the `oidc_providers` +setting in your configuration file (or uncomment one of the existing examples). +See [sample_config.yaml](./sample_config.yaml) for some sample settings, as well as +the text below for example configurations for specific providers. ## Sample configs @@ -62,20 +61,21 @@ Directory (tenant) ID as it will be used in the Azure links. Edit your Synapse config file and change the `oidc_config` section: ```yaml -oidc_config: - enabled: true - issuer: "https://login.microsoftonline.com//v2.0" - client_id: "" - client_secret: "" - scopes: ["openid", "profile"] - authorization_endpoint: "https://login.microsoftonline.com//oauth2/v2.0/authorize" - token_endpoint: "https://login.microsoftonline.com//oauth2/v2.0/token" - userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo" +oidc_providers: + - idp_id: microsoft + idp_name: Microsoft + issuer: "https://login.microsoftonline.com//v2.0" + client_id: "" + client_secret: "" + scopes: ["openid", "profile"] + authorization_endpoint: "https://login.microsoftonline.com//oauth2/v2.0/authorize" + token_endpoint: "https://login.microsoftonline.com//oauth2/v2.0/token" + userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo" - user_mapping_provider: - config: - localpart_template: "{{ user.preferred_username.split('@')[0] }}" - display_name_template: "{{ user.name }}" + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username.split('@')[0] }}" + display_name_template: "{{ user.name }}" ``` ### [Dex][dex-idp] @@ -103,17 +103,18 @@ Run with `dex serve examples/config-dev.yaml`. Synapse config: ```yaml -oidc_config: - enabled: true - skip_verification: true # This is needed as Dex is served on an insecure endpoint - issuer: "http://127.0.0.1:5556/dex" - client_id: "synapse" - client_secret: "secret" - scopes: ["openid", "profile"] - user_mapping_provider: - config: - localpart_template: "{{ user.name }}" - display_name_template: "{{ user.name|capitalize }}" +oidc_providers: + - idp_id: dex + idp_name: "My Dex server" + skip_verification: true # This is needed as Dex is served on an insecure endpoint + issuer: "http://127.0.0.1:5556/dex" + client_id: "synapse" + client_secret: "secret" + scopes: ["openid", "profile"] + user_mapping_provider: + config: + localpart_template: "{{ user.name }}" + display_name_template: "{{ user.name|capitalize }}" ``` ### [Keycloak][keycloak-idp] @@ -152,16 +153,17 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to 8. Copy Secret ```yaml -oidc_config: - enabled: true - issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}" - client_id: "synapse" - client_secret: "copy secret generated from above" - scopes: ["openid", "profile"] - user_mapping_provider: - config: - localpart_template: "{{ user.preferred_username }}" - display_name_template: "{{ user.name }}" +oidc_providers: + - idp_id: keycloak + idp_name: "My KeyCloak server" + issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}" + client_id: "synapse" + client_secret: "copy secret generated from above" + scopes: ["openid", "profile"] + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username }}" + display_name_template: "{{ user.name }}" ``` ### [Auth0][auth0] @@ -191,16 +193,17 @@ oidc_config: Synapse config: ```yaml -oidc_config: - enabled: true - issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - scopes: ["openid", "profile"] - user_mapping_provider: - config: - localpart_template: "{{ user.preferred_username }}" - display_name_template: "{{ user.name }}" +oidc_providers: + - idp_id: auth0 + idp_name: Auth0 + issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + scopes: ["openid", "profile"] + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username }}" + display_name_template: "{{ user.name }}" ``` ### GitHub @@ -219,21 +222,22 @@ does not return a `sub` property, an alternative `subject_claim` has to be set. Synapse config: ```yaml -oidc_config: - enabled: true - discover: false - issuer: "https://github.com/" - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - authorization_endpoint: "https://github.com/login/oauth/authorize" - token_endpoint: "https://github.com/login/oauth/access_token" - userinfo_endpoint: "https://api.github.com/user" - scopes: ["read:user"] - user_mapping_provider: - config: - subject_claim: "id" - localpart_template: "{{ user.login }}" - display_name_template: "{{ user.name }}" +oidc_providers: + - idp_id: github + idp_name: Github + discover: false + issuer: "https://github.com/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + authorization_endpoint: "https://github.com/login/oauth/authorize" + token_endpoint: "https://github.com/login/oauth/access_token" + userinfo_endpoint: "https://api.github.com/user" + scopes: ["read:user"] + user_mapping_provider: + config: + subject_claim: "id" + localpart_template: "{{ user.login }}" + display_name_template: "{{ user.name }}" ``` ### [Google][google-idp] @@ -243,16 +247,17 @@ oidc_config: 2. add an "OAuth Client ID" for a Web Application under "Credentials". 3. Copy the Client ID and Client Secret, and add the following to your synapse config: ```yaml - oidc_config: - enabled: true - issuer: "https://accounts.google.com/" - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - scopes: ["openid", "profile"] - user_mapping_provider: - config: - localpart_template: "{{ user.given_name|lower }}" - display_name_template: "{{ user.name }}" + oidc_providers: + - idp_id: google + idp_name: Google + issuer: "https://accounts.google.com/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + scopes: ["openid", "profile"] + user_mapping_provider: + config: + localpart_template: "{{ user.given_name|lower }}" + display_name_template: "{{ user.name }}" ``` 4. Back in the Google console, add this Authorized redirect URI: `[synapse public baseurl]/_synapse/oidc/callback`. @@ -266,16 +271,17 @@ oidc_config: Synapse config: ```yaml -oidc_config: - enabled: true - issuer: "https://id.twitch.tv/oauth2/" - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - client_auth_method: "client_secret_post" - user_mapping_provider: - config: - localpart_template: "{{ user.preferred_username }}" - display_name_template: "{{ user.name }}" +oidc_providers: + - idp_id: twitch + idp_name: Twitch + issuer: "https://id.twitch.tv/oauth2/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + client_auth_method: "client_secret_post" + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username }}" + display_name_template: "{{ user.name }}" ``` ### GitLab @@ -287,16 +293,17 @@ oidc_config: Synapse config: ```yaml -oidc_config: - enabled: true - issuer: "https://gitlab.com/" - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - client_auth_method: "client_secret_post" - scopes: ["openid", "read_user"] - user_profile_method: "userinfo_endpoint" - user_mapping_provider: - config: - localpart_template: '{{ user.nickname }}' - display_name_template: '{{ user.name }}' +oidc_providers: + - idp_id: gitlab + idp_name: Gitlab + issuer: "https://gitlab.com/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + client_auth_method: "client_secret_post" + scopes: ["openid", "read_user"] + user_profile_method: "userinfo_endpoint" + user_mapping_provider: + config: + localpart_template: '{{ user.nickname }}' + display_name_template: '{{ user.name }}' ``` diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 9da351f9f3..ae995efe9b 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1709,141 +1709,149 @@ saml2_config: #idp_entityid: 'https://our_idp/entityid' -# Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login. +# List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration +# and login. +# +# Options for each entry include: +# +# idp_id: a unique identifier for this identity provider. Used internally +# by Synapse; should be a single word such as 'github'. +# +# Note that, if this is changed, users authenticating via that provider +# will no longer be recognised as the same user! +# +# idp_name: A user-facing name for this identity provider, which is used to +# offer the user a choice of login mechanisms. +# +# discover: set to 'false' to disable the use of the OIDC discovery mechanism +# to discover endpoints. Defaults to true. +# +# issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery +# is enabled) to discover the provider's endpoints. +# +# client_id: Required. oauth2 client id to use. +# +# client_secret: Required. oauth2 client secret to use. +# +# client_auth_method: auth method to use when exchanging the token. Valid +# values are 'client_secret_basic' (default), 'client_secret_post' and +# 'none'. +# +# scopes: list of scopes to request. This should normally include the "openid" +# scope. Defaults to ["openid"]. +# +# authorization_endpoint: the oauth2 authorization endpoint. Required if +# provider discovery is disabled. +# +# token_endpoint: the oauth2 token endpoint. Required if provider discovery is +# disabled. +# +# userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is +# disabled and the 'openid' scope is not requested. +# +# jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and +# the 'openid' scope is used. +# +# skip_verification: set to 'true' to skip metadata verification. Use this if +# you are connecting to a provider that is not OpenID Connect compliant. +# Defaults to false. Avoid this in production. +# +# user_profile_method: Whether to fetch the user profile from the userinfo +# endpoint. Valid values are: 'auto' or 'userinfo_endpoint'. +# +# Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is +# included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the +# userinfo endpoint. +# +# allow_existing_users: set to 'true' to allow a user logging in via OIDC to +# match a pre-existing account instead of failing. This could be used if +# switching from password logins to OIDC. Defaults to false. +# +# user_mapping_provider: Configuration for how attributes returned from a OIDC +# provider are mapped onto a matrix user. This setting has the following +# sub-properties: +# +# module: The class name of a custom mapping module. Default is +# 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'. +# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers +# for information on implementing a custom mapping provider. +# +# config: Configuration for the mapping provider module. This section will +# be passed as a Python dictionary to the user mapping provider +# module's `parse_config` method. +# +# For the default provider, the following settings are available: +# +# sub: name of the claim containing a unique identifier for the +# user. Defaults to 'sub', which OpenID Connect compliant +# providers should provide. +# +# localpart_template: Jinja2 template for the localpart of the MXID. +# If this is not set, the user will be prompted to choose their +# own username. +# +# display_name_template: Jinja2 template for the display name to set +# on first login. If unset, no displayname will be set. +# +# extra_attributes: a map of Jinja2 templates for extra attributes +# to send back to the client during login. +# Note that these are non-standard and clients will ignore them +# without modifications. +# +# When rendering, the Jinja2 templates are given a 'user' variable, +# which is set to the claims returned by the UserInfo Endpoint and/or +# in the ID Token. # # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md -# for some example configurations. +# for information on how to configure these options. # -oidc_config: - # Uncomment the following to enable authorization against an OpenID Connect - # server. Defaults to false. +# For backwards compatibility, it is also possible to configure a single OIDC +# provider via an 'oidc_config' setting. This is now deprecated and admins are +# advised to migrate to the 'oidc_providers' format. +# +oidc_providers: + # Generic example # - #enabled: true + #- idp_id: my_idp + # idp_name: "My OpenID provider" + # discover: false + # issuer: "https://accounts.example.com/" + # client_id: "provided-by-your-issuer" + # client_secret: "provided-by-your-issuer" + # client_auth_method: client_secret_post + # scopes: ["openid", "profile"] + # authorization_endpoint: "https://accounts.example.com/oauth2/auth" + # token_endpoint: "https://accounts.example.com/oauth2/token" + # userinfo_endpoint: "https://accounts.example.com/userinfo" + # jwks_uri: "https://accounts.example.com/.well-known/jwks.json" + # skip_verification: true - # Uncomment the following to disable use of the OIDC discovery mechanism to - # discover endpoints. Defaults to true. + # For use with Keycloak # - #discover: false + #- idp_id: keycloak + # idp_name: Keycloak + # issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name" + # client_id: "synapse" + # client_secret: "copy secret generated in Keycloak UI" + # scopes: ["openid", "profile"] - # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to - # discover the provider's endpoints. + # For use with Github # - # Required if 'enabled' is true. - # - #issuer: "https://accounts.example.com/" - - # oauth2 client id to use. - # - # Required if 'enabled' is true. - # - #client_id: "provided-by-your-issuer" - - # oauth2 client secret to use. - # - # Required if 'enabled' is true. - # - #client_secret: "provided-by-your-issuer" - - # auth method to use when exchanging the token. - # Valid values are 'client_secret_basic' (default), 'client_secret_post' and - # 'none'. - # - #client_auth_method: client_secret_post - - # list of scopes to request. This should normally include the "openid" scope. - # Defaults to ["openid"]. - # - #scopes: ["openid", "profile"] - - # the oauth2 authorization endpoint. Required if provider discovery is disabled. - # - #authorization_endpoint: "https://accounts.example.com/oauth2/auth" - - # the oauth2 token endpoint. Required if provider discovery is disabled. - # - #token_endpoint: "https://accounts.example.com/oauth2/token" - - # the OIDC userinfo endpoint. Required if discovery is disabled and the - # "openid" scope is not requested. - # - #userinfo_endpoint: "https://accounts.example.com/userinfo" - - # URI where to fetch the JWKS. Required if discovery is disabled and the - # "openid" scope is used. - # - #jwks_uri: "https://accounts.example.com/.well-known/jwks.json" - - # Uncomment to skip metadata verification. Defaults to false. - # - # Use this if you are connecting to a provider that is not OpenID Connect - # compliant. - # Avoid this in production. - # - #skip_verification: true - - # Whether to fetch the user profile from the userinfo endpoint. Valid - # values are: "auto" or "userinfo_endpoint". - # - # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included - # in `scopes`. Uncomment the following to always fetch the userinfo endpoint. - # - #user_profile_method: "userinfo_endpoint" - - # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead - # of failing. This could be used if switching from password logins to OIDC. Defaults to false. - # - #allow_existing_users: true - - # An external module can be provided here as a custom solution to mapping - # attributes returned from a OIDC provider onto a matrix user. - # - user_mapping_provider: - # The custom module's class. Uncomment to use a custom module. - # Default is 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'. - # - # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers - # for information on implementing a custom mapping provider. - # - #module: mapping_provider.OidcMappingProvider - - # Custom configuration values for the module. This section will be passed as - # a Python dictionary to the user mapping provider module's `parse_config` - # method. - # - # The examples below are intended for the default provider: they should be - # changed if using a custom provider. - # - config: - # name of the claim containing a unique identifier for the user. - # Defaults to `sub`, which OpenID Connect compliant providers should provide. - # - #subject_claim: "sub" - - # Jinja2 template for the localpart of the MXID. - # - # When rendering, this template is given the following variables: - # * user: The claims returned by the UserInfo Endpoint and/or in the ID - # Token - # - # If this is not set, the user will be prompted to choose their - # own username. - # - #localpart_template: "{{ user.preferred_username }}" - - # Jinja2 template for the display name to set on first login. - # - # If unset, no displayname will be set. - # - #display_name_template: "{{ user.given_name }} {{ user.last_name }}" - - # Jinja2 templates for extra attributes to send back to the client during - # login. - # - # Note that these are non-standard and clients will ignore them without modifications. - # - #extra_attributes: - #birthdate: "{{ user.birthdate }}" - + #- idp_id: google + # idp_name: Google + # discover: false + # issuer: "https://github.com/" + # client_id: "your-client-id" # TO BE FILLED + # client_secret: "your-client-secret" # TO BE FILLED + # authorization_endpoint: "https://github.com/login/oauth/authorize" + # token_endpoint: "https://github.com/login/oauth/access_token" + # userinfo_endpoint: "https://api.github.com/user" + # scopes: ["read:user"] + # user_mapping_provider: + # config: + # subject_claim: "id" + # localpart_template: "{ user.login }" + # display_name_template: "{ user.name }" # Enable Central Authentication Service (CAS) for registration and login. diff --git a/docs/workers.md b/docs/workers.md index 7fb651bba4..cc5090f224 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -16,6 +16,9 @@ workers only work with PostgreSQL-based Synapse deployments. SQLite should only be used for demo purposes and any admin considering workers should already be running PostgreSQL. +See also https://matrix.org/blog/2020/11/03/how-we-fixed-synapses-scalability +for a higher level overview. + ## Main process/worker communication The processes communicate with each other via a Synapse-specific protocol called diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 2f97e6d258..c7877b4095 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -40,7 +40,7 @@ class CasConfig(Config): self.cas_required_attributes = {} def generate_config_section(self, config_dir_path, server_name, **kwargs): - return """ + return """\ # Enable Central Authentication Service (CAS) for registration and login. # cas_config: diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index fddca19223..c7fa749377 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -15,7 +15,7 @@ # limitations under the License. import string -from typing import Optional, Type +from typing import Iterable, Optional, Type import attr @@ -33,16 +33,8 @@ class OIDCConfig(Config): section = "oidc" def read_config(self, config, **kwargs): - validate_config(MAIN_CONFIG_SCHEMA, config, ()) - - self.oidc_provider = None # type: Optional[OidcProviderConfig] - - oidc_config = config.get("oidc_config") - if oidc_config and oidc_config.get("enabled", False): - validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",)) - self.oidc_provider = _parse_oidc_config_dict(oidc_config) - - if not self.oidc_provider: + self.oidc_providers = tuple(_parse_oidc_provider_configs(config)) + if not self.oidc_providers: return try: @@ -58,144 +50,153 @@ class OIDCConfig(Config): @property def oidc_enabled(self) -> bool: # OIDC is enabled if we have a provider - return bool(self.oidc_provider) + return bool(self.oidc_providers) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ - # Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login. + # List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration + # and login. + # + # Options for each entry include: + # + # idp_id: a unique identifier for this identity provider. Used internally + # by Synapse; should be a single word such as 'github'. + # + # Note that, if this is changed, users authenticating via that provider + # will no longer be recognised as the same user! + # + # idp_name: A user-facing name for this identity provider, which is used to + # offer the user a choice of login mechanisms. + # + # discover: set to 'false' to disable the use of the OIDC discovery mechanism + # to discover endpoints. Defaults to true. + # + # issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery + # is enabled) to discover the provider's endpoints. + # + # client_id: Required. oauth2 client id to use. + # + # client_secret: Required. oauth2 client secret to use. + # + # client_auth_method: auth method to use when exchanging the token. Valid + # values are 'client_secret_basic' (default), 'client_secret_post' and + # 'none'. + # + # scopes: list of scopes to request. This should normally include the "openid" + # scope. Defaults to ["openid"]. + # + # authorization_endpoint: the oauth2 authorization endpoint. Required if + # provider discovery is disabled. + # + # token_endpoint: the oauth2 token endpoint. Required if provider discovery is + # disabled. + # + # userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is + # disabled and the 'openid' scope is not requested. + # + # jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and + # the 'openid' scope is used. + # + # skip_verification: set to 'true' to skip metadata verification. Use this if + # you are connecting to a provider that is not OpenID Connect compliant. + # Defaults to false. Avoid this in production. + # + # user_profile_method: Whether to fetch the user profile from the userinfo + # endpoint. Valid values are: 'auto' or 'userinfo_endpoint'. + # + # Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is + # included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the + # userinfo endpoint. + # + # allow_existing_users: set to 'true' to allow a user logging in via OIDC to + # match a pre-existing account instead of failing. This could be used if + # switching from password logins to OIDC. Defaults to false. + # + # user_mapping_provider: Configuration for how attributes returned from a OIDC + # provider are mapped onto a matrix user. This setting has the following + # sub-properties: + # + # module: The class name of a custom mapping module. Default is + # {mapping_provider!r}. + # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers + # for information on implementing a custom mapping provider. + # + # config: Configuration for the mapping provider module. This section will + # be passed as a Python dictionary to the user mapping provider + # module's `parse_config` method. + # + # For the default provider, the following settings are available: + # + # sub: name of the claim containing a unique identifier for the + # user. Defaults to 'sub', which OpenID Connect compliant + # providers should provide. + # + # localpart_template: Jinja2 template for the localpart of the MXID. + # If this is not set, the user will be prompted to choose their + # own username. + # + # display_name_template: Jinja2 template for the display name to set + # on first login. If unset, no displayname will be set. + # + # extra_attributes: a map of Jinja2 templates for extra attributes + # to send back to the client during login. + # Note that these are non-standard and clients will ignore them + # without modifications. + # + # When rendering, the Jinja2 templates are given a 'user' variable, + # which is set to the claims returned by the UserInfo Endpoint and/or + # in the ID Token. # # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md - # for some example configurations. + # for information on how to configure these options. # - oidc_config: - # Uncomment the following to enable authorization against an OpenID Connect - # server. Defaults to false. + # For backwards compatibility, it is also possible to configure a single OIDC + # provider via an 'oidc_config' setting. This is now deprecated and admins are + # advised to migrate to the 'oidc_providers' format. + # + oidc_providers: + # Generic example # - #enabled: true + #- idp_id: my_idp + # idp_name: "My OpenID provider" + # discover: false + # issuer: "https://accounts.example.com/" + # client_id: "provided-by-your-issuer" + # client_secret: "provided-by-your-issuer" + # client_auth_method: client_secret_post + # scopes: ["openid", "profile"] + # authorization_endpoint: "https://accounts.example.com/oauth2/auth" + # token_endpoint: "https://accounts.example.com/oauth2/token" + # userinfo_endpoint: "https://accounts.example.com/userinfo" + # jwks_uri: "https://accounts.example.com/.well-known/jwks.json" + # skip_verification: true - # Uncomment the following to disable use of the OIDC discovery mechanism to - # discover endpoints. Defaults to true. + # For use with Keycloak # - #discover: false + #- idp_id: keycloak + # idp_name: Keycloak + # issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name" + # client_id: "synapse" + # client_secret: "copy secret generated in Keycloak UI" + # scopes: ["openid", "profile"] - # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to - # discover the provider's endpoints. + # For use with Github # - # Required if 'enabled' is true. - # - #issuer: "https://accounts.example.com/" - - # oauth2 client id to use. - # - # Required if 'enabled' is true. - # - #client_id: "provided-by-your-issuer" - - # oauth2 client secret to use. - # - # Required if 'enabled' is true. - # - #client_secret: "provided-by-your-issuer" - - # auth method to use when exchanging the token. - # Valid values are 'client_secret_basic' (default), 'client_secret_post' and - # 'none'. - # - #client_auth_method: client_secret_post - - # list of scopes to request. This should normally include the "openid" scope. - # Defaults to ["openid"]. - # - #scopes: ["openid", "profile"] - - # the oauth2 authorization endpoint. Required if provider discovery is disabled. - # - #authorization_endpoint: "https://accounts.example.com/oauth2/auth" - - # the oauth2 token endpoint. Required if provider discovery is disabled. - # - #token_endpoint: "https://accounts.example.com/oauth2/token" - - # the OIDC userinfo endpoint. Required if discovery is disabled and the - # "openid" scope is not requested. - # - #userinfo_endpoint: "https://accounts.example.com/userinfo" - - # URI where to fetch the JWKS. Required if discovery is disabled and the - # "openid" scope is used. - # - #jwks_uri: "https://accounts.example.com/.well-known/jwks.json" - - # Uncomment to skip metadata verification. Defaults to false. - # - # Use this if you are connecting to a provider that is not OpenID Connect - # compliant. - # Avoid this in production. - # - #skip_verification: true - - # Whether to fetch the user profile from the userinfo endpoint. Valid - # values are: "auto" or "userinfo_endpoint". - # - # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included - # in `scopes`. Uncomment the following to always fetch the userinfo endpoint. - # - #user_profile_method: "userinfo_endpoint" - - # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead - # of failing. This could be used if switching from password logins to OIDC. Defaults to false. - # - #allow_existing_users: true - - # An external module can be provided here as a custom solution to mapping - # attributes returned from a OIDC provider onto a matrix user. - # - user_mapping_provider: - # The custom module's class. Uncomment to use a custom module. - # Default is {mapping_provider!r}. - # - # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers - # for information on implementing a custom mapping provider. - # - #module: mapping_provider.OidcMappingProvider - - # Custom configuration values for the module. This section will be passed as - # a Python dictionary to the user mapping provider module's `parse_config` - # method. - # - # The examples below are intended for the default provider: they should be - # changed if using a custom provider. - # - config: - # name of the claim containing a unique identifier for the user. - # Defaults to `sub`, which OpenID Connect compliant providers should provide. - # - #subject_claim: "sub" - - # Jinja2 template for the localpart of the MXID. - # - # When rendering, this template is given the following variables: - # * user: The claims returned by the UserInfo Endpoint and/or in the ID - # Token - # - # If this is not set, the user will be prompted to choose their - # own username. - # - #localpart_template: "{{{{ user.preferred_username }}}}" - - # Jinja2 template for the display name to set on first login. - # - # If unset, no displayname will be set. - # - #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}" - - # Jinja2 templates for extra attributes to send back to the client during - # login. - # - # Note that these are non-standard and clients will ignore them without modifications. - # - #extra_attributes: - #birthdate: "{{{{ user.birthdate }}}}" + #- idp_id: google + # idp_name: Google + # discover: false + # issuer: "https://github.com/" + # client_id: "your-client-id" # TO BE FILLED + # client_secret: "your-client-secret" # TO BE FILLED + # authorization_endpoint: "https://github.com/login/oauth/authorize" + # token_endpoint: "https://github.com/login/oauth/access_token" + # userinfo_endpoint: "https://api.github.com/user" + # scopes: ["read:user"] + # user_mapping_provider: + # config: + # subject_claim: "id" + # localpart_template: "{{ user.login }}" + # display_name_template: "{{ user.name }}" """.format( mapping_provider=DEFAULT_USER_MAPPING_PROVIDER ) @@ -234,7 +235,22 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { }, } -# the `oidc_config` setting can either be None (as it is in the default +# the same as OIDC_PROVIDER_CONFIG_SCHEMA, but with compulsory idp_id and idp_name +OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA = { + "allOf": [OIDC_PROVIDER_CONFIG_SCHEMA, {"required": ["idp_id", "idp_name"]}] +} + + +# the `oidc_providers` list can either be None (as it is in the default config), or +# a list of provider configs, each of which requires an explicit ID and name. +OIDC_PROVIDER_LIST_SCHEMA = { + "oneOf": [ + {"type": "null"}, + {"type": "array", "items": OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA}, + ] +} + +# the `oidc_config` setting can either be None (which it used to be in the default # config), or an object. If an object, it is ignored unless it has an "enabled: True" # property. # @@ -243,12 +259,41 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { # additional checks in the code. OIDC_CONFIG_SCHEMA = {"oneOf": [{"type": "null"}, {"type": "object"}]} +# the top-level schema can contain an "oidc_config" and/or an "oidc_providers". MAIN_CONFIG_SCHEMA = { "type": "object", - "properties": {"oidc_config": OIDC_CONFIG_SCHEMA}, + "properties": { + "oidc_config": OIDC_CONFIG_SCHEMA, + "oidc_providers": OIDC_PROVIDER_LIST_SCHEMA, + }, } +def _parse_oidc_provider_configs(config: JsonDict) -> Iterable["OidcProviderConfig"]: + """extract and parse the OIDC provider configs from the config dict + + The configuration may contain either a single `oidc_config` object with an + `enabled: True` property, or a list of provider configurations under + `oidc_providers`, *or both*. + + Returns a generator which yields the OidcProviderConfig objects + """ + validate_config(MAIN_CONFIG_SCHEMA, config, ()) + + for p in config.get("oidc_providers") or []: + yield _parse_oidc_config_dict(p) + + # for backwards-compatibility, it is also possible to provide a single "oidc_config" + # object with an "enabled: True" property. + oidc_config = config.get("oidc_config") + if oidc_config and oidc_config.get("enabled", False): + # MAIN_CONFIG_SCHEMA checks that `oidc_config` is an object, but not that + # it matches OIDC_PROVIDER_CONFIG_SCHEMA (see the comments on OIDC_CONFIG_SCHEMA + # above), so now we need to validate it. + validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",)) + yield _parse_oidc_config_dict(oidc_config) + + def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig": """Take the configuration dict and parse it into an OidcProviderConfig diff --git a/synapse/config/registration.py b/synapse/config/registration.py index cc5f75123c..740c3fc1b1 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -14,14 +14,13 @@ # limitations under the License. import os -from distutils.util import strtobool import pkg_resources from synapse.api.constants import RoomCreationPreset from synapse.config._base import Config, ConfigError from synapse.types import RoomAlias, UserID -from synapse.util.stringutils import random_string_with_symbols +from synapse.util.stringutils import random_string_with_symbols, strtobool class AccountValidityConfig(Config): @@ -86,12 +85,12 @@ class RegistrationConfig(Config): section = "registration" def read_config(self, config, **kwargs): - self.enable_registration = bool( - strtobool(str(config.get("enable_registration", False))) + self.enable_registration = strtobool( + str(config.get("enable_registration", False)) ) if "disable_registration" in config: - self.enable_registration = not bool( - strtobool(str(config["disable_registration"])) + self.enable_registration = not strtobool( + str(config["disable_registration"]) ) self.account_validity = AccountValidityConfig( diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 8028663fa8..3ec4120f85 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -17,7 +17,6 @@ import abc import os -from distutils.util import strtobool from typing import Dict, Optional, Tuple, Type from unpaddedbase64 import encode_base64 @@ -26,6 +25,7 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVers from synapse.types import JsonDict, RoomStreamToken from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze +from synapse.util.stringutils import strtobool # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents # bugs where we accidentally share e.g. signature dicts. However, converting a @@ -34,6 +34,7 @@ from synapse.util.frozenutils import freeze # NOTE: This is overridden by the configuration by the Synapse worker apps, but # for the sake of tests, it is set here while it cannot be configured on the # homeserver object itself. + USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0")) diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 109dc7932f..37a678b6ce 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -163,7 +163,7 @@ class DeviceMessageHandler: await self.store.mark_remote_user_device_cache_as_stale(sender_user_id) # Immediately attempt a resync in the background - run_in_background(self._user_device_resync, sender_user_id) + run_in_background(self._user_device_resync, user_id=sender_user_id) async def send_device_message( self, diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index f63a90ec5c..5e5fda7b2f 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -78,21 +78,28 @@ class OidcHandler: def __init__(self, hs: "HomeServer"): self._sso_handler = hs.get_sso_handler() - provider_conf = hs.config.oidc.oidc_provider + provider_confs = hs.config.oidc.oidc_providers # we should not have been instantiated if there is no configured provider. - assert provider_conf is not None + assert provider_confs self._token_generator = OidcSessionTokenGenerator(hs) - - self._provider = OidcProvider(hs, self._token_generator, provider_conf) + self._providers = { + p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs + } async def load_metadata(self) -> None: """Validate the config and load the metadata from the remote endpoint. Called at startup to ensure we have everything we need. """ - await self._provider.load_metadata() - await self._provider.load_jwks() + for idp_id, p in self._providers.items(): + try: + await p.load_metadata() + await p.load_jwks() + except Exception as e: + raise Exception( + "Error while initialising OIDC provider %r" % (idp_id,) + ) from e async def handle_oidc_callback(self, request: SynapseRequest) -> None: """Handle an incoming request to /_synapse/oidc/callback @@ -184,6 +191,12 @@ class OidcHandler: self._sso_handler.render_error(request, "mismatching_session", str(e)) return + oidc_provider = self._providers.get(session_data.idp_id) + if not oidc_provider: + logger.error("OIDC session uses unknown IdP %r", oidc_provider) + self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP") + return + if b"code" not in request.args: logger.info("Code parameter is missing") self._sso_handler.render_error( @@ -193,7 +206,7 @@ class OidcHandler: code = request.args[b"code"][0].decode() - await self._provider.handle_oidc_callback(request, session_data, code) + await oidc_provider.handle_oidc_callback(request, session_data, code) class OidcError(Exception): diff --git a/synapse/http/client.py b/synapse/http/client.py index dc4b81ca60..df498c8645 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -766,14 +766,24 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): self.max_size = max_size def dataReceived(self, data: bytes) -> None: + # If the deferred was called, bail early. + if self.deferred.called: + return + self.stream.write(data) self.length += len(data) + # The first time the maximum size is exceeded, error and cancel the + # connection. dataReceived might be called again if data was received + # in the meantime. if self.max_size is not None and self.length >= self.max_size: self.deferred.errback(BodyExceededMaxSize()) - self.deferred = defer.Deferred() self.transport.loseConnection() def connectionLost(self, reason: Failure) -> None: + # If the maximum size was already exceeded, there's nothing to do. + if self.deferred.called: + return + if reason.check(ResponseDone): self.deferred.callback(self.length) elif reason.check(PotentialDataLoss): diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index c82b4f87d6..8720b1401f 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -15,6 +15,9 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple + +from twisted.web.http import Request from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_boolean, parse_integer @@ -23,6 +26,10 @@ from synapse.rest.admin._base import ( assert_requester_is_admin, assert_user_is_admin, ) +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -39,11 +46,11 @@ class QuarantineMediaInRoom(RestServlet): admin_patterns("/quarantine_media/(?P[^/]+)") ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, room_id: str): + async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -64,11 +71,11 @@ class QuarantineMediaByUser(RestServlet): PATTERNS = admin_patterns("/user/(?P[^/]+)/media/quarantine") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, user_id: str): + async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -91,11 +98,13 @@ class QuarantineMediaByID(RestServlet): "/media/quarantine/(?P[^/]+)/(?P[^/]+)" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, server_name: str, media_id: str): + async def on_POST( + self, request: Request, server_name: str, media_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -109,17 +118,39 @@ class QuarantineMediaByID(RestServlet): return 200, {} +class ProtectMediaByID(RestServlet): + """Protect local media from being quarantined. + """ + + PATTERNS = admin_patterns("/media/protect/(?P[^/]+)") + + def __init__(self, hs: "HomeServer"): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + logging.info("Protecting local media by ID: %s", media_id) + + # Quarantine this media id + await self.store.mark_local_media_as_safe(media_id) + + return 200, {} + + class ListMediaInRoom(RestServlet): """Lists all of the media in a given room. """ PATTERNS = admin_patterns("/room/(?P[^/]+)/media") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_GET(self, request, room_id): + async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) is_admin = await self.auth.is_server_admin(requester.user) if not is_admin: @@ -133,11 +164,11 @@ class ListMediaInRoom(RestServlet): class PurgeMediaCacheRestServlet(RestServlet): PATTERNS = admin_patterns("/purge_media_cache") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.media_repository = hs.get_media_repository() self.auth = hs.get_auth() - async def on_POST(self, request): + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) before_ts = parse_integer(request, "before_ts", required=True) @@ -154,13 +185,15 @@ class DeleteMediaByID(RestServlet): PATTERNS = admin_patterns("/media/(?P[^/]+)/(?P[^/]+)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() - async def on_DELETE(self, request, server_name: str, media_id: str): + async def on_DELETE( + self, request: Request, server_name: str, media_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) if self.server_name != server_name: @@ -182,13 +215,13 @@ class DeleteMediaByDateSize(RestServlet): PATTERNS = admin_patterns("/media/(?P[^/]+)/delete") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() - async def on_POST(self, request, server_name: str): + async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) before_ts = parse_integer(request, "before_ts", required=True) @@ -222,7 +255,7 @@ class DeleteMediaByDateSize(RestServlet): return 200, {"deleted_media": deleted_media, "total": total} -def register_servlets_for_media_repo(hs, http_server): +def register_servlets_for_media_repo(hs: "HomeServer", http_server): """ Media repo specific APIs. """ @@ -230,6 +263,7 @@ def register_servlets_for_media_repo(hs, http_server): QuarantineMediaInRoom(hs).register(http_server) QuarantineMediaByID(hs).register(http_server) QuarantineMediaByUser(hs).register(http_server) + ProtectMediaByID(hs).register(http_server) ListMediaInRoom(hs).register(http_server) DeleteMediaByID(hs).register(http_server) DeleteMediaByDateSize(hs).register(http_server) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 47c2b44bff..31a41e4a27 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2019 New Vector Ltd +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,10 +17,11 @@ import logging import os import urllib -from typing import Awaitable +from typing import Awaitable, Dict, Generator, List, Optional, Tuple from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender +from twisted.web.http import Request from synapse.api.errors import Codes, SynapseError, cs_error from synapse.http.server import finish_request, respond_with_json @@ -46,7 +47,7 @@ TEXT_CONTENT_TYPES = [ ] -def parse_media_id(request): +def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: try: # This allows users to append e.g. /test.png to the URL. Useful for # clients that parse the URL to see content type. @@ -69,7 +70,7 @@ def parse_media_id(request): ) -def respond_404(request): +def respond_404(request: Request) -> None: respond_with_json( request, 404, @@ -79,8 +80,12 @@ def respond_404(request): async def respond_with_file( - request, media_type, file_path, file_size=None, upload_name=None -): + request: Request, + media_type: str, + file_path: str, + file_size: Optional[int] = None, + upload_name: Optional[str] = None, +) -> None: logger.debug("Responding with %r", file_path) if os.path.isfile(file_path): @@ -98,15 +103,20 @@ async def respond_with_file( respond_404(request) -def add_file_headers(request, media_type, file_size, upload_name): +def add_file_headers( + request: Request, + media_type: str, + file_size: Optional[int], + upload_name: Optional[str], +) -> None: """Adds the correct response headers in preparation for responding with the media. Args: - request (twisted.web.http.Request) - media_type (str): The media/content type. - file_size (int): Size in bytes of the media, if known. - upload_name (str): The name of the requested file, if any. + request + media_type: The media/content type. + file_size: Size in bytes of the media, if known. + upload_name: The name of the requested file, if any. """ def _quote(x): @@ -153,7 +163,8 @@ def add_file_headers(request, media_type, file_size, upload_name): # select private. don't bother setting Expires as all our # clients are smart enough to be happy with Cache-Control request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") - request.setHeader(b"Content-Length", b"%d" % (file_size,)) + if file_size is not None: + request.setHeader(b"Content-Length", b"%d" % (file_size,)) # Tell web crawlers to not index, archive, or follow links in media. This # should help to prevent things in the media repo from showing up in web @@ -184,7 +195,7 @@ _FILENAME_SEPARATOR_CHARS = { } -def _can_encode_filename_as_token(x): +def _can_encode_filename_as_token(x: str) -> bool: for c in x: # from RFC2616: # @@ -206,17 +217,21 @@ def _can_encode_filename_as_token(x): async def respond_with_responder( - request, responder, media_type, file_size, upload_name=None -): + request: Request, + responder: "Optional[Responder]", + media_type: str, + file_size: Optional[int], + upload_name: Optional[str] = None, +) -> None: """Responds to the request with given responder. If responder is None then returns 404. Args: - request (twisted.web.http.Request) - responder (Responder|None) - media_type (str): The media/content type. - file_size (int|None): Size in bytes of the media. If not known it should be None - upload_name (str|None): The name of the requested file, if any. + request + responder + media_type: The media/content type. + file_size: Size in bytes of the media. If not known it should be None + upload_name: The name of the requested file, if any. """ if request._disconnected: logger.warning( @@ -308,22 +323,22 @@ class FileInfo: self.thumbnail_type = thumbnail_type -def get_filename_from_headers(headers): +def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: """ Get the filename of the downloaded file by inspecting the Content-Disposition HTTP header. Args: - headers (dict[bytes, list[bytes]]): The HTTP request headers. + headers: The HTTP request headers. Returns: - A Unicode string of the filename, or None. + The filename, or None. """ content_disposition = headers.get(b"Content-Disposition", [b""]) # No header, bail out. if not content_disposition[0]: - return + return None _, params = _parse_header(content_disposition[0]) @@ -356,17 +371,16 @@ def get_filename_from_headers(headers): return upload_name -def _parse_header(line): +def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]: """Parse a Content-type like header. Cargo-culted from `cgi`, but works on bytes rather than strings. Args: - line (bytes): header to be parsed + line: header to be parsed Returns: - Tuple[bytes, dict[bytes, bytes]]: - the main content-type, followed by the parameter dictionary + The main content-type, followed by the parameter dictionary """ parts = _parseparam(b";" + line) key = next(parts) @@ -386,16 +400,16 @@ def _parse_header(line): return key, pdict -def _parseparam(s): +def _parseparam(s: bytes) -> Generator[bytes, None, None]: """Generator which splits the input on ;, respecting double-quoted sequences Cargo-culted from `cgi`, but works on bytes rather than strings. Args: - s (bytes): header to be parsed + s: header to be parsed Returns: - Iterable[bytes]: the split input + The split input """ while s[:1] == b";": s = s[1:] diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py index 68dd2a1c8a..4e4c6971f7 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/v1/config_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2018 Will Hunt +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,22 +15,29 @@ # limitations under the License. # +from typing import TYPE_CHECKING + +from twisted.web.http import Request + from synapse.http.server import DirectServeJsonResource, respond_with_json +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + class MediaConfigResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() config = hs.get_config() self.clock = hs.get_clock() self.auth = hs.get_auth() self.limits_dict = {"m.upload.size": config.max_upload_size} - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: await self.auth.get_user_by_req(request) respond_with_json(request, 200, self.limits_dict, send_cors=True) - async def _async_render_OPTIONS(self, request): + async def _async_render_OPTIONS(self, request: Request) -> None: respond_with_json(request, 200, {}, send_cors=True) diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index d3d8457303..3ed219ae43 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,24 +14,31 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING + +from twisted.web.http import Request -import synapse.http.servlet from synapse.http.server import DirectServeJsonResource, set_cors_headers +from synapse.http.servlet import parse_boolean from ._base import parse_media_id, respond_404 +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + from synapse.rest.media.v1.media_repository import MediaRepository + logger = logging.getLogger(__name__) class DownloadResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs, media_repo): + def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): super().__init__() self.media_repo = media_repo self.server_name = hs.hostname - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: set_cors_headers(request) request.setHeader( b"Content-Security-Policy", @@ -49,9 +57,7 @@ class DownloadResource(DirectServeJsonResource): if server_name == self.server_name: await self.media_repo.get_local_media(request, media_id, name) else: - allow_remote = synapse.http.servlet.parse_boolean( - request, "allow_remote", default=True - ) + allow_remote = parse_boolean(request, "allow_remote", default=True) if not allow_remote: logger.info( "Rejecting request for remote media %s/%s due to allow_remote", diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index 9e079f672f..7792f26e78 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,11 +17,12 @@ import functools import os import re +from typing import Callable, List NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") -def _wrap_in_base_path(func): +def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]": """Takes a function that returns a relative path and turns it into an absolute path based on the location of the primary media store """ @@ -41,12 +43,18 @@ class MediaFilePaths: to write to the backup media store (when one is configured) """ - def __init__(self, primary_base_path): + def __init__(self, primary_base_path: str): self.base_path = primary_base_path def default_thumbnail_rel( - self, default_top_level, default_sub_type, width, height, content_type, method - ): + self, + default_top_level: str, + default_sub_type: str, + width: int, + height: int, + content_type: str, + method: str, + ) -> str: top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( @@ -55,12 +63,14 @@ class MediaFilePaths: default_thumbnail = _wrap_in_base_path(default_thumbnail_rel) - def local_media_filepath_rel(self, media_id): + def local_media_filepath_rel(self, media_id: str) -> str: return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:]) local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) - def local_media_thumbnail_rel(self, media_id, width, height, content_type, method): + def local_media_thumbnail_rel( + self, media_id: str, width: int, height: int, content_type: str, method: str + ) -> str: top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( @@ -86,7 +96,7 @@ class MediaFilePaths: media_id[4:], ) - def remote_media_filepath_rel(self, server_name, file_id): + def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str: return os.path.join( "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] ) @@ -94,8 +104,14 @@ class MediaFilePaths: remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) def remote_media_thumbnail_rel( - self, server_name, file_id, width, height, content_type, method - ): + self, + server_name: str, + file_id: str, + width: int, + height: int, + content_type: str, + method: str, + ) -> str: top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( @@ -113,7 +129,7 @@ class MediaFilePaths: # Should be removed after some time, when most of the thumbnails are stored # using the new path. def remote_media_thumbnail_rel_legacy( - self, server_name, file_id, width, height, content_type + self, server_name: str, file_id: str, width: int, height: int, content_type: str ): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) @@ -126,7 +142,7 @@ class MediaFilePaths: file_name, ) - def remote_media_thumbnail_dir(self, server_name, file_id): + def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str: return os.path.join( self.base_path, "remote_thumbnail", @@ -136,7 +152,7 @@ class MediaFilePaths: file_id[4:], ) - def url_cache_filepath_rel(self, media_id): + def url_cache_filepath_rel(self, media_id: str) -> str: if NEW_FORMAT_ID_RE.match(media_id): # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf @@ -146,7 +162,7 @@ class MediaFilePaths: url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) - def url_cache_filepath_dirs_to_delete(self, media_id): + def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id file" if NEW_FORMAT_ID_RE.match(media_id): return [os.path.join(self.base_path, "url_cache", media_id[:10])] @@ -156,7 +172,9 @@ class MediaFilePaths: os.path.join(self.base_path, "url_cache", media_id[0:2]), ] - def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method): + def url_cache_thumbnail_rel( + self, media_id: str, width: int, height: int, content_type: str, method: str + ) -> str: # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf @@ -178,7 +196,7 @@ class MediaFilePaths: url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) - def url_cache_thumbnail_directory(self, media_id): + def url_cache_thumbnail_directory(self, media_id: str) -> str: # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf @@ -195,7 +213,7 @@ class MediaFilePaths: media_id[4:], ) - def url_cache_thumbnail_dirs_to_delete(self, media_id): + def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id thumbnails" # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 83beb02b05..4c9946a616 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,12 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import errno import logging import os import shutil -from typing import IO, Dict, List, Optional, Tuple +from io import BytesIO +from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple import twisted.internet.error import twisted.web.http @@ -56,6 +56,9 @@ from .thumbnail_resource import ThumbnailResource from .thumbnailer import Thumbnailer, ThumbnailError from .upload_resource import UploadResource +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -63,7 +66,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 class MediaRepository: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.client = hs.get_federation_http_client() @@ -73,16 +76,16 @@ class MediaRepository: self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels - self.primary_base_path = hs.config.media_store_path - self.filepaths = MediaFilePaths(self.primary_base_path) + self.primary_base_path = hs.config.media_store_path # type: str + self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") - self.recently_accessed_remotes = set() - self.recently_accessed_locals = set() + self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]] + self.recently_accessed_locals = set() # type: Set[str] self.federation_domain_whitelist = hs.config.federation_domain_whitelist @@ -113,7 +116,7 @@ class MediaRepository: "update_recently_accessed_media", self._update_recently_accessed ) - async def _update_recently_accessed(self): + async def _update_recently_accessed(self) -> None: remote_media = self.recently_accessed_remotes self.recently_accessed_remotes = set() @@ -124,12 +127,12 @@ class MediaRepository: local_media, remote_media, self.clock.time_msec() ) - def mark_recently_accessed(self, server_name, media_id): + def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None: """Mark the given media as recently accessed. Args: - server_name (str|None): Origin server of media, or None if local - media_id (str): The media ID of the content + server_name: Origin server of media, or None if local + media_id: The media ID of the content """ if server_name: self.recently_accessed_remotes.add((server_name, media_id)) @@ -459,7 +462,14 @@ class MediaRepository: def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) - def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): + def _generate_thumbnail( + self, + thumbnailer: Thumbnailer, + t_width: int, + t_height: int, + t_method: str, + t_type: str, + ) -> Optional[BytesIO]: m_width = thumbnailer.width m_height = thumbnailer.height @@ -470,22 +480,20 @@ class MediaRepository: m_height, self.max_image_pixels, ) - return + return None if thumbnailer.transpose_method is not None: m_width, m_height = thumbnailer.transpose() if t_method == "crop": - t_byte_source = thumbnailer.crop(t_width, t_height, t_type) + return thumbnailer.crop(t_width, t_height, t_type) elif t_method == "scale": t_width, t_height = thumbnailer.aspect(t_width, t_height) t_width = min(m_width, t_width) t_height = min(m_height, t_height) - t_byte_source = thumbnailer.scale(t_width, t_height, t_type) - else: - t_byte_source = None + return thumbnailer.scale(t_width, t_height, t_type) - return t_byte_source + return None async def generate_local_exact_thumbnail( self, @@ -776,7 +784,7 @@ class MediaRepository: return {"width": m_width, "height": m_height} - async def delete_old_remote_media(self, before_ts): + async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]: old_media = await self.store.get_remote_media_before(before_ts) deleted = 0 @@ -928,7 +936,7 @@ class MediaRepositoryResource(Resource): within a given rectangle. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): # If we're not configured to use it, raise if we somehow got here. if not hs.config.can_load_media_repo: raise ConfigError("Synapse is not configured to use a media repo.") diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 268e0c8f50..89cdd605aa 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vecotr Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ import os import shutil from typing import IO, TYPE_CHECKING, Any, Optional, Sequence +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender from synapse.logging.context import defer_to_thread, make_deferred_yieldable @@ -270,7 +272,7 @@ class MediaStorage: return self.filepaths.local_media_filepath_rel(file_info.file_id) -def _write_file_synchronously(source, dest): +def _write_file_synchronously(source: IO, dest: IO) -> None: """Write `source` to the file like `dest` synchronously. Should be called from a thread. @@ -286,14 +288,14 @@ class FileResponder(Responder): """Wraps an open file that can be sent to a request. Args: - open_file (file): A file like object to be streamed ot the client, + open_file: A file like object to be streamed ot the client, is closed when finished streaming. """ - def __init__(self, open_file): + def __init__(self, open_file: IO): self.open_file = open_file - def write_to_consumer(self, consumer): + def write_to_consumer(self, consumer: IConsumer) -> Deferred: return make_deferred_yieldable( FileSender().beginFileTransfer(self.open_file, consumer) ) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 1082389d9b..a632099167 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import datetime import errno import fnmatch @@ -23,12 +23,13 @@ import re import shutil import sys import traceback -from typing import Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union from urllib import parse as urlparse import attr from twisted.internet.error import DNSLookupError +from twisted.web.http import Request from synapse.api.errors import Codes, SynapseError from synapse.http.client import SimpleHttpClient @@ -41,6 +42,7 @@ from synapse.http.servlet import parse_integer, parse_string from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers +from synapse.rest.media.v1.media_storage import MediaStorage from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache @@ -48,6 +50,12 @@ from synapse.util.stringutils import random_string from ._base import FileInfo +if TYPE_CHECKING: + from lxml import etree + + from synapse.app.homeserver import HomeServer + from synapse.rest.media.v1.media_repository import MediaRepository + logger = logging.getLogger(__name__) _charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I) @@ -119,7 +127,12 @@ class OEmbedError(Exception): class PreviewUrlResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs, media_repo, media_storage): + def __init__( + self, + hs: "HomeServer", + media_repo: "MediaRepository", + media_storage: MediaStorage, + ): super().__init__() self.auth = hs.get_auth() @@ -166,11 +179,11 @@ class PreviewUrlResource(DirectServeJsonResource): self._start_expire_url_cache_data, 10 * 1000 ) - async def _async_render_OPTIONS(self, request): + async def _async_render_OPTIONS(self, request: Request) -> None: request.setHeader(b"Allow", b"OPTIONS, GET") respond_with_json(request, 200, {}, send_cors=True) - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: # XXX: if get_user_by_req fails, what should we do in an async render? requester = await self.auth.get_user_by_req(request) @@ -450,7 +463,7 @@ class PreviewUrlResource(DirectServeJsonResource): logger.warning("Error downloading oEmbed metadata from %s: %r", url, e) raise OEmbedError() from e - async def _download_url(self, url: str, user): + async def _download_url(self, url: str, user: str) -> Dict[str, Any]: # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? @@ -580,7 +593,7 @@ class PreviewUrlResource(DirectServeJsonResource): "expire_url_cache_data", self._expire_url_cache_data ) - async def _expire_url_cache_data(self): + async def _expire_url_cache_data(self) -> None: """Clean up expired url cache content, media and thumbnails. """ # TODO: Delete from backup media store @@ -676,7 +689,9 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("No media removed from url cache") -def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]: +def decode_and_calc_og( + body: bytes, media_uri: str, request_encoding: Optional[str] = None +) -> Dict[str, Optional[str]]: # If there's no body, nothing useful is going to be found. if not body: return {} @@ -697,7 +712,7 @@ def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str] return og -def _calc_og(tree, media_uri): +def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]: # suck our tree into lxml and define our OG response. # if we see any image URLs in the OG response, then spider them @@ -801,7 +816,9 @@ def _calc_og(tree, media_uri): for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE) ) og["og:description"] = summarize_paragraphs(text_nodes) - else: + elif og["og:description"]: + # This must be a non-empty string at this point. + assert isinstance(og["og:description"], str) og["og:description"] = summarize_paragraphs([og["og:description"]]) # TODO: delete the url downloads to stop diskfilling, @@ -809,7 +826,9 @@ def _calc_og(tree, media_uri): return og -def _iterate_over_text(tree, *tags_to_ignore): +def _iterate_over_text( + tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]] +) -> Generator[str, None, None]: """Iterate over the tree returning text nodes in a depth first fashion, skipping text nodes inside certain tags. """ @@ -843,32 +862,32 @@ def _iterate_over_text(tree, *tags_to_ignore): ) -def _rebase_url(url, base): - base = list(urlparse.urlparse(base)) - url = list(urlparse.urlparse(url)) - if not url[0]: # fix up schema - url[0] = base[0] or "http" - if not url[1]: # fix up hostname - url[1] = base[1] - if not url[2].startswith("/"): - url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2] - return urlparse.urlunparse(url) +def _rebase_url(url: str, base: str) -> str: + base_parts = list(urlparse.urlparse(base)) + url_parts = list(urlparse.urlparse(url)) + if not url_parts[0]: # fix up schema + url_parts[0] = base_parts[0] or "http" + if not url_parts[1]: # fix up hostname + url_parts[1] = base_parts[1] + if not url_parts[2].startswith("/"): + url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2] + return urlparse.urlunparse(url_parts) -def _is_media(content_type): - if content_type.lower().startswith("image/"): - return True +def _is_media(content_type: str) -> bool: + return content_type.lower().startswith("image/") -def _is_html(content_type): +def _is_html(content_type: str) -> bool: content_type = content_type.lower() - if content_type.startswith("text/html") or content_type.startswith( + return content_type.startswith("text/html") or content_type.startswith( "application/xhtml" - ): - return True + ) -def summarize_paragraphs(text_nodes, min_size=200, max_size=500): +def summarize_paragraphs( + text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500 +) -> Optional[str]: # Try to get a summary of between 200 and 500 words, respecting # first paragraph and then word boundaries. # TODO: Respect sentences? diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 67f67efde7..e92006faa9 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging import os import shutil -from typing import Optional +from typing import TYPE_CHECKING, Optional from synapse.config._base import Config from synapse.logging.context import defer_to_thread, run_in_background @@ -27,13 +28,17 @@ from .media_storage import FileResponder logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer -class StorageProvider: + +class StorageProvider(metaclass=abc.ABCMeta): """A storage provider is a service that can store uploaded media and retrieve them. """ - async def store_file(self, path: str, file_info: FileInfo): + @abc.abstractmethod + async def store_file(self, path: str, file_info: FileInfo) -> None: """Store the file described by file_info. The actual contents can be retrieved by reading the file in file_info.upload_path. @@ -42,6 +47,7 @@ class StorageProvider: file_info: The metadata of the file. """ + @abc.abstractmethod async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: """Attempt to fetch the file described by file_info and stream it into writer. @@ -78,10 +84,10 @@ class StorageProviderWrapper(StorageProvider): self.store_synchronous = store_synchronous self.store_remote = store_remote - def __str__(self): + def __str__(self) -> str: return "StorageProviderWrapper[%s]" % (self.backend,) - async def store_file(self, path, file_info): + async def store_file(self, path: str, file_info: FileInfo) -> None: if not file_info.server_name and not self.store_local: return None @@ -91,7 +97,7 @@ class StorageProviderWrapper(StorageProvider): if self.store_synchronous: # store_file is supposed to return an Awaitable, but guard # against improper implementations. - return await maybe_awaitable(self.backend.store_file(path, file_info)) + await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore else: # TODO: Handle errors. async def store(): @@ -103,9 +109,8 @@ class StorageProviderWrapper(StorageProvider): logger.exception("Error storing file") run_in_background(store) - return None - async def fetch(self, path, file_info): + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: # store_file is supposed to return an Awaitable, but guard # against improper implementations. return await maybe_awaitable(self.backend.fetch(path, file_info)) @@ -115,11 +120,11 @@ class FileStorageProviderBackend(StorageProvider): """A storage provider that stores files in a directory on a filesystem. Args: - hs (HomeServer) + hs config: The config returned by `parse_config`. """ - def __init__(self, hs, config): + def __init__(self, hs: "HomeServer", config: str): self.hs = hs self.cache_directory = hs.config.media_store_path self.base_directory = config @@ -127,7 +132,7 @@ class FileStorageProviderBackend(StorageProvider): def __str__(self): return "FileStorageProviderBackend[%s]" % (self.base_directory,) - async def store_file(self, path, file_info): + async def store_file(self, path: str, file_info: FileInfo) -> None: """See StorageProvider.store_file""" primary_fname = os.path.join(self.cache_directory, path) @@ -137,19 +142,21 @@ class FileStorageProviderBackend(StorageProvider): if not os.path.exists(dirname): os.makedirs(dirname) - return await defer_to_thread( + await defer_to_thread( self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname ) - async def fetch(self, path, file_info): + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: """See StorageProvider.fetch""" backup_fname = os.path.join(self.base_directory, path) if os.path.isfile(backup_fname): return FileResponder(open(backup_fname, "rb")) + return None + @staticmethod - def parse_config(config): + def parse_config(config: dict) -> str: """Called on startup to parse config supplied. This should parse the config and raise if there is a problem. diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 30421b663a..d6880f2e6e 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,10 +16,14 @@ import logging +from typing import TYPE_CHECKING + +from twisted.web.http import Request from synapse.api.errors import SynapseError from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.servlet import parse_integer, parse_string +from synapse.rest.media.v1.media_storage import MediaStorage from ._base import ( FileInfo, @@ -28,13 +33,22 @@ from ._base import ( respond_with_responder, ) +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + from synapse.rest.media.v1.media_repository import MediaRepository + logger = logging.getLogger(__name__) class ThumbnailResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs, media_repo, media_storage): + def __init__( + self, + hs: "HomeServer", + media_repo: "MediaRepository", + media_storage: MediaStorage, + ): super().__init__() self.store = hs.get_datastore() @@ -43,7 +57,7 @@ class ThumbnailResource(DirectServeJsonResource): self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.server_name = hs.hostname - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: set_cors_headers(request) server_name, media_id, _ = parse_media_id(request) width = parse_integer(request, "width", required=True) @@ -73,8 +87,14 @@ class ThumbnailResource(DirectServeJsonResource): self.media_repo.mark_recently_accessed(server_name, media_id) async def _respond_local_thumbnail( - self, request, media_id, width, height, method, m_type - ): + self, + request: Request, + media_id: str, + width: int, + height: int, + method: str, + m_type: str, + ) -> None: media_info = await self.store.get_local_media(media_id) if not media_info: @@ -114,13 +134,13 @@ class ThumbnailResource(DirectServeJsonResource): async def _select_or_generate_local_thumbnail( self, - request, - media_id, - desired_width, - desired_height, - desired_method, - desired_type, - ): + request: Request, + media_id: str, + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, + ) -> None: media_info = await self.store.get_local_media(media_id) if not media_info: @@ -178,14 +198,14 @@ class ThumbnailResource(DirectServeJsonResource): async def _select_or_generate_remote_thumbnail( self, - request, - server_name, - media_id, - desired_width, - desired_height, - desired_method, - desired_type, - ): + request: Request, + server_name: str, + media_id: str, + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, + ) -> None: media_info = await self.media_repo.get_remote_media_info(server_name, media_id) thumbnail_infos = await self.store.get_remote_media_thumbnails( @@ -239,8 +259,15 @@ class ThumbnailResource(DirectServeJsonResource): raise SynapseError(400, "Failed to generate thumbnail.") async def _respond_remote_thumbnail( - self, request, server_name, media_id, width, height, method, m_type - ): + self, + request: Request, + server_name: str, + media_id: str, + width: int, + height: int, + method: str, + m_type: str, + ) -> None: # TODO: Don't download the whole remote file # We should proxy the thumbnail from the remote server instead of # downloading the remote file and generating our own thumbnails. @@ -275,12 +302,12 @@ class ThumbnailResource(DirectServeJsonResource): def _select_thumbnail( self, - desired_width, - desired_height, - desired_method, - desired_type, + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, thumbnail_infos, - ): + ) -> dict: d_w = desired_width d_h = desired_height diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 32a8e4f960..07903e4017 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +15,7 @@ # limitations under the License. import logging from io import BytesIO +from typing import Tuple from PIL import Image @@ -39,7 +41,7 @@ class Thumbnailer: FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} - def __init__(self, input_path): + def __init__(self, input_path: str): try: self.image = Image.open(input_path) except OSError as e: @@ -59,11 +61,11 @@ class Thumbnailer: # A lot of parsing errors can happen when parsing EXIF logger.info("Error parsing image EXIF information: %s", e) - def transpose(self): + def transpose(self) -> Tuple[int, int]: """Transpose the image using its EXIF Orientation tag Returns: - Tuple[int, int]: (width, height) containing the new image size in pixels. + A tuple containing the new image size in pixels as (width, height). """ if self.transpose_method is not None: self.image = self.image.transpose(self.transpose_method) @@ -73,7 +75,7 @@ class Thumbnailer: self.image.info["exif"] = None return self.image.size - def aspect(self, max_width, max_height): + def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]: """Calculate the largest size that preserves aspect ratio which fits within the given rectangle:: @@ -91,7 +93,7 @@ class Thumbnailer: else: return (max_height * self.width) // self.height, max_height - def _resize(self, width, height): + def _resize(self, width: int, height: int) -> Image: # 1-bit or 8-bit color palette images need converting to RGB # otherwise they will be scaled using nearest neighbour which # looks awful @@ -99,7 +101,7 @@ class Thumbnailer: self.image = self.image.convert("RGB") return self.image.resize((width, height), Image.ANTIALIAS) - def scale(self, width, height, output_type): + def scale(self, width: int, height: int, output_type: str) -> BytesIO: """Rescales the image to the given dimensions. Returns: @@ -108,7 +110,7 @@ class Thumbnailer: scaled = self._resize(width, height) return self._encode_image(scaled, output_type) - def crop(self, width, height, output_type): + def crop(self, width: int, height: int, output_type: str) -> BytesIO: """Rescales and crops the image to the given dimensions preserving aspect:: (w_in / h_in) = (w_scaled / h_scaled) @@ -136,7 +138,7 @@ class Thumbnailer: cropped = scaled_image.crop((crop_left, 0, crop_right, height)) return self._encode_image(cropped, output_type) - def _encode_image(self, output_image, output_type): + def _encode_image(self, output_image: Image, output_type: str) -> BytesIO: output_bytes_io = BytesIO() fmt = self.FORMATS[output_type] if fmt == "JPEG": diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 42febc9afc..6da76ae994 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,18 +15,25 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING + +from twisted.web.http import Request from synapse.api.errors import Codes, SynapseError from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_string +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + from synapse.rest.media.v1.media_repository import MediaRepository + logger = logging.getLogger(__name__) class UploadResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs, media_repo): + def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): super().__init__() self.media_repo = media_repo @@ -37,10 +45,10 @@ class UploadResource(DirectServeJsonResource): self.max_upload_size = hs.config.max_upload_size self.clock = hs.get_clock() - async def _async_render_OPTIONS(self, request): + async def _async_render_OPTIONS(self, request: Request) -> None: respond_with_json(request, 200, {}, send_cors=True) - async def _async_render_POST(self, request): + async def _async_render_POST(self, request: Request) -> None: requester = await self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 7128dc1742..e46e44ba54 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -16,6 +16,8 @@ import logging from typing import Dict, List, Optional, Tuple +import attr + from synapse.api.constants import EventContentFields from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import make_event_from_dict @@ -28,6 +30,25 @@ from synapse.types import JsonDict logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True) +class _CalculateChainCover: + """Return value for _calculate_chain_cover_txn. + """ + + # The last room_id/depth/stream processed. + room_id = attr.ib(type=str) + depth = attr.ib(type=int) + stream = attr.ib(type=int) + + # Number of rows processed + processed_count = attr.ib(type=int) + + # Map from room_id to last depth/stream processed for each room that we have + # processed all events for (i.e. the rooms we can flip the + # `has_auth_chain_index` for) + finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]]) + + class EventsBackgroundUpdatesStore(SQLBaseStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" @@ -719,138 +740,29 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): current_room_id = progress.get("current_room_id", "") - # Have we finished processing the current room. - finished = progress.get("finished", True) - # Where we've processed up to in the room, defaults to the start of the # room. last_depth = progress.get("last_depth", -1) last_stream = progress.get("last_stream", -1) - # Have we set the `has_auth_chain_index` for the room yet. - has_set_room_has_chain_index = progress.get( - "has_set_room_has_chain_index", False + result = await self.db_pool.runInteraction( + "_chain_cover_index", + self._calculate_chain_cover_txn, + current_room_id, + last_depth, + last_stream, + batch_size, + single_room=False, ) - if finished: - # If we've finished with the previous room (or its our first - # iteration) we move on to the next room. + finished = result.processed_count == 0 - def _get_next_room(txn: Cursor) -> Optional[str]: - sql = """ - SELECT room_id FROM rooms - WHERE room_id > ? - AND ( - NOT has_auth_chain_index - OR has_auth_chain_index IS NULL - ) - ORDER BY room_id - LIMIT 1 - """ - txn.execute(sql, (current_room_id,)) - row = txn.fetchone() - if row: - return row[0] + total_rows_processed = result.processed_count + current_room_id = result.room_id + last_depth = result.depth + last_stream = result.stream - return None - - current_room_id = await self.db_pool.runInteraction( - "_chain_cover_index", _get_next_room - ) - if not current_room_id: - await self.db_pool.updates._end_background_update("chain_cover") - return 0 - - logger.debug("Adding chain cover to %s", current_room_id) - - def _calculate_auth_chain( - txn: Cursor, last_depth: int, last_stream: int - ) -> Tuple[int, int, int]: - # Get the next set of events in the room (that we haven't already - # computed chain cover for). We do this in topological order. - - # We want to do a `(topological_ordering, stream_ordering) > (?,?)` - # comparison, but that is not supported on older SQLite versions - tuple_clause, tuple_args = make_tuple_comparison_clause( - self.database_engine, - [ - ("topological_ordering", last_depth), - ("stream_ordering", last_stream), - ], - ) - - sql = """ - SELECT - event_id, state_events.type, state_events.state_key, - topological_ordering, stream_ordering - FROM events - INNER JOIN state_events USING (event_id) - LEFT JOIN event_auth_chains USING (event_id) - LEFT JOIN event_auth_chain_to_calculate USING (event_id) - WHERE events.room_id = ? - AND event_auth_chains.event_id IS NULL - AND event_auth_chain_to_calculate.event_id IS NULL - AND %(tuple_cmp)s - ORDER BY topological_ordering, stream_ordering - LIMIT ? - """ % { - "tuple_cmp": tuple_clause, - } - - args = [current_room_id] - args.extend(tuple_args) - args.append(batch_size) - - txn.execute(sql, args) - rows = txn.fetchall() - - # Put the results in the necessary format for - # `_add_chain_cover_index` - event_to_room_id = {row[0]: current_room_id for row in rows} - event_to_types = {row[0]: (row[1], row[2]) for row in rows} - - new_last_depth = rows[-1][3] if rows else last_depth # type: int - new_last_stream = rows[-1][4] if rows else last_stream # type: int - - count = len(rows) - - # We also need to fetch the auth events for them. - auth_events = self.db_pool.simple_select_many_txn( - txn, - table="event_auth", - column="event_id", - iterable=event_to_room_id, - keyvalues={}, - retcols=("event_id", "auth_id"), - ) - - event_to_auth_chain = {} # type: Dict[str, List[str]] - for row in auth_events: - event_to_auth_chain.setdefault(row["event_id"], []).append( - row["auth_id"] - ) - - # Calculate and persist the chain cover index for this set of events. - # - # Annoyingly we need to gut wrench into the persit event store so that - # we can reuse the function to calculate the chain cover for rooms. - PersistEventsStore._add_chain_cover_index( - txn, - self.db_pool, - event_to_room_id, - event_to_types, - event_to_auth_chain, - ) - - return new_last_depth, new_last_stream, count - - last_depth, last_stream, count = await self.db_pool.runInteraction( - "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream - ) - - total_rows_processed = count - - if count < batch_size and not has_set_room_has_chain_index: + for room_id, (depth, stream) in result.finished_room_map.items(): # If we've done all the events in the room we flip the # `has_auth_chain_index` in the DB. Note that its possible for # further events to be persisted between the above and setting the @@ -860,42 +772,159 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): await self.db_pool.simple_update( table="rooms", - keyvalues={"room_id": current_room_id}, + keyvalues={"room_id": room_id}, updatevalues={"has_auth_chain_index": True}, desc="_chain_cover_index", ) - has_set_room_has_chain_index = True # Handle any events that might have raced with us flipping the # bit above. - last_depth, last_stream, count = await self.db_pool.runInteraction( - "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream + result = await self.db_pool.runInteraction( + "_chain_cover_index", + self._calculate_chain_cover_txn, + room_id, + depth, + stream, + batch_size=None, + single_room=True, ) - total_rows_processed += count + total_rows_processed += result.processed_count - # Note that at this point its technically possible that more events - # than our `batch_size` have been persisted without their chain - # cover, so we need to continue processing this room if the last - # count returned was equal to the `batch_size`. + if finished: + await self.db_pool.updates._end_background_update("chain_cover") + return total_rows_processed - if count < batch_size: - # We've finished calculating the index for this room, move on to the - # next room. - await self.db_pool.updates._background_update_progress( - "chain_cover", {"current_room_id": current_room_id, "finished": True}, - ) - else: - # We still have outstanding events to calculate the index for. - await self.db_pool.updates._background_update_progress( - "chain_cover", - { - "current_room_id": current_room_id, - "last_depth": last_depth, - "last_stream": last_stream, - "has_auth_chain_index": has_set_room_has_chain_index, - "finished": False, - }, - ) + await self.db_pool.updates._background_update_progress( + "chain_cover", + { + "current_room_id": current_room_id, + "last_depth": last_depth, + "last_stream": last_stream, + }, + ) return total_rows_processed + + def _calculate_chain_cover_txn( + self, + txn: Cursor, + last_room_id: str, + last_depth: int, + last_stream: int, + batch_size: Optional[int], + single_room: bool, + ) -> _CalculateChainCover: + """Calculate the chain cover for `batch_size` events, ordered by + `(room_id, depth, stream)`. + + Args: + txn, + last_room_id, last_depth, last_stream: The `(room_id, depth, stream)` + tuple to fetch results after. + batch_size: The maximum number of events to process. If None then + no limit. + single_room: Whether to calculate the index for just the given + room. + """ + + # Get the next set of events in the room (that we haven't already + # computed chain cover for). We do this in topological order. + + # We want to do a `(topological_ordering, stream_ordering) > (?,?)` + # comparison, but that is not supported on older SQLite versions + tuple_clause, tuple_args = make_tuple_comparison_clause( + self.database_engine, + [ + ("events.room_id", last_room_id), + ("topological_ordering", last_depth), + ("stream_ordering", last_stream), + ], + ) + + extra_clause = "" + if single_room: + extra_clause = "AND events.room_id = ?" + tuple_args.append(last_room_id) + + sql = """ + SELECT + event_id, state_events.type, state_events.state_key, + topological_ordering, stream_ordering, + events.room_id + FROM events + INNER JOIN state_events USING (event_id) + LEFT JOIN event_auth_chains USING (event_id) + LEFT JOIN event_auth_chain_to_calculate USING (event_id) + WHERE event_auth_chains.event_id IS NULL + AND event_auth_chain_to_calculate.event_id IS NULL + AND %(tuple_cmp)s + %(extra)s + ORDER BY events.room_id, topological_ordering, stream_ordering + %(limit)s + """ % { + "tuple_cmp": tuple_clause, + "limit": "LIMIT ?" if batch_size is not None else "", + "extra": extra_clause, + } + + if batch_size is not None: + tuple_args.append(batch_size) + + txn.execute(sql, tuple_args) + rows = txn.fetchall() + + # Put the results in the necessary format for + # `_add_chain_cover_index` + event_to_room_id = {row[0]: row[5] for row in rows} + event_to_types = {row[0]: (row[1], row[2]) for row in rows} + + # Calculate the new last position we've processed up to. + new_last_depth = rows[-1][3] if rows else last_depth # type: int + new_last_stream = rows[-1][4] if rows else last_stream # type: int + new_last_room_id = rows[-1][5] if rows else "" # type: str + + # Map from room_id to last depth/stream_ordering processed for the room, + # excluding the last room (which we're likely still processing). We also + # need to include the room passed in if it's not included in the result + # set (as we then know we've processed all events in said room). + # + # This is the set of rooms that we can now safely flip the + # `has_auth_chain_index` bit for. + finished_rooms = { + row[5]: (row[3], row[4]) for row in rows if row[5] != new_last_room_id + } + if last_room_id not in finished_rooms and last_room_id != new_last_room_id: + finished_rooms[last_room_id] = (last_depth, last_stream) + + count = len(rows) + + # We also need to fetch the auth events for them. + auth_events = self.db_pool.simple_select_many_txn( + txn, + table="event_auth", + column="event_id", + iterable=event_to_room_id, + keyvalues={}, + retcols=("event_id", "auth_id"), + ) + + event_to_auth_chain = {} # type: Dict[str, List[str]] + for row in auth_events: + event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"]) + + # Calculate and persist the chain cover index for this set of events. + # + # Annoyingly we need to gut wrench into the persit event store so that + # we can reuse the function to calculate the chain cover for rooms. + PersistEventsStore._add_chain_cover_index( + txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain, + ) + + return _CalculateChainCover( + room_id=new_last_room_id, + depth=new_last_depth, + stream=new_last_stream, + processed_count=count, + finished_room_map=finished_rooms, + ) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 4b2f224718..283c8a5e22 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -169,7 +170,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def get_local_media_before( self, before_ts: int, size_gt: int, keep_profiles: bool, - ) -> Optional[List[str]]: + ) -> List[str]: # to find files that have never been accessed (last_access_ts IS NULL) # compare with `created_ts` diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 77ba9d819e..bc7621b8d6 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -17,14 +17,13 @@ import logging from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple -from canonicaljson import encode_canonical_json - from synapse.push import PusherConfig, ThrottleParams from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.types import Connection from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: @@ -315,7 +314,7 @@ class PusherStore(PusherWorkerStore): "device_display_name": device_display_name, "ts": pushkey_ts, "lang": lang, - "data": bytearray(encode_canonical_json(data)), + "data": json_encoder.encode(data), "last_stream_ordering": last_stream_ordering, "profile_tag": profile_tag, "id": stream_id, diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 61d96a6c28..b103c8694c 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -75,3 +75,22 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str: if len(items) <= maxitems: return str(items) return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]" + + +def strtobool(val: str) -> bool: + """Convert a string representation of truth to True or False + + True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values + are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if + 'val' is anything else. + + This is lifted from distutils.util.strtobool, with the exception that it actually + returns a bool, rather than an int. + """ + val = val.lower() + if val in ("y", "yes", "t", "true", "on", "1"): + return True + elif val in ("n", "no", "f", "false", "off", "0"): + return False + else: + raise ValueError("invalid truth value %r" % (val,)) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 02e21ed6ca..b3dfa40d25 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -145,7 +145,7 @@ class OidcHandlerTestCase(HomeserverTestCase): hs = self.setup_test_homeserver(proxied_http_client=self.http_client) self.handler = hs.get_oidc_handler() - self.provider = self.handler._provider + self.provider = self.handler._providers["oidc"] sso_handler = hs.get_sso_handler() # Mock the render error method. self.render_error = Mock(return_value=None) @@ -866,7 +866,7 @@ async def _make_callback_with_userinfo( from synapse.handlers.oidc_handler import OidcSessionData handler = hs.get_oidc_handler() - provider = handler._provider + provider = handler._providers["oidc"] provider._exchange_code = simple_async_mock(return_value={}) provider._parse_id_token = simple_async_mock(return_value=userinfo) provider._fetch_userinfo = simple_async_mock(return_value=userinfo) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 4e51839d0f..686012dd25 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -1095,7 +1095,7 @@ class MatrixFederationAgentTests(unittest.TestCase): # Expire both caches and repeat the request self.reactor.pump((10000.0,)) - # Repated the request, this time it should fail if the lookup fails. + # Repeat the request, this time it should fail if the lookup fails. fetch_d = defer.ensureDeferred( self.well_known_resolver.get_well_known(b"testserv") ) @@ -1130,7 +1130,7 @@ class MatrixFederationAgentTests(unittest.TestCase): content=b'{ "m.server": "' + (b"a" * WELL_KNOWN_MAX_SIZE) + b'" }', ) - # The result is sucessful, but disabled delegation. + # The result is successful, but disabled delegation. r = self.successResultOf(fetch_d) self.assertIsNone(r.delegated_server) diff --git a/tests/http/test_client.py b/tests/http/test_client.py new file mode 100644 index 0000000000..f17c122e93 --- /dev/null +++ b/tests/http/test_client.py @@ -0,0 +1,101 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from io import BytesIO + +from mock import Mock + +from twisted.python.failure import Failure +from twisted.web.client import ResponseDone + +from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size + +from tests.unittest import TestCase + + +class ReadBodyWithMaxSizeTests(TestCase): + def setUp(self): + """Start reading the body, returns the response, result and proto""" + self.response = Mock() + self.result = BytesIO() + self.deferred = read_body_with_max_size(self.response, self.result, 6) + + # Fish the protocol out of the response. + self.protocol = self.response.deliverBody.call_args[0][0] + self.protocol.transport = Mock() + + def _cleanup_error(self): + """Ensure that the error in the Deferred is handled gracefully.""" + called = [False] + + def errback(f): + called[0] = True + + self.deferred.addErrback(errback) + self.assertTrue(called[0]) + + def test_no_error(self): + """A response that is NOT too large.""" + + # Start sending data. + self.protocol.dataReceived(b"12345") + # Close the connection. + self.protocol.connectionLost(Failure(ResponseDone())) + + self.assertEqual(self.result.getvalue(), b"12345") + self.assertEqual(self.deferred.result, 5) + + def test_too_large(self): + """A response which is too large raises an exception.""" + + # Start sending data. + self.protocol.dataReceived(b"1234567890") + # Close the connection. + self.protocol.connectionLost(Failure(ResponseDone())) + + self.assertEqual(self.result.getvalue(), b"1234567890") + self.assertIsInstance(self.deferred.result, Failure) + self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize) + self._cleanup_error() + + def test_multiple_packets(self): + """Data should be accummulated through mutliple packets.""" + + # Start sending data. + self.protocol.dataReceived(b"12") + self.protocol.dataReceived(b"34") + # Close the connection. + self.protocol.connectionLost(Failure(ResponseDone())) + + self.assertEqual(self.result.getvalue(), b"1234") + self.assertEqual(self.deferred.result, 4) + + def test_additional_data(self): + """A connection can receive data after being closed.""" + + # Start sending data. + self.protocol.dataReceived(b"1234567890") + self.assertIsInstance(self.deferred.result, Failure) + self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize) + self.protocol.transport.loseConnection.assert_called_once() + + # More data might have come in. + self.protocol.dataReceived(b"1234567890") + # Close the connection. + self.protocol.connectionLost(Failure(ResponseDone())) + + self.assertEqual(self.result.getvalue(), b"1234567890") + self.assertIsInstance(self.deferred.result, Failure) + self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize) + self._cleanup_error() diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 586b877bda..9d22c04073 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -153,8 +153,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - # Allow for uploading and downloading to/from the media repo self.media_repo = hs.get_media_repository_resource() self.download_resource = self.media_repo.children[b"download"] @@ -428,7 +426,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Mark the second item as safe from quarantine. _, media_id_2 = server_and_media_id_2.split("/") - self.get_success(self.store.mark_local_media_as_safe(media_id_2)) + # Quarantine the media + url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),) + channel = self.make_request("POST", url, access_token=admin_user_tok) + self.pump(1.0) + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) # Quarantine all media by this user url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 3e8661f9b9..a6488a3d29 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -475,7 +475,9 @@ class UIAuthTests(unittest.HomeserverTestCase): session_id = channel.json_body["session"] # do the OIDC auth, but auth as the wrong user - channel = self.helper.auth_via_oidc("wrong_user", ui_auth_session_id=session_id) + channel = self.helper.auth_via_oidc( + {"sub": "wrong_user"}, ui_auth_session_id=session_id + ) # that should return a failure message self.assertSubstring("We were unable to validate", channel.text_body) diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index ff67a73749..0c46ad595b 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple +from typing import Dict, List, Set, Tuple from twisted.trial import unittest @@ -483,22 +483,20 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): login.register_servlets, ] - def test_background_update(self): - """Test that the background update to calculate auth chains for historic - rooms works correctly. + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.user_id = self.register_user("foo", "pass") + self.token = self.login("foo", "pass") + self.requester = create_requester(self.user_id) + + def _generate_room(self) -> Tuple[str, List[Set[str]]]: + """Insert a room without a chain cover index. """ - - # Create a room - user_id = self.register_user("foo", "pass") - token = self.login("foo", "pass") - room_id = self.helper.create_room_as(user_id, tok=token) - requester = create_requester(user_id) - - store = self.hs.get_datastore() + room_id = self.helper.create_room_as(self.user_id, tok=self.token) # Mark the room as not having a chain cover index self.get_success( - store.db_pool.simple_update( + self.store.db_pool.simple_update( table="rooms", keyvalues={"room_id": room_id}, updatevalues={"has_auth_chain_index": False}, @@ -508,42 +506,44 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): # Create a fork in the DAG with different events. event_handler = self.hs.get_event_creation_handler() - latest_event_ids = self.get_success(store.get_prev_events_for_room(room_id)) + latest_event_ids = self.get_success( + self.store.get_prev_events_for_room(room_id) + ) event, context = self.get_success( event_handler.create_event( - requester, + self.requester, { "type": "some_state_type", "state_key": "", "content": {}, "room_id": room_id, - "sender": user_id, + "sender": self.user_id, }, prev_event_ids=latest_event_ids, ) ) self.get_success( - event_handler.handle_new_client_event(requester, event, context) + event_handler.handle_new_client_event(self.requester, event, context) ) - state1 = list(self.get_success(context.get_current_state_ids()).values()) + state1 = set(self.get_success(context.get_current_state_ids()).values()) event, context = self.get_success( event_handler.create_event( - requester, + self.requester, { "type": "some_state_type", "state_key": "", "content": {}, "room_id": room_id, - "sender": user_id, + "sender": self.user_id, }, prev_event_ids=latest_event_ids, ) ) self.get_success( - event_handler.handle_new_client_event(requester, event, context) + event_handler.handle_new_client_event(self.requester, event, context) ) - state2 = list(self.get_success(context.get_current_state_ids()).values()) + state2 = set(self.get_success(context.get_current_state_ids()).values()) # Delete the chain cover info. @@ -551,36 +551,191 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): txn.execute("DELETE FROM event_auth_chains") txn.execute("DELETE FROM event_auth_chain_links") - self.get_success(store.db_pool.runInteraction("test", _delete_tables)) + self.get_success(self.store.db_pool.runInteraction("test", _delete_tables)) + + return room_id, [state1, state2] + + def test_background_update_single_room(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + + # Create a room + room_id, states = self._generate_room() # Insert and run the background update. self.get_success( - store.db_pool.simple_insert( + self.store.db_pool.simple_insert( "background_updates", {"update_name": "chain_cover", "progress_json": "{}"}, ) ) # Ugh, have to reset this flag - store.db_pool.updates._all_done = False + self.store.db_pool.updates._all_done = False while not self.get_success( - store.db_pool.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - store.db_pool.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) # Test that the `has_auth_chain_index` has been set - self.assertTrue(self.get_success(store.has_auth_chain_index(room_id))) + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id))) # Test that calculating the auth chain difference using the newly # calculated chain cover works. self.get_success( - store.db_pool.runInteraction( + self.store.db_pool.runInteraction( "test", - store._get_auth_chain_difference_using_cover_index_txn, + self.store._get_auth_chain_difference_using_cover_index_txn, room_id, - [state1, state2], + states, ) ) + + def test_background_update_multiple_rooms(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + # Create a room + room_id1, states1 = self._generate_room() + room_id2, states2 = self._generate_room() + room_id3, states2 = self._generate_room() + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": "chain_cover", "progress_json": "{}"}, + ) + ) + + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + # Test that the `has_auth_chain_index` has been set + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1))) + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2))) + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3))) + + # Test that calculating the auth chain difference using the newly + # calculated chain cover works. + self.get_success( + self.store.db_pool.runInteraction( + "test", + self.store._get_auth_chain_difference_using_cover_index_txn, + room_id1, + states1, + ) + ) + + def test_background_update_single_large_room(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + + # Create a room + room_id, states = self._generate_room() + + # Add a bunch of state so that it takes multiple iterations of the + # background update to process the room. + for i in range(0, 150): + self.helper.send_state( + room_id, event_type="m.test", body={"index": i}, tok=self.token + ) + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": "chain_cover", "progress_json": "{}"}, + ) + ) + + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + + iterations = 0 + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + iterations += 1 + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + # Ensure that we did actually take multiple iterations to process the + # room. + self.assertGreater(iterations, 1) + + # Test that the `has_auth_chain_index` has been set + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id))) + + # Test that calculating the auth chain difference using the newly + # calculated chain cover works. + self.get_success( + self.store.db_pool.runInteraction( + "test", + self.store._get_auth_chain_difference_using_cover_index_txn, + room_id, + states, + ) + ) + + def test_background_update_multiple_large_room(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + + # Create the rooms + room_id1, _ = self._generate_room() + room_id2, _ = self._generate_room() + + # Add a bunch of state so that it takes multiple iterations of the + # background update to process the room. + for i in range(0, 150): + self.helper.send_state( + room_id1, event_type="m.test", body={"index": i}, tok=self.token + ) + + for i in range(0, 150): + self.helper.send_state( + room_id2, event_type="m.test", body={"index": i}, tok=self.token + ) + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": "chain_cover", "progress_json": "{}"}, + ) + ) + + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + + iterations = 0 + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + iterations += 1 + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + # Ensure that we did actually take multiple iterations to process the + # room. + self.assertGreater(iterations, 1) + + # Test that the `has_auth_chain_index` has been set + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1))) + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))