mirror of
https://github.com/element-hq/synapse
synced 2024-10-01 21:32:40 +00:00
Merge branch 'release-v1.24.0' of github.com:matrix-org/synapse into matrix-org-hotfixes
This commit is contained in:
commit
16744644f6
59 changed files with 1604 additions and 383 deletions
|
@ -6,7 +6,7 @@
|
||||||
set -ex
|
set -ex
|
||||||
|
|
||||||
apt-get update
|
apt-get update
|
||||||
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev zlib1g-dev tox
|
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox
|
||||||
|
|
||||||
export LANG="C.UTF-8"
|
export LANG="C.UTF-8"
|
||||||
|
|
||||||
|
|
1
changelog.d/8565.misc
Normal file
1
changelog.d/8565.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Simplify the way the `HomeServer` object caches its internal attributes.
|
1
changelog.d/8799.bugfix
Normal file
1
changelog.d/8799.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Allow per-room profiles to be used for the server notice user.
|
1
changelog.d/8800.misc
Normal file
1
changelog.d/8800.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add additional error checking for OpenID Connect and SAML mapping providers.
|
1
changelog.d/8804.feature
Normal file
1
changelog.d/8804.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Allow Date header through CORS. Contributed by Nicolas Chamo.
|
1
changelog.d/8809.misc
Normal file
1
changelog.d/8809.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Remove unnecessary function arguments and add typing to several membership replication classes.
|
1
changelog.d/8819.misc
Normal file
1
changelog.d/8819.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add tests for `password_auth_provider`s.
|
1
changelog.d/8820.feature
Normal file
1
changelog.d/8820.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add a config option, `push.group_by_unread_count`, which controls whether unread message counts in push notifications are defined as "the number of rooms with unread messages" or "total unread messages".
|
1
changelog.d/8833.removal
Normal file
1
changelog.d/8833.removal
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Disable pretty printing JSON responses for curl. Users who want pretty-printed output should use [jq](https://stedolan.github.io/jq/) in combination with curl. Contributed by @tulir.
|
1
changelog.d/8835.bugfix
Normal file
1
changelog.d/8835.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix minor long-standing bug in login, where we would offer the `password` login type if a custom auth provider supported it, even if password login was disabled.
|
1
changelog.d/8843.feature
Normal file
1
changelog.d/8843.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add `force_purge` option to delete-room admin api.
|
1
changelog.d/8845.misc
Normal file
1
changelog.d/8845.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Drop redundant database index on `event_json`.
|
1
changelog.d/8847.misc
Normal file
1
changelog.d/8847.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Simplify `uk.half-shot.msc2778.login.application_service` login handler.
|
1
changelog.d/8848.bugfix
Normal file
1
changelog.d/8848.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix a long-standing bug which caused Synapse to require unspecified parameters during user-interactive authentication.
|
1
changelog.d/8849.misc
Normal file
1
changelog.d/8849.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Refactor `password_auth_provider` support code.
|
1
changelog.d/8850.misc
Normal file
1
changelog.d/8850.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add missing `ordering` to background database updates.
|
1
changelog.d/8851.misc
Normal file
1
changelog.d/8851.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Simplify the way the `HomeServer` object caches its internal attributes.
|
1
changelog.d/8854.misc
Normal file
1
changelog.d/8854.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Allow for specifying a room version when creating a room in unit tests via `RestHelper.create_room_as`.
|
1
changelog.d/8855.feature
Normal file
1
changelog.d/8855.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add support for re-trying generation of a localpart for OpenID Connect mapping providers.
|
|
@ -382,7 +382,7 @@ the new room. Users on other servers will be unaffected.
|
||||||
|
|
||||||
The API is:
|
The API is:
|
||||||
|
|
||||||
```json
|
```
|
||||||
POST /_synapse/admin/v1/rooms/<room_id>/delete
|
POST /_synapse/admin/v1/rooms/<room_id>/delete
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -439,6 +439,10 @@ The following JSON body parameters are available:
|
||||||
future attempts to join the room. Defaults to `false`.
|
future attempts to join the room. Defaults to `false`.
|
||||||
* `purge` - Optional. If set to `true`, it will remove all traces of the room from your database.
|
* `purge` - Optional. If set to `true`, it will remove all traces of the room from your database.
|
||||||
Defaults to `true`.
|
Defaults to `true`.
|
||||||
|
* `force_purge` - Optional, and ignored unless `purge` is `true`. If set to `true`, it
|
||||||
|
will force a purge to go ahead even if there are local users still in the room. Do not
|
||||||
|
use this unless a regular `purge` operation fails, as it could leave those users'
|
||||||
|
clients in a confused state.
|
||||||
|
|
||||||
The JSON body must not be empty. The body must be at least `{}`.
|
The JSON body must not be empty. The body must be at least `{}`.
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ Password auth provider classes must provide the following methods:
|
||||||
|
|
||||||
It should perform any appropriate sanity checks on the provided
|
It should perform any appropriate sanity checks on the provided
|
||||||
configuration, and return an object which is then passed into
|
configuration, and return an object which is then passed into
|
||||||
|
`__init__`.
|
||||||
|
|
||||||
This method should have the `@staticmethod` decoration.
|
This method should have the `@staticmethod` decoration.
|
||||||
|
|
||||||
|
|
|
@ -2271,6 +2271,16 @@ push:
|
||||||
#
|
#
|
||||||
#include_content: false
|
#include_content: false
|
||||||
|
|
||||||
|
# When a push notification is received, an unread count is also sent.
|
||||||
|
# This number can either be calculated as the number of unread messages
|
||||||
|
# for the user, or the number of *rooms* the user has unread messages in.
|
||||||
|
#
|
||||||
|
# The default value is "true", meaning push clients will see the number of
|
||||||
|
# rooms with unread messages in them. Uncomment to instead send the number
|
||||||
|
# of unread messages.
|
||||||
|
#
|
||||||
|
#group_unread_count_by_room: false
|
||||||
|
|
||||||
|
|
||||||
# Spam checkers are third-party modules that can block specific actions
|
# Spam checkers are third-party modules that can block specific actions
|
||||||
# of local users, such as creating rooms and registering undesirable
|
# of local users, such as creating rooms and registering undesirable
|
||||||
|
|
1
mypy.ini
1
mypy.ini
|
@ -80,6 +80,7 @@ files =
|
||||||
synapse/util/metrics.py,
|
synapse/util/metrics.py,
|
||||||
tests/replication,
|
tests/replication,
|
||||||
tests/test_utils,
|
tests/test_utils,
|
||||||
|
tests/handlers/test_password_providers.py,
|
||||||
tests/rest/client/v2_alpha/test_auth.py,
|
tests/rest/client/v2_alpha/test_auth.py,
|
||||||
tests/util/test_stream_change_cache.py
|
tests/util/test_stream_change_cache.py
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,9 @@ class PushConfig(Config):
|
||||||
def read_config(self, config, **kwargs):
|
def read_config(self, config, **kwargs):
|
||||||
push_config = config.get("push") or {}
|
push_config = config.get("push") or {}
|
||||||
self.push_include_content = push_config.get("include_content", True)
|
self.push_include_content = push_config.get("include_content", True)
|
||||||
|
self.push_group_unread_count_by_room = push_config.get(
|
||||||
|
"group_unread_count_by_room", True
|
||||||
|
)
|
||||||
|
|
||||||
pusher_instances = config.get("pusher_instances") or []
|
pusher_instances = config.get("pusher_instances") or []
|
||||||
self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
|
self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
|
||||||
|
@ -68,4 +71,14 @@ class PushConfig(Config):
|
||||||
# include the event ID and room ID in push notification payloads.
|
# include the event ID and room ID in push notification payloads.
|
||||||
#
|
#
|
||||||
#include_content: false
|
#include_content: false
|
||||||
|
|
||||||
|
# When a push notification is received, an unread count is also sent.
|
||||||
|
# This number can either be calculated as the number of unread messages
|
||||||
|
# for the user, or the number of *rooms* the user has unread messages in.
|
||||||
|
#
|
||||||
|
# The default value is "true", meaning push clients will see the number of
|
||||||
|
# rooms with unread messages in them. Uncomment to instead send the number
|
||||||
|
# of unread messages.
|
||||||
|
#
|
||||||
|
#group_unread_count_by_room: false
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014 - 2016 OpenMarket Ltd
|
# Copyright 2014 - 2016 OpenMarket Ltd
|
||||||
# Copyright 2017 Vector Creations Ltd
|
# Copyright 2017 Vector Creations Ltd
|
||||||
|
# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -25,6 +26,7 @@ from typing import (
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
|
@ -181,17 +183,12 @@ class AuthHandler(BaseHandler):
|
||||||
# better way to break the loop
|
# better way to break the loop
|
||||||
account_handler = ModuleApi(hs, self)
|
account_handler = ModuleApi(hs, self)
|
||||||
|
|
||||||
self.password_providers = []
|
self.password_providers = [
|
||||||
for module, config in hs.config.password_providers:
|
PasswordProvider.load(module, config, account_handler)
|
||||||
try:
|
for module, config in hs.config.password_providers
|
||||||
self.password_providers.append(
|
]
|
||||||
module(config=config, account_handler=account_handler)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error while initializing %r: %s", module, e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
logger.info("Extra password_providers: %r", self.password_providers)
|
logger.info("Extra password_providers: %s", self.password_providers)
|
||||||
|
|
||||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
|
@ -205,15 +202,23 @@ class AuthHandler(BaseHandler):
|
||||||
# type in the list. (NB that the spec doesn't require us to do so and
|
# type in the list. (NB that the spec doesn't require us to do so and
|
||||||
# clients which favour types that they don't understand over those that
|
# clients which favour types that they don't understand over those that
|
||||||
# they do are technically broken)
|
# they do are technically broken)
|
||||||
|
|
||||||
|
# start out by assuming PASSWORD is enabled; we will remove it later if not.
|
||||||
login_types = []
|
login_types = []
|
||||||
if self._password_enabled:
|
if hs.config.password_localdb_enabled:
|
||||||
login_types.append(LoginType.PASSWORD)
|
login_types.append(LoginType.PASSWORD)
|
||||||
|
|
||||||
for provider in self.password_providers:
|
for provider in self.password_providers:
|
||||||
if hasattr(provider, "get_supported_login_types"):
|
if hasattr(provider, "get_supported_login_types"):
|
||||||
for t in provider.get_supported_login_types().keys():
|
for t in provider.get_supported_login_types().keys():
|
||||||
if t not in login_types:
|
if t not in login_types:
|
||||||
login_types.append(t)
|
login_types.append(t)
|
||||||
|
|
||||||
|
if not self._password_enabled:
|
||||||
|
login_types.remove(LoginType.PASSWORD)
|
||||||
|
|
||||||
self._supported_login_types = login_types
|
self._supported_login_types = login_types
|
||||||
|
|
||||||
# Login types and UI Auth types have a heavy overlap, but are not
|
# Login types and UI Auth types have a heavy overlap, but are not
|
||||||
# necessarily identical. Login types have SSO (and other login types)
|
# necessarily identical. Login types have SSO (and other login types)
|
||||||
# added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
|
# added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
|
||||||
|
@ -230,6 +235,13 @@ class AuthHandler(BaseHandler):
|
||||||
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Ratelimitier for failed /login attempts
|
||||||
|
self._failed_login_attempts_ratelimiter = Ratelimiter(
|
||||||
|
clock=hs.get_clock(),
|
||||||
|
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
||||||
|
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
||||||
|
)
|
||||||
|
|
||||||
self._clock = self.hs.get_clock()
|
self._clock = self.hs.get_clock()
|
||||||
|
|
||||||
# Expire old UI auth sessions after a period of time.
|
# Expire old UI auth sessions after a period of time.
|
||||||
|
@ -642,14 +654,8 @@ class AuthHandler(BaseHandler):
|
||||||
res = await checker.check_auth(authdict, clientip=clientip)
|
res = await checker.check_auth(authdict, clientip=clientip)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
# build a v1-login-style dict out of the authdict and fall back to the
|
# fall back to the v1 login flow
|
||||||
# v1 code
|
canonical_id, _ = await self.validate_login(authdict)
|
||||||
user_id = authdict.get("user")
|
|
||||||
|
|
||||||
if user_id is None:
|
|
||||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
|
||||||
|
|
||||||
(canonical_id, callback) = await self.validate_login(user_id, authdict)
|
|
||||||
return canonical_id
|
return canonical_id
|
||||||
|
|
||||||
def _get_params_recaptcha(self) -> dict:
|
def _get_params_recaptcha(self) -> dict:
|
||||||
|
@ -824,15 +830,157 @@ class AuthHandler(BaseHandler):
|
||||||
return self._supported_login_types
|
return self._supported_login_types
|
||||||
|
|
||||||
async def validate_login(
|
async def validate_login(
|
||||||
self, username: str, login_submission: Dict[str, Any]
|
self, login_submission: Dict[str, Any], ratelimit: bool = False,
|
||||||
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
|
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
|
||||||
"""Authenticates the user for the /login API
|
"""Authenticates the user for the /login API
|
||||||
|
|
||||||
Also used by the user-interactive auth flow to validate
|
Also used by the user-interactive auth flow to validate auth types which don't
|
||||||
m.login.password auth types.
|
have an explicit UIA handler, including m.password.auth.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
username: username supplied by the user
|
login_submission: the whole of the login submission
|
||||||
|
(including 'type' and other relevant fields)
|
||||||
|
ratelimit: whether to apply the failed_login_attempt ratelimiter
|
||||||
|
Returns:
|
||||||
|
A tuple of the canonical user id, and optional callback
|
||||||
|
to be called once the access token and device id are issued
|
||||||
|
Raises:
|
||||||
|
StoreError if there was a problem accessing the database
|
||||||
|
SynapseError if there was a problem with the request
|
||||||
|
LoginError if there was an authentication problem.
|
||||||
|
"""
|
||||||
|
login_type = login_submission.get("type")
|
||||||
|
if not isinstance(login_type, str):
|
||||||
|
raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM)
|
||||||
|
|
||||||
|
# ideally, we wouldn't be checking the identifier unless we know we have a login
|
||||||
|
# method which uses it (https://github.com/matrix-org/synapse/issues/8836)
|
||||||
|
#
|
||||||
|
# But the auth providers' check_auth interface requires a username, so in
|
||||||
|
# practice we can only support login methods which we can map to a username
|
||||||
|
# anyway.
|
||||||
|
|
||||||
|
# special case to check for "password" for the check_password interface
|
||||||
|
# for the auth providers
|
||||||
|
password = login_submission.get("password")
|
||||||
|
if login_type == LoginType.PASSWORD:
|
||||||
|
if not self._password_enabled:
|
||||||
|
raise SynapseError(400, "Password login has been disabled.")
|
||||||
|
if not isinstance(password, str):
|
||||||
|
raise SynapseError(400, "Bad parameter: password", Codes.INVALID_PARAM)
|
||||||
|
|
||||||
|
# map old-school login fields into new-school "identifier" fields.
|
||||||
|
identifier_dict = convert_client_dict_legacy_fields_to_identifier(
|
||||||
|
login_submission
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert phone type identifiers to generic threepids
|
||||||
|
if identifier_dict["type"] == "m.id.phone":
|
||||||
|
identifier_dict = login_id_phone_to_thirdparty(identifier_dict)
|
||||||
|
|
||||||
|
# convert threepid identifiers to user IDs
|
||||||
|
if identifier_dict["type"] == "m.id.thirdparty":
|
||||||
|
address = identifier_dict.get("address")
|
||||||
|
medium = identifier_dict.get("medium")
|
||||||
|
|
||||||
|
if medium is None or address is None:
|
||||||
|
raise SynapseError(400, "Invalid thirdparty identifier")
|
||||||
|
|
||||||
|
# For emails, canonicalise the address.
|
||||||
|
# We store all email addresses canonicalised in the DB.
|
||||||
|
# (See add_threepid in synapse/handlers/auth.py)
|
||||||
|
if medium == "email":
|
||||||
|
try:
|
||||||
|
address = canonicalise_email(address)
|
||||||
|
except ValueError as e:
|
||||||
|
raise SynapseError(400, str(e))
|
||||||
|
|
||||||
|
# We also apply account rate limiting using the 3PID as a key, as
|
||||||
|
# otherwise using 3PID bypasses the ratelimiting based on user ID.
|
||||||
|
if ratelimit:
|
||||||
|
self._failed_login_attempts_ratelimiter.ratelimit(
|
||||||
|
(medium, address), update=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for login providers that support 3pid login types
|
||||||
|
if login_type == LoginType.PASSWORD:
|
||||||
|
# we've already checked that there is a (valid) password field
|
||||||
|
assert isinstance(password, str)
|
||||||
|
(
|
||||||
|
canonical_user_id,
|
||||||
|
callback_3pid,
|
||||||
|
) = await self.check_password_provider_3pid(medium, address, password)
|
||||||
|
if canonical_user_id:
|
||||||
|
# Authentication through password provider and 3pid succeeded
|
||||||
|
return canonical_user_id, callback_3pid
|
||||||
|
|
||||||
|
# No password providers were able to handle this 3pid
|
||||||
|
# Check local store
|
||||||
|
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
|
medium, address
|
||||||
|
)
|
||||||
|
if not user_id:
|
||||||
|
logger.warning(
|
||||||
|
"unknown 3pid identifier medium %s, address %r", medium, address
|
||||||
|
)
|
||||||
|
# We mark that we've failed to log in here, as
|
||||||
|
# `check_password_provider_3pid` might have returned `None` due
|
||||||
|
# to an incorrect password, rather than the account not
|
||||||
|
# existing.
|
||||||
|
#
|
||||||
|
# If it returned None but the 3PID was bound then we won't hit
|
||||||
|
# this code path, which is fine as then the per-user ratelimit
|
||||||
|
# will kick in below.
|
||||||
|
if ratelimit:
|
||||||
|
self._failed_login_attempts_ratelimiter.can_do_action(
|
||||||
|
(medium, address)
|
||||||
|
)
|
||||||
|
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
|
identifier_dict = {"type": "m.id.user", "user": user_id}
|
||||||
|
|
||||||
|
# by this point, the identifier should be an m.id.user: if it's anything
|
||||||
|
# else, we haven't understood it.
|
||||||
|
if identifier_dict["type"] != "m.id.user":
|
||||||
|
raise SynapseError(400, "Unknown login identifier type")
|
||||||
|
|
||||||
|
username = identifier_dict.get("user")
|
||||||
|
if not username:
|
||||||
|
raise SynapseError(400, "User identifier is missing 'user' key")
|
||||||
|
|
||||||
|
if username.startswith("@"):
|
||||||
|
qualified_user_id = username
|
||||||
|
else:
|
||||||
|
qualified_user_id = UserID(username, self.hs.hostname).to_string()
|
||||||
|
|
||||||
|
# Check if we've hit the failed ratelimit (but don't update it)
|
||||||
|
if ratelimit:
|
||||||
|
self._failed_login_attempts_ratelimiter.ratelimit(
|
||||||
|
qualified_user_id.lower(), update=False
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self._validate_userid_login(username, login_submission)
|
||||||
|
except LoginError:
|
||||||
|
# The user has failed to log in, so we need to update the rate
|
||||||
|
# limiter. Using `can_do_action` avoids us raising a ratelimit
|
||||||
|
# exception and masking the LoginError. The actual ratelimiting
|
||||||
|
# should have happened above.
|
||||||
|
if ratelimit:
|
||||||
|
self._failed_login_attempts_ratelimiter.can_do_action(
|
||||||
|
qualified_user_id.lower()
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _validate_userid_login(
|
||||||
|
self, username: str, login_submission: Dict[str, Any],
|
||||||
|
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
|
||||||
|
"""Helper for validate_login
|
||||||
|
|
||||||
|
Handles login, once we've mapped 3pids onto userids
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username: the username, from the identifier dict
|
||||||
login_submission: the whole of the login submission
|
login_submission: the whole of the login submission
|
||||||
(including 'type' and other relevant fields)
|
(including 'type' and other relevant fields)
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -843,38 +991,18 @@ class AuthHandler(BaseHandler):
|
||||||
SynapseError if there was a problem with the request
|
SynapseError if there was a problem with the request
|
||||||
LoginError if there was an authentication problem.
|
LoginError if there was an authentication problem.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if username.startswith("@"):
|
if username.startswith("@"):
|
||||||
qualified_user_id = username
|
qualified_user_id = username
|
||||||
else:
|
else:
|
||||||
qualified_user_id = UserID(username, self.hs.hostname).to_string()
|
qualified_user_id = UserID(username, self.hs.hostname).to_string()
|
||||||
|
|
||||||
login_type = login_submission.get("type")
|
login_type = login_submission.get("type")
|
||||||
|
# we already checked that we have a valid login type
|
||||||
|
assert isinstance(login_type, str)
|
||||||
|
|
||||||
known_login_type = False
|
known_login_type = False
|
||||||
|
|
||||||
# special case to check for "password" for the check_password interface
|
|
||||||
# for the auth providers
|
|
||||||
password = login_submission.get("password")
|
|
||||||
|
|
||||||
if login_type == LoginType.PASSWORD:
|
|
||||||
if not self._password_enabled:
|
|
||||||
raise SynapseError(400, "Password login has been disabled.")
|
|
||||||
if not password:
|
|
||||||
raise SynapseError(400, "Missing parameter: password")
|
|
||||||
|
|
||||||
for provider in self.password_providers:
|
for provider in self.password_providers:
|
||||||
if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
|
|
||||||
known_login_type = True
|
|
||||||
is_valid = await provider.check_password(qualified_user_id, password)
|
|
||||||
if is_valid:
|
|
||||||
return qualified_user_id, None
|
|
||||||
|
|
||||||
if not hasattr(provider, "get_supported_login_types") or not hasattr(
|
|
||||||
provider, "check_auth"
|
|
||||||
):
|
|
||||||
# this password provider doesn't understand custom login types
|
|
||||||
continue
|
|
||||||
|
|
||||||
supported_login_types = provider.get_supported_login_types()
|
supported_login_types = provider.get_supported_login_types()
|
||||||
if login_type not in supported_login_types:
|
if login_type not in supported_login_types:
|
||||||
# this password provider doesn't understand this login type
|
# this password provider doesn't understand this login type
|
||||||
|
@ -899,15 +1027,17 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
result = await provider.check_auth(username, login_type, login_dict)
|
result = await provider.check_auth(username, login_type, login_dict)
|
||||||
if result:
|
if result:
|
||||||
if isinstance(result, str):
|
|
||||||
result = (result, None)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
|
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
|
||||||
known_login_type = True
|
known_login_type = True
|
||||||
|
|
||||||
|
# we've already checked that there is a (valid) password field
|
||||||
|
password = login_submission["password"]
|
||||||
|
assert isinstance(password, str)
|
||||||
|
|
||||||
canonical_user_id = await self._check_local_password(
|
canonical_user_id = await self._check_local_password(
|
||||||
qualified_user_id, password # type: ignore
|
qualified_user_id, password
|
||||||
)
|
)
|
||||||
|
|
||||||
if canonical_user_id:
|
if canonical_user_id:
|
||||||
|
@ -938,18 +1068,8 @@ class AuthHandler(BaseHandler):
|
||||||
unsuccessful, `user_id` and `callback` are both `None`.
|
unsuccessful, `user_id` and `callback` are both `None`.
|
||||||
"""
|
"""
|
||||||
for provider in self.password_providers:
|
for provider in self.password_providers:
|
||||||
if hasattr(provider, "check_3pid_auth"):
|
|
||||||
# This function is able to return a deferred that either
|
|
||||||
# resolves None, meaning authentication failure, or upon
|
|
||||||
# success, to a str (which is the user_id) or a tuple of
|
|
||||||
# (user_id, callback_func), where callback_func should be run
|
|
||||||
# after we've finished everything else
|
|
||||||
result = await provider.check_3pid_auth(medium, address, password)
|
result = await provider.check_3pid_auth(medium, address, password)
|
||||||
if result:
|
if result:
|
||||||
# Check if the return value is a str or a tuple
|
|
||||||
if isinstance(result, str):
|
|
||||||
# If it's a str, set callback function to None
|
|
||||||
result = (result, None)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return None, None
|
return None, None
|
||||||
|
@ -1008,16 +1128,11 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
# see if any of our auth providers want to know about this
|
# see if any of our auth providers want to know about this
|
||||||
for provider in self.password_providers:
|
for provider in self.password_providers:
|
||||||
if hasattr(provider, "on_logged_out"):
|
await provider.on_logged_out(
|
||||||
# This might return an awaitable, if it does block the log out
|
|
||||||
# until it completes.
|
|
||||||
result = provider.on_logged_out(
|
|
||||||
user_id=user_info.user_id,
|
user_id=user_info.user_id,
|
||||||
device_id=user_info.device_id,
|
device_id=user_info.device_id,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
if inspect.isawaitable(result):
|
|
||||||
await result
|
|
||||||
|
|
||||||
# delete pushers associated with this access token
|
# delete pushers associated with this access token
|
||||||
if user_info.token_id is not None:
|
if user_info.token_id is not None:
|
||||||
|
@ -1046,7 +1161,6 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
# see if any of our auth providers want to know about this
|
# see if any of our auth providers want to know about this
|
||||||
for provider in self.password_providers:
|
for provider in self.password_providers:
|
||||||
if hasattr(provider, "on_logged_out"):
|
|
||||||
for token, token_id, device_id in tokens_and_devices:
|
for token, token_id, device_id in tokens_and_devices:
|
||||||
await provider.on_logged_out(
|
await provider.on_logged_out(
|
||||||
user_id=user_id, device_id=device_id, access_token=token
|
user_id=user_id, device_id=device_id, access_token=token
|
||||||
|
@ -1374,3 +1488,127 @@ class MacaroonGenerator:
|
||||||
macaroon.add_first_party_caveat("gen = 1")
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||||
return macaroon
|
return macaroon
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordProvider:
|
||||||
|
"""Wrapper for a password auth provider module
|
||||||
|
|
||||||
|
This class abstracts out all of the backwards-compatibility hacks for
|
||||||
|
password providers, to provide a consistent interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
|
||||||
|
try:
|
||||||
|
pp = module(config=config, account_handler=module_api)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error while initializing %r: %s", module, e)
|
||||||
|
raise
|
||||||
|
return cls(pp, module_api)
|
||||||
|
|
||||||
|
def __init__(self, pp, module_api: ModuleApi):
|
||||||
|
self._pp = pp
|
||||||
|
self._module_api = module_api
|
||||||
|
|
||||||
|
self._supported_login_types = {}
|
||||||
|
|
||||||
|
# grandfather in check_password support
|
||||||
|
if hasattr(self._pp, "check_password"):
|
||||||
|
self._supported_login_types[LoginType.PASSWORD] = ("password",)
|
||||||
|
|
||||||
|
g = getattr(self._pp, "get_supported_login_types", None)
|
||||||
|
if g:
|
||||||
|
self._supported_login_types.update(g())
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return str(self._pp)
|
||||||
|
|
||||||
|
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
|
||||||
|
"""Get the login types supported by this password provider
|
||||||
|
|
||||||
|
Returns a map from a login type identifier (such as m.login.password) to an
|
||||||
|
iterable giving the fields which must be provided by the user in the submission
|
||||||
|
to the /login API.
|
||||||
|
|
||||||
|
This wrapper adds m.login.password to the list if the underlying password
|
||||||
|
provider supports the check_password() api.
|
||||||
|
"""
|
||||||
|
return self._supported_login_types
|
||||||
|
|
||||||
|
async def check_auth(
|
||||||
|
self, username: str, login_type: str, login_dict: JsonDict
|
||||||
|
) -> Optional[Tuple[str, Optional[Callable]]]:
|
||||||
|
"""Check if the user has presented valid login credentials
|
||||||
|
|
||||||
|
This wrapper also calls check_password() if the underlying password provider
|
||||||
|
supports the check_password() api and the login type is m.login.password.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username: user id presented by the client. Either an MXID or an unqualified
|
||||||
|
username.
|
||||||
|
|
||||||
|
login_type: the login type being attempted - one of the types returned by
|
||||||
|
get_supported_login_types()
|
||||||
|
|
||||||
|
login_dict: the dictionary of login secrets passed by the client.
|
||||||
|
|
||||||
|
Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
|
||||||
|
user, and `callback` is an optional callback which will be called with the
|
||||||
|
result from the /login call (including access_token, device_id, etc.)
|
||||||
|
"""
|
||||||
|
# first grandfather in a call to check_password
|
||||||
|
if login_type == LoginType.PASSWORD:
|
||||||
|
g = getattr(self._pp, "check_password", None)
|
||||||
|
if g:
|
||||||
|
qualified_user_id = self._module_api.get_qualified_user_id(username)
|
||||||
|
is_valid = await self._pp.check_password(
|
||||||
|
qualified_user_id, login_dict["password"]
|
||||||
|
)
|
||||||
|
if is_valid:
|
||||||
|
return qualified_user_id, None
|
||||||
|
|
||||||
|
g = getattr(self._pp, "check_auth", None)
|
||||||
|
if not g:
|
||||||
|
return None
|
||||||
|
result = await g(username, login_type, login_dict)
|
||||||
|
|
||||||
|
# Check if the return value is a str or a tuple
|
||||||
|
if isinstance(result, str):
|
||||||
|
# If it's a str, set callback function to None
|
||||||
|
return result, None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def check_3pid_auth(
|
||||||
|
self, medium: str, address: str, password: str
|
||||||
|
) -> Optional[Tuple[str, Optional[Callable]]]:
|
||||||
|
g = getattr(self._pp, "check_3pid_auth", None)
|
||||||
|
if not g:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# This function is able to return a deferred that either
|
||||||
|
# resolves None, meaning authentication failure, or upon
|
||||||
|
# success, to a str (which is the user_id) or a tuple of
|
||||||
|
# (user_id, callback_func), where callback_func should be run
|
||||||
|
# after we've finished everything else
|
||||||
|
result = await g(medium, address, password)
|
||||||
|
|
||||||
|
# Check if the return value is a str or a tuple
|
||||||
|
if isinstance(result, str):
|
||||||
|
# If it's a str, set callback function to None
|
||||||
|
return result, None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def on_logged_out(
|
||||||
|
self, user_id: str, device_id: Optional[str], access_token: str
|
||||||
|
) -> None:
|
||||||
|
g = getattr(self._pp, "on_logged_out", None)
|
||||||
|
if not g:
|
||||||
|
return
|
||||||
|
|
||||||
|
# This might return an awaitable, if it does block the log out
|
||||||
|
# until it completes.
|
||||||
|
result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
|
||||||
|
if inspect.isawaitable(result):
|
||||||
|
await result
|
||||||
|
|
|
@ -354,7 +354,8 @@ class IdentityHandler(BaseHandler):
|
||||||
raise SynapseError(500, "An error was encountered when sending the email")
|
raise SynapseError(500, "An error was encountered when sending the email")
|
||||||
|
|
||||||
token_expires = (
|
token_expires = (
|
||||||
self.hs.clock.time_msec() + self.hs.config.email_validation_token_lifetime
|
self.hs.get_clock().time_msec()
|
||||||
|
+ self.hs.config.email_validation_token_lifetime
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.store.start_or_continue_validation_session(
|
await self.store.start_or_continue_validation_session(
|
||||||
|
|
|
@ -39,7 +39,7 @@ from synapse.handlers._base import BaseHandler
|
||||||
from synapse.handlers.sso import MappingException, UserAttributes
|
from synapse.handlers.sso import MappingException, UserAttributes
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.types import JsonDict, map_username_to_mxid_localpart
|
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -898,13 +898,39 @@ class OidcHandler(BaseHandler):
|
||||||
|
|
||||||
return UserAttributes(**attributes)
|
return UserAttributes(**attributes)
|
||||||
|
|
||||||
|
async def grandfather_existing_users() -> Optional[str]:
|
||||||
|
if self._allow_existing_users:
|
||||||
|
# If allowing existing users we want to generate a single localpart
|
||||||
|
# and attempt to match it.
|
||||||
|
attributes = await oidc_response_to_user_attributes(failures=0)
|
||||||
|
|
||||||
|
user_id = UserID(attributes.localpart, self.server_name).to_string()
|
||||||
|
users = await self.store.get_users_by_id_case_insensitive(user_id)
|
||||||
|
if users:
|
||||||
|
# If an existing matrix ID is returned, then use it.
|
||||||
|
if len(users) == 1:
|
||||||
|
previously_registered_user_id = next(iter(users))
|
||||||
|
elif user_id in users:
|
||||||
|
previously_registered_user_id = user_id
|
||||||
|
else:
|
||||||
|
# Do not attempt to continue generating Matrix IDs.
|
||||||
|
raise MappingException(
|
||||||
|
"Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
|
||||||
|
user_id, users
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return previously_registered_user_id
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
return await self._sso_handler.get_mxid_from_sso(
|
return await self._sso_handler.get_mxid_from_sso(
|
||||||
self._auth_provider_id,
|
self._auth_provider_id,
|
||||||
remote_user_id,
|
remote_user_id,
|
||||||
user_agent,
|
user_agent,
|
||||||
ip_address,
|
ip_address,
|
||||||
oidc_response_to_user_attributes,
|
oidc_response_to_user_attributes,
|
||||||
self._allow_existing_users,
|
grandfather_existing_users,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -299,15 +299,20 @@ class PaginationHandler:
|
||||||
"""
|
"""
|
||||||
return self._purges_by_id.get(purge_id)
|
return self._purges_by_id.get(purge_id)
|
||||||
|
|
||||||
async def purge_room(self, room_id: str) -> None:
|
async def purge_room(self, room_id: str, force: bool = False) -> None:
|
||||||
"""Purge the given room from the database"""
|
"""Purge the given room from the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: room to be purged
|
||||||
|
force: set true to skip checking for joined users.
|
||||||
|
"""
|
||||||
with await self.pagination_lock.write(room_id):
|
with await self.pagination_lock.write(room_id):
|
||||||
# check we know about the room
|
# check we know about the room
|
||||||
await self.store.get_room_version_id(room_id)
|
await self.store.get_room_version_id(room_id)
|
||||||
|
|
||||||
# first check that we have no users in this room
|
# first check that we have no users in this room
|
||||||
|
if not force:
|
||||||
joined = await self.store.is_host_joined(room_id, self._server_name)
|
joined = await self.store.is_host_joined(room_id, self._server_name)
|
||||||
|
|
||||||
if joined:
|
if joined:
|
||||||
raise SynapseError(400, "Users are still joined to this room")
|
raise SynapseError(400, "Users are still joined to this room")
|
||||||
|
|
||||||
|
|
|
@ -366,7 +366,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
# later on.
|
# later on.
|
||||||
content = dict(content)
|
content = dict(content)
|
||||||
|
|
||||||
if not self.allow_per_room_profiles or requester.shadow_banned:
|
# allow the server notices mxid to set room-level profile
|
||||||
|
is_requester_server_notices_user = (
|
||||||
|
self._server_notices_mxid is not None
|
||||||
|
and requester.user.to_string() == self._server_notices_mxid
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
not self.allow_per_room_profiles and not is_requester_server_notices_user
|
||||||
|
) or requester.shadow_banned:
|
||||||
# Strip profile data, knowing that new profile data will be added to the
|
# Strip profile data, knowing that new profile data will be added to the
|
||||||
# event's content in event_creation_handler.create_event() using the target's
|
# event's content in event_creation_handler.create_event() using the target's
|
||||||
# global profile.
|
# global profile.
|
||||||
|
|
|
@ -265,10 +265,10 @@ class SamlHandler(BaseHandler):
|
||||||
return UserAttributes(
|
return UserAttributes(
|
||||||
localpart=result.get("mxid_localpart"),
|
localpart=result.get("mxid_localpart"),
|
||||||
display_name=result.get("displayname"),
|
display_name=result.get("displayname"),
|
||||||
emails=result.get("emails"),
|
emails=result.get("emails", []),
|
||||||
)
|
)
|
||||||
|
|
||||||
with (await self._mapping_lock.queue(self._auth_provider_id)):
|
async def grandfather_existing_users() -> Optional[str]:
|
||||||
# backwards-compatibility hack: see if there is an existing user with a
|
# backwards-compatibility hack: see if there is an existing user with a
|
||||||
# suitable mapping from the uid
|
# suitable mapping from the uid
|
||||||
if (
|
if (
|
||||||
|
@ -290,17 +290,18 @@ class SamlHandler(BaseHandler):
|
||||||
if users:
|
if users:
|
||||||
registered_user_id = list(users.keys())[0]
|
registered_user_id = list(users.keys())[0]
|
||||||
logger.info("Grandfathering mapping to %s", registered_user_id)
|
logger.info("Grandfathering mapping to %s", registered_user_id)
|
||||||
await self.store.record_user_external_id(
|
|
||||||
self._auth_provider_id, remote_user_id, registered_user_id
|
|
||||||
)
|
|
||||||
return registered_user_id
|
return registered_user_id
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
with (await self._mapping_lock.queue(self._auth_provider_id)):
|
||||||
return await self._sso_handler.get_mxid_from_sso(
|
return await self._sso_handler.get_mxid_from_sso(
|
||||||
self._auth_provider_id,
|
self._auth_provider_id,
|
||||||
remote_user_id,
|
remote_user_id,
|
||||||
user_agent,
|
user_agent,
|
||||||
ip_address,
|
ip_address,
|
||||||
saml_response_to_remapped_user_attributes,
|
saml_response_to_remapped_user_attributes,
|
||||||
|
grandfather_existing_users,
|
||||||
)
|
)
|
||||||
|
|
||||||
def expire_sessions(self):
|
def expire_sessions(self):
|
||||||
|
|
|
@ -116,7 +116,7 @@ class SsoHandler(BaseHandler):
|
||||||
user_agent: str,
|
user_agent: str,
|
||||||
ip_address: str,
|
ip_address: str,
|
||||||
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
|
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
|
||||||
allow_existing_users: bool = False,
|
grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Given an SSO ID, retrieve the user ID for it and possibly register the user.
|
Given an SSO ID, retrieve the user ID for it and possibly register the user.
|
||||||
|
@ -125,6 +125,10 @@ class SsoHandler(BaseHandler):
|
||||||
if it has that matrix ID is returned regardless of the current mapping
|
if it has that matrix ID is returned regardless of the current mapping
|
||||||
logic.
|
logic.
|
||||||
|
|
||||||
|
If a callable is provided for grandfathering users, it is called and can
|
||||||
|
potentially return a matrix ID to use. If it does, the SSO ID is linked to
|
||||||
|
this matrix ID for subsequent calls.
|
||||||
|
|
||||||
The mapping function is called (potentially multiple times) to generate
|
The mapping function is called (potentially multiple times) to generate
|
||||||
a localpart for the user.
|
a localpart for the user.
|
||||||
|
|
||||||
|
@ -132,17 +136,6 @@ class SsoHandler(BaseHandler):
|
||||||
given user-agent and IP address and the SSO ID is linked to this matrix
|
given user-agent and IP address and the SSO ID is linked to this matrix
|
||||||
ID for subsequent calls.
|
ID for subsequent calls.
|
||||||
|
|
||||||
If allow_existing_users is true the mapping function is only called once
|
|
||||||
and results in:
|
|
||||||
|
|
||||||
1. The use of a previously registered matrix ID. In this case, the
|
|
||||||
SSO ID is linked to the matrix ID. (Note it is possible that
|
|
||||||
other SSO IDs are linked to the same matrix ID.)
|
|
||||||
2. An unused localpart, in which case the user is registered (as
|
|
||||||
discussed above).
|
|
||||||
3. An error if the generated localpart matches multiple pre-existing
|
|
||||||
matrix IDs. Generally this should not happen.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
auth_provider_id: A unique identifier for this SSO provider, e.g.
|
auth_provider_id: A unique identifier for this SSO provider, e.g.
|
||||||
"oidc" or "saml".
|
"oidc" or "saml".
|
||||||
|
@ -152,8 +145,9 @@ class SsoHandler(BaseHandler):
|
||||||
sso_to_matrix_id_mapper: A callable to generate the user attributes.
|
sso_to_matrix_id_mapper: A callable to generate the user attributes.
|
||||||
The only parameter is an integer which represents the amount of
|
The only parameter is an integer which represents the amount of
|
||||||
times the returned mxid localpart mapping has failed.
|
times the returned mxid localpart mapping has failed.
|
||||||
allow_existing_users: True if the localpart returned from the
|
grandfather_existing_users: A callable which can return an previously
|
||||||
mapping provider can be linked to an existing matrix ID.
|
existing matrix ID. The SSO ID is then linked to the returned
|
||||||
|
matrix ID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The user ID associated with the SSO response.
|
The user ID associated with the SSO response.
|
||||||
|
@ -171,6 +165,16 @@ class SsoHandler(BaseHandler):
|
||||||
if previously_registered_user_id:
|
if previously_registered_user_id:
|
||||||
return previously_registered_user_id
|
return previously_registered_user_id
|
||||||
|
|
||||||
|
# Check for grandfathering of users.
|
||||||
|
if grandfather_existing_users:
|
||||||
|
previously_registered_user_id = await grandfather_existing_users()
|
||||||
|
if previously_registered_user_id:
|
||||||
|
# Future logins should also match this user ID.
|
||||||
|
await self.store.record_user_external_id(
|
||||||
|
auth_provider_id, remote_user_id, previously_registered_user_id
|
||||||
|
)
|
||||||
|
return previously_registered_user_id
|
||||||
|
|
||||||
# Otherwise, generate a new user.
|
# Otherwise, generate a new user.
|
||||||
for i in range(self._MAP_USERNAME_RETRIES):
|
for i in range(self._MAP_USERNAME_RETRIES):
|
||||||
try:
|
try:
|
||||||
|
@ -194,33 +198,7 @@ class SsoHandler(BaseHandler):
|
||||||
|
|
||||||
# Check if this mxid already exists
|
# Check if this mxid already exists
|
||||||
user_id = UserID(attributes.localpart, self.server_name).to_string()
|
user_id = UserID(attributes.localpart, self.server_name).to_string()
|
||||||
users = await self.store.get_users_by_id_case_insensitive(user_id)
|
if not await self.store.get_users_by_id_case_insensitive(user_id):
|
||||||
# Note, if allow_existing_users is true then the loop is guaranteed
|
|
||||||
# to end on the first iteration: either by matching an existing user,
|
|
||||||
# raising an error, or registering a new user. See the docstring for
|
|
||||||
# more in-depth an explanation.
|
|
||||||
if users and allow_existing_users:
|
|
||||||
# If an existing matrix ID is returned, then use it.
|
|
||||||
if len(users) == 1:
|
|
||||||
previously_registered_user_id = next(iter(users))
|
|
||||||
elif user_id in users:
|
|
||||||
previously_registered_user_id = user_id
|
|
||||||
else:
|
|
||||||
# Do not attempt to continue generating Matrix IDs.
|
|
||||||
raise MappingException(
|
|
||||||
"Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
|
|
||||||
user_id, users
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Future logins should also match this user ID.
|
|
||||||
await self.store.record_user_external_id(
|
|
||||||
auth_provider_id, remote_user_id, previously_registered_user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return previously_registered_user_id
|
|
||||||
|
|
||||||
elif not users:
|
|
||||||
# This mxid is free
|
# This mxid is free
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -25,7 +25,7 @@ from io import BytesIO
|
||||||
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
|
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json
|
from canonicaljson import iterencode_canonical_json
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
|
|
||||||
from twisted.internet import defer, interfaces
|
from twisted.internet import defer, interfaces
|
||||||
|
@ -94,11 +94,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
respond_with_json(
|
respond_with_json(
|
||||||
request,
|
request, error_code, error_dict, send_cors=True,
|
||||||
error_code,
|
|
||||||
error_dict,
|
|
||||||
send_cors=True,
|
|
||||||
pretty_print=_request_user_agent_is_curl(request),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -290,7 +286,6 @@ class DirectServeJsonResource(_AsyncResource):
|
||||||
code,
|
code,
|
||||||
response_object,
|
response_object,
|
||||||
send_cors=True,
|
send_cors=True,
|
||||||
pretty_print=_request_user_agent_is_curl(request),
|
|
||||||
canonical_json=self.canonical_json,
|
canonical_json=self.canonical_json,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -587,7 +582,6 @@ def respond_with_json(
|
||||||
code: int,
|
code: int,
|
||||||
json_object: Any,
|
json_object: Any,
|
||||||
send_cors: bool = False,
|
send_cors: bool = False,
|
||||||
pretty_print: bool = False,
|
|
||||||
canonical_json: bool = True,
|
canonical_json: bool = True,
|
||||||
):
|
):
|
||||||
"""Sends encoded JSON in response to the given request.
|
"""Sends encoded JSON in response to the given request.
|
||||||
|
@ -598,8 +592,6 @@ def respond_with_json(
|
||||||
json_object: The object to serialize to JSON.
|
json_object: The object to serialize to JSON.
|
||||||
send_cors: Whether to send Cross-Origin Resource Sharing headers
|
send_cors: Whether to send Cross-Origin Resource Sharing headers
|
||||||
https://fetch.spec.whatwg.org/#http-cors-protocol
|
https://fetch.spec.whatwg.org/#http-cors-protocol
|
||||||
pretty_print: Whether to include indentation and line-breaks in the
|
|
||||||
resulting JSON bytes.
|
|
||||||
canonical_json: Whether to use the canonicaljson algorithm when encoding
|
canonical_json: Whether to use the canonicaljson algorithm when encoding
|
||||||
the JSON bytes.
|
the JSON bytes.
|
||||||
|
|
||||||
|
@ -615,9 +607,6 @@ def respond_with_json(
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if pretty_print:
|
|
||||||
encoder = iterencode_pretty_printed_json
|
|
||||||
else:
|
|
||||||
if canonical_json:
|
if canonical_json:
|
||||||
encoder = iterencode_canonical_json
|
encoder = iterencode_canonical_json
|
||||||
else:
|
else:
|
||||||
|
@ -685,7 +674,7 @@ def set_cors_headers(request: Request):
|
||||||
)
|
)
|
||||||
request.setHeader(
|
request.setHeader(
|
||||||
b"Access-Control-Allow-Headers",
|
b"Access-Control-Allow-Headers",
|
||||||
b"Origin, X-Requested-With, Content-Type, Accept, Authorization",
|
b"Origin, X-Requested-With, Content-Type, Accept, Authorization, Date",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -759,11 +748,3 @@ def finish_request(request: Request):
|
||||||
request.finish()
|
request.finish()
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.info("Connection disconnected before response was written: %r", e)
|
logger.info("Connection disconnected before response was written: %r", e)
|
||||||
|
|
||||||
|
|
||||||
def _request_user_agent_is_curl(request: Request) -> bool:
|
|
||||||
user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[])
|
|
||||||
for user_agent in user_agents:
|
|
||||||
if b"curl" in user_agent:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
|
@ -75,6 +75,7 @@ class HttpPusher:
|
||||||
self.failing_since = pusherdict["failing_since"]
|
self.failing_since = pusherdict["failing_since"]
|
||||||
self.timed_call = None
|
self.timed_call = None
|
||||||
self._is_processing = False
|
self._is_processing = False
|
||||||
|
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
|
||||||
|
|
||||||
# This is the highest stream ordering we know it's safe to process.
|
# This is the highest stream ordering we know it's safe to process.
|
||||||
# When new events arrive, we'll be given a window of new events: we
|
# When new events arrive, we'll be given a window of new events: we
|
||||||
|
@ -140,7 +141,11 @@ class HttpPusher:
|
||||||
async def _update_badge(self):
|
async def _update_badge(self):
|
||||||
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
|
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
|
||||||
# to be largely redundant. perhaps we can remove it.
|
# to be largely redundant. perhaps we can remove it.
|
||||||
badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
|
badge = await push_tools.get_badge_count(
|
||||||
|
self.hs.get_datastore(),
|
||||||
|
self.user_id,
|
||||||
|
group_by_room=self._group_unread_count_by_room,
|
||||||
|
)
|
||||||
await self._send_badge(badge)
|
await self._send_badge(badge)
|
||||||
|
|
||||||
def on_timer(self):
|
def on_timer(self):
|
||||||
|
@ -287,7 +292,11 @@ class HttpPusher:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
|
tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
|
||||||
badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
|
badge = await push_tools.get_badge_count(
|
||||||
|
self.hs.get_datastore(),
|
||||||
|
self.user_id,
|
||||||
|
group_by_room=self._group_unread_count_by_room,
|
||||||
|
)
|
||||||
|
|
||||||
event = await self.store.get_event(push_action["event_id"], allow_none=True)
|
event = await self.store.get_event(push_action["event_id"], allow_none=True)
|
||||||
if event is None:
|
if event is None:
|
||||||
|
|
|
@ -12,12 +12,12 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
|
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
|
||||||
from synapse.storage import Storage
|
from synapse.storage import Storage
|
||||||
|
from synapse.storage.databases.main import DataStore
|
||||||
|
|
||||||
|
|
||||||
async def get_badge_count(store, user_id):
|
async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int:
|
||||||
invites = await store.get_invited_rooms_for_local_user(user_id)
|
invites = await store.get_invited_rooms_for_local_user(user_id)
|
||||||
joins = await store.get_rooms_for_user(user_id)
|
joins = await store.get_rooms_for_user(user_id)
|
||||||
|
|
||||||
|
@ -34,9 +34,15 @@ async def get_badge_count(store, user_id):
|
||||||
room_id, user_id, last_unread_event_id
|
room_id, user_id, last_unread_event_id
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# return one badge count per conversation, as count per
|
if notifs["notify_count"] == 0:
|
||||||
# message is so noisy as to be almost useless
|
continue
|
||||||
badge += 1 if notifs["notify_count"] else 0
|
|
||||||
|
if group_by_room:
|
||||||
|
# return one badge count per conversation
|
||||||
|
badge += 1
|
||||||
|
else:
|
||||||
|
# increment the badge count by the number of unread messages in the room
|
||||||
|
badge += notifs["notify_count"]
|
||||||
return badge
|
return badge
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,9 +12,10 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
|
||||||
|
from twisted.web.http import Request
|
||||||
|
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
|
@ -52,16 +53,23 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(
|
async def _serialize_payload( # type: ignore
|
||||||
requester, room_id, user_id, remote_room_hosts, content
|
requester: Requester,
|
||||||
):
|
room_id: str,
|
||||||
|
user_id: str,
|
||||||
|
remote_room_hosts: List[str],
|
||||||
|
content: JsonDict,
|
||||||
|
) -> JsonDict:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
requester(Requester)
|
requester: The user making the request according to the access token
|
||||||
room_id (str)
|
room_id: The ID of the room.
|
||||||
user_id (str)
|
user_id: The ID of the user.
|
||||||
remote_room_hosts (list[str]): Servers to try and join via
|
remote_room_hosts: Servers to try and join via
|
||||||
content(dict): The event content to use for the join event
|
content: The event content to use for the join event
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict representing the payload of the request.
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"requester": requester.serialize(),
|
"requester": requester.serialize(),
|
||||||
|
@ -69,7 +77,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_request(self, request, room_id, user_id):
|
async def _handle_request( # type: ignore
|
||||||
|
self, request: Request, room_id: str, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
remote_room_hosts = content["remote_room_hosts"]
|
remote_room_hosts = content["remote_room_hosts"]
|
||||||
|
@ -118,14 +128,17 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
|
||||||
txn_id: Optional[str],
|
txn_id: Optional[str],
|
||||||
requester: Requester,
|
requester: Requester,
|
||||||
content: JsonDict,
|
content: JsonDict,
|
||||||
):
|
) -> JsonDict:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
invite_event_id: ID of the invite to be rejected
|
invite_event_id: The ID of the invite to be rejected.
|
||||||
txn_id: optional transaction ID supplied by the client
|
txn_id: Optional transaction ID supplied by the client
|
||||||
requester: user making the rejection request, according to the access token
|
requester: User making the rejection request, according to the access token
|
||||||
content: additional content to include in the rejection event.
|
content: Additional content to include in the rejection event.
|
||||||
Normally an empty dict.
|
Normally an empty dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict representing the payload of the request.
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"txn_id": txn_id,
|
"txn_id": txn_id,
|
||||||
|
@ -133,7 +146,9 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_request(self, request, invite_event_id):
|
async def _handle_request( # type: ignore
|
||||||
|
self, request: Request, invite_event_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
txn_id = content["txn_id"]
|
txn_id = content["txn_id"]
|
||||||
|
@ -174,18 +189,25 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
|
||||||
self.distributor = hs.get_distributor()
|
self.distributor = hs.get_distributor()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(room_id, user_id, change):
|
async def _serialize_payload( # type: ignore
|
||||||
|
room_id: str, user_id: str, change: str
|
||||||
|
) -> JsonDict:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
room_id (str)
|
room_id: The ID of the room.
|
||||||
user_id (str)
|
user_id: The ID of the user.
|
||||||
change (str): "left"
|
change: "left"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict representing the payload of the request.
|
||||||
"""
|
"""
|
||||||
assert change == "left"
|
assert change == "left"
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _handle_request(self, request, room_id, user_id, change):
|
def _handle_request( # type: ignore
|
||||||
|
self, request: Request, room_id: str, user_id: str, change: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
logger.info("user membership change: %s in %s", user_id, room_id)
|
logger.info("user membership change: %s in %s", user_id, room_id)
|
||||||
|
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
|
@ -70,14 +70,18 @@ class ShutdownRoomRestServlet(RestServlet):
|
||||||
|
|
||||||
|
|
||||||
class DeleteRoomRestServlet(RestServlet):
|
class DeleteRoomRestServlet(RestServlet):
|
||||||
"""Delete a room from server. It is a combination and improvement of
|
"""Delete a room from server.
|
||||||
shut down and purge room.
|
|
||||||
|
It is a combination and improvement of shutdown and purge room.
|
||||||
|
|
||||||
Shuts down a room by removing all local users from the room.
|
Shuts down a room by removing all local users from the room.
|
||||||
Blocking all future invites and joins to the room is optional.
|
Blocking all future invites and joins to the room is optional.
|
||||||
|
|
||||||
If desired any local aliases will be repointed to a new room
|
If desired any local aliases will be repointed to a new room
|
||||||
created by `new_room_user_id` and kicked users will be auto
|
created by `new_room_user_id` and kicked users will be auto-
|
||||||
joined to the new room.
|
joined to the new room.
|
||||||
It will remove all trace of a room from the database.
|
|
||||||
|
If 'purge' is true, it will remove all traces of a room from the database.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
|
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
|
||||||
|
@ -110,6 +114,14 @@ class DeleteRoomRestServlet(RestServlet):
|
||||||
Codes.BAD_JSON,
|
Codes.BAD_JSON,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
force_purge = content.get("force_purge", False)
|
||||||
|
if not isinstance(force_purge, bool):
|
||||||
|
raise SynapseError(
|
||||||
|
HTTPStatus.BAD_REQUEST,
|
||||||
|
"Param 'force_purge' must be a boolean, if given",
|
||||||
|
Codes.BAD_JSON,
|
||||||
|
)
|
||||||
|
|
||||||
ret = await self.room_shutdown_handler.shutdown_room(
|
ret = await self.room_shutdown_handler.shutdown_room(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
new_room_user_id=content.get("new_room_user_id"),
|
new_room_user_id=content.get("new_room_user_id"),
|
||||||
|
@ -121,7 +133,7 @@ class DeleteRoomRestServlet(RestServlet):
|
||||||
|
|
||||||
# Purge room
|
# Purge room
|
||||||
if purge:
|
if purge:
|
||||||
await self.pagination_handler.purge_room(room_id)
|
await self.pagination_handler.purge_room(room_id, force=force_purge)
|
||||||
|
|
||||||
return (200, ret)
|
return (200, ret)
|
||||||
|
|
||||||
|
|
|
@ -19,10 +19,6 @@ from typing import Awaitable, Callable, Dict, Optional
|
||||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||||
from synapse.api.ratelimiting import Ratelimiter
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.handlers.auth import (
|
|
||||||
convert_client_dict_legacy_fields_to_identifier,
|
|
||||||
login_id_phone_to_thirdparty,
|
|
||||||
)
|
|
||||||
from synapse.http.server import finish_request
|
from synapse.http.server import finish_request
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
RestServlet,
|
RestServlet,
|
||||||
|
@ -33,7 +29,6 @@ from synapse.http.site import SynapseRequest
|
||||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||||
from synapse.rest.well_known import WellKnownBuilder
|
from synapse.rest.well_known import WellKnownBuilder
|
||||||
from synapse.types import JsonDict, UserID
|
from synapse.types import JsonDict, UserID
|
||||||
from synapse.util.threepids import canonicalise_email
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -78,11 +73,6 @@ class LoginRestServlet(RestServlet):
|
||||||
rate_hz=self.hs.config.rc_login_account.per_second,
|
rate_hz=self.hs.config.rc_login_account.per_second,
|
||||||
burst_count=self.hs.config.rc_login_account.burst_count,
|
burst_count=self.hs.config.rc_login_account.burst_count,
|
||||||
)
|
)
|
||||||
self._failed_attempts_ratelimiter = Ratelimiter(
|
|
||||||
clock=hs.get_clock(),
|
|
||||||
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
|
||||||
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_GET(self, request: SynapseRequest):
|
def on_GET(self, request: SynapseRequest):
|
||||||
flows = []
|
flows = []
|
||||||
|
@ -140,27 +130,31 @@ class LoginRestServlet(RestServlet):
|
||||||
result["well_known"] = well_known_data
|
result["well_known"] = well_known_data
|
||||||
return 200, result
|
return 200, result
|
||||||
|
|
||||||
def _get_qualified_user_id(self, identifier):
|
|
||||||
if identifier["type"] != "m.id.user":
|
|
||||||
raise SynapseError(400, "Unknown login identifier type")
|
|
||||||
if "user" not in identifier:
|
|
||||||
raise SynapseError(400, "User identifier is missing 'user' key")
|
|
||||||
|
|
||||||
if identifier["user"].startswith("@"):
|
|
||||||
return identifier["user"]
|
|
||||||
else:
|
|
||||||
return UserID(identifier["user"], self.hs.hostname).to_string()
|
|
||||||
|
|
||||||
async def _do_appservice_login(
|
async def _do_appservice_login(
|
||||||
self, login_submission: JsonDict, appservice: ApplicationService
|
self, login_submission: JsonDict, appservice: ApplicationService
|
||||||
):
|
):
|
||||||
logger.info(
|
identifier = login_submission.get("identifier")
|
||||||
"Got appservice login request with identifier: %r",
|
logger.info("Got appservice login request with identifier: %r", identifier)
|
||||||
login_submission.get("identifier"),
|
|
||||||
|
if not isinstance(identifier, dict):
|
||||||
|
raise SynapseError(
|
||||||
|
400, "Invalid identifier in login submission", Codes.INVALID_PARAM
|
||||||
)
|
)
|
||||||
|
|
||||||
identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
|
# this login flow only supports identifiers of type "m.id.user".
|
||||||
qualified_user_id = self._get_qualified_user_id(identifier)
|
if identifier.get("type") != "m.id.user":
|
||||||
|
raise SynapseError(
|
||||||
|
400, "Unknown login identifier type", Codes.INVALID_PARAM
|
||||||
|
)
|
||||||
|
|
||||||
|
user = identifier.get("user")
|
||||||
|
if not isinstance(user, str):
|
||||||
|
raise SynapseError(400, "Invalid user in identifier", Codes.INVALID_PARAM)
|
||||||
|
|
||||||
|
if user.startswith("@"):
|
||||||
|
qualified_user_id = user
|
||||||
|
else:
|
||||||
|
qualified_user_id = UserID(user, self.hs.hostname).to_string()
|
||||||
|
|
||||||
if not appservice.is_interested_in_user(qualified_user_id):
|
if not appservice.is_interested_in_user(qualified_user_id):
|
||||||
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
|
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
|
||||||
|
@ -186,91 +180,9 @@ class LoginRestServlet(RestServlet):
|
||||||
login_submission.get("address"),
|
login_submission.get("address"),
|
||||||
login_submission.get("user"),
|
login_submission.get("user"),
|
||||||
)
|
)
|
||||||
identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
|
|
||||||
|
|
||||||
# convert phone type identifiers to generic threepids
|
|
||||||
if identifier["type"] == "m.id.phone":
|
|
||||||
identifier = login_id_phone_to_thirdparty(identifier)
|
|
||||||
|
|
||||||
# convert threepid identifiers to user IDs
|
|
||||||
if identifier["type"] == "m.id.thirdparty":
|
|
||||||
address = identifier.get("address")
|
|
||||||
medium = identifier.get("medium")
|
|
||||||
|
|
||||||
if medium is None or address is None:
|
|
||||||
raise SynapseError(400, "Invalid thirdparty identifier")
|
|
||||||
|
|
||||||
# For emails, canonicalise the address.
|
|
||||||
# We store all email addresses canonicalised in the DB.
|
|
||||||
# (See add_threepid in synapse/handlers/auth.py)
|
|
||||||
if medium == "email":
|
|
||||||
try:
|
|
||||||
address = canonicalise_email(address)
|
|
||||||
except ValueError as e:
|
|
||||||
raise SynapseError(400, str(e))
|
|
||||||
|
|
||||||
# We also apply account rate limiting using the 3PID as a key, as
|
|
||||||
# otherwise using 3PID bypasses the ratelimiting based on user ID.
|
|
||||||
self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False)
|
|
||||||
|
|
||||||
# Check for login providers that support 3pid login types
|
|
||||||
(
|
|
||||||
canonical_user_id,
|
|
||||||
callback_3pid,
|
|
||||||
) = await self.auth_handler.check_password_provider_3pid(
|
|
||||||
medium, address, login_submission["password"]
|
|
||||||
)
|
|
||||||
if canonical_user_id:
|
|
||||||
# Authentication through password provider and 3pid succeeded
|
|
||||||
|
|
||||||
result = await self._complete_login(
|
|
||||||
canonical_user_id, login_submission, callback_3pid
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
# No password providers were able to handle this 3pid
|
|
||||||
# Check local store
|
|
||||||
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
|
|
||||||
medium, address
|
|
||||||
)
|
|
||||||
if not user_id:
|
|
||||||
logger.warning(
|
|
||||||
"unknown 3pid identifier medium %s, address %r", medium, address
|
|
||||||
)
|
|
||||||
# We mark that we've failed to log in here, as
|
|
||||||
# `check_password_provider_3pid` might have returned `None` due
|
|
||||||
# to an incorrect password, rather than the account not
|
|
||||||
# existing.
|
|
||||||
#
|
|
||||||
# If it returned None but the 3PID was bound then we won't hit
|
|
||||||
# this code path, which is fine as then the per-user ratelimit
|
|
||||||
# will kick in below.
|
|
||||||
self._failed_attempts_ratelimiter.can_do_action((medium, address))
|
|
||||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
|
||||||
|
|
||||||
identifier = {"type": "m.id.user", "user": user_id}
|
|
||||||
|
|
||||||
# by this point, the identifier should be an m.id.user: if it's anything
|
|
||||||
# else, we haven't understood it.
|
|
||||||
qualified_user_id = self._get_qualified_user_id(identifier)
|
|
||||||
|
|
||||||
# Check if we've hit the failed ratelimit (but don't update it)
|
|
||||||
self._failed_attempts_ratelimiter.ratelimit(
|
|
||||||
qualified_user_id.lower(), update=False
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
canonical_user_id, callback = await self.auth_handler.validate_login(
|
canonical_user_id, callback = await self.auth_handler.validate_login(
|
||||||
identifier["user"], login_submission
|
login_submission, ratelimit=True
|
||||||
)
|
)
|
||||||
except LoginError:
|
|
||||||
# The user has failed to log in, so we need to update the rate
|
|
||||||
# limiter. Using `can_do_action` avoids us raising a ratelimit
|
|
||||||
# exception and masking the LoginError. The actual ratelimiting
|
|
||||||
# should have happened above.
|
|
||||||
self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower())
|
|
||||||
raise
|
|
||||||
|
|
||||||
result = await self._complete_login(
|
result = await self._complete_login(
|
||||||
canonical_user_id, login_submission, callback
|
canonical_user_id, login_submission, callback
|
||||||
)
|
)
|
||||||
|
|
|
@ -115,7 +115,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
||||||
# comments for request_token_inhibit_3pid_errors.
|
# comments for request_token_inhibit_3pid_errors.
|
||||||
# Also wait for some random amount of time between 100ms and 1s to make it
|
# Also wait for some random amount of time between 100ms and 1s to make it
|
||||||
# look like we did something.
|
# look like we did something.
|
||||||
await self.hs.clock.sleep(random.randint(1, 10) / 10)
|
await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
|
||||||
return 200, {"sid": random_string(16)}
|
return 200, {"sid": random_string(16)}
|
||||||
|
|
||||||
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
|
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
|
||||||
|
@ -387,7 +387,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
|
||||||
# comments for request_token_inhibit_3pid_errors.
|
# comments for request_token_inhibit_3pid_errors.
|
||||||
# Also wait for some random amount of time between 100ms and 1s to make it
|
# Also wait for some random amount of time between 100ms and 1s to make it
|
||||||
# look like we did something.
|
# look like we did something.
|
||||||
await self.hs.clock.sleep(random.randint(1, 10) / 10)
|
await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
|
||||||
return 200, {"sid": random_string(16)}
|
return 200, {"sid": random_string(16)}
|
||||||
|
|
||||||
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
||||||
|
@ -466,7 +466,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
||||||
# comments for request_token_inhibit_3pid_errors.
|
# comments for request_token_inhibit_3pid_errors.
|
||||||
# Also wait for some random amount of time between 100ms and 1s to make it
|
# Also wait for some random amount of time between 100ms and 1s to make it
|
||||||
# look like we did something.
|
# look like we did something.
|
||||||
await self.hs.clock.sleep(random.randint(1, 10) / 10)
|
await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
|
||||||
return 200, {"sid": random_string(16)}
|
return 200, {"sid": random_string(16)}
|
||||||
|
|
||||||
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
|
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
|
||||||
|
|
|
@ -135,7 +135,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
|
||||||
# comments for request_token_inhibit_3pid_errors.
|
# comments for request_token_inhibit_3pid_errors.
|
||||||
# Also wait for some random amount of time between 100ms and 1s to make it
|
# Also wait for some random amount of time between 100ms and 1s to make it
|
||||||
# look like we did something.
|
# look like we did something.
|
||||||
await self.hs.clock.sleep(random.randint(1, 10) / 10)
|
await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
|
||||||
return 200, {"sid": random_string(16)}
|
return 200, {"sid": random_string(16)}
|
||||||
|
|
||||||
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
||||||
|
@ -214,7 +214,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
|
||||||
# comments for request_token_inhibit_3pid_errors.
|
# comments for request_token_inhibit_3pid_errors.
|
||||||
# Also wait for some random amount of time between 100ms and 1s to make it
|
# Also wait for some random amount of time between 100ms and 1s to make it
|
||||||
# look like we did something.
|
# look like we did something.
|
||||||
await self.hs.clock.sleep(random.randint(1, 10) / 10)
|
await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
|
||||||
return 200, {"sid": random_string(16)}
|
return 200, {"sid": random_string(16)}
|
||||||
|
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
|
|
|
@ -66,7 +66,7 @@ class LocalKey(Resource):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.config = hs.config
|
self.config = hs.config
|
||||||
self.clock = hs.clock
|
self.clock = hs.get_clock()
|
||||||
self.update_response_body(self.clock.time_msec())
|
self.update_response_body(self.clock.time_msec())
|
||||||
Resource.__init__(self)
|
Resource.__init__(self)
|
||||||
|
|
||||||
|
|
|
@ -147,7 +147,8 @@ def cache_in_self(builder: T) -> T:
|
||||||
"@cache_in_self can only be used on functions starting with `get_`"
|
"@cache_in_self can only be used on functions starting with `get_`"
|
||||||
)
|
)
|
||||||
|
|
||||||
depname = builder.__name__[len("get_") :]
|
# get_attr -> _attr
|
||||||
|
depname = builder.__name__[len("get") :]
|
||||||
|
|
||||||
building = [False]
|
building = [False]
|
||||||
|
|
||||||
|
@ -235,15 +236,6 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
self._instance_id = random_string(5)
|
self._instance_id = random_string(5)
|
||||||
self._instance_name = config.worker_name or "master"
|
self._instance_name = config.worker_name or "master"
|
||||||
|
|
||||||
self.clock = Clock(reactor)
|
|
||||||
self.distributor = Distributor()
|
|
||||||
|
|
||||||
self.registration_ratelimiter = Ratelimiter(
|
|
||||||
clock=self.clock,
|
|
||||||
rate_hz=config.rc_registration.per_second,
|
|
||||||
burst_count=config.rc_registration.burst_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.version_string = version_string
|
self.version_string = version_string
|
||||||
|
|
||||||
self.datastores = None # type: Optional[Databases]
|
self.datastores = None # type: Optional[Databases]
|
||||||
|
@ -301,8 +293,9 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
def is_mine_id(self, string: str) -> bool:
|
def is_mine_id(self, string: str) -> bool:
|
||||||
return string.split(":", 1)[1] == self.hostname
|
return string.split(":", 1)[1] == self.hostname
|
||||||
|
|
||||||
|
@cache_in_self
|
||||||
def get_clock(self) -> Clock:
|
def get_clock(self) -> Clock:
|
||||||
return self.clock
|
return Clock(self._reactor)
|
||||||
|
|
||||||
def get_datastore(self) -> DataStore:
|
def get_datastore(self) -> DataStore:
|
||||||
if not self.datastores:
|
if not self.datastores:
|
||||||
|
@ -319,11 +312,17 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
def get_config(self) -> HomeServerConfig:
|
def get_config(self) -> HomeServerConfig:
|
||||||
return self.config
|
return self.config
|
||||||
|
|
||||||
|
@cache_in_self
|
||||||
def get_distributor(self) -> Distributor:
|
def get_distributor(self) -> Distributor:
|
||||||
return self.distributor
|
return Distributor()
|
||||||
|
|
||||||
|
@cache_in_self
|
||||||
def get_registration_ratelimiter(self) -> Ratelimiter:
|
def get_registration_ratelimiter(self) -> Ratelimiter:
|
||||||
return self.registration_ratelimiter
|
return Ratelimiter(
|
||||||
|
clock=self.get_clock(),
|
||||||
|
rate_hz=self.config.rc_registration.per_second,
|
||||||
|
burst_count=self.config.rc_registration.burst_count,
|
||||||
|
)
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_federation_client(self) -> FederationClient:
|
def get_federation_client(self) -> FederationClient:
|
||||||
|
@ -687,7 +686,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_federation_ratelimiter(self) -> FederationRateLimiter:
|
def get_federation_ratelimiter(self) -> FederationRateLimiter:
|
||||||
return FederationRateLimiter(self.clock, config=self.config.rc_federation)
|
return FederationRateLimiter(self.get_clock(), config=self.config.rc_federation)
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_module_api(self) -> ModuleApi:
|
def get_module_api(self) -> ModuleApi:
|
||||||
|
|
|
@ -314,6 +314,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
||||||
for table in (
|
for table in (
|
||||||
"event_auth",
|
"event_auth",
|
||||||
"event_edges",
|
"event_edges",
|
||||||
|
"event_json",
|
||||||
"event_push_actions_staging",
|
"event_push_actions_staging",
|
||||||
"event_reference_hashes",
|
"event_reference_hashes",
|
||||||
"event_relations",
|
"event_relations",
|
||||||
|
@ -340,7 +341,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
||||||
"destination_rooms",
|
"destination_rooms",
|
||||||
"event_backward_extremities",
|
"event_backward_extremities",
|
||||||
"event_forward_extremities",
|
"event_forward_extremities",
|
||||||
"event_json",
|
|
||||||
"event_push_actions",
|
"event_push_actions",
|
||||||
"event_search",
|
"event_search",
|
||||||
"events",
|
"events",
|
||||||
|
|
|
@ -20,14 +20,14 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
-- add new index that includes method to local media
|
-- add new index that includes method to local media
|
||||||
INSERT INTO background_updates (update_name, progress_json) VALUES
|
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||||
('local_media_repository_thumbnails_method_idx', '{}');
|
(5807, 'local_media_repository_thumbnails_method_idx', '{}');
|
||||||
|
|
||||||
-- add new index that includes method to remote media
|
-- add new index that includes method to remote media
|
||||||
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
|
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
|
||||||
('remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
|
(5807, 'remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
|
||||||
|
|
||||||
-- drop old index
|
-- drop old index
|
||||||
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
|
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
|
||||||
('media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
|
(5807, 'media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
|
||||||
|
|
||||||
|
|
|
@ -28,5 +28,5 @@
|
||||||
-- functionality as the old one. This effectively restarts the background job
|
-- functionality as the old one. This effectively restarts the background job
|
||||||
-- from the beginning, without running it twice in a row, supporting both
|
-- from the beginning, without running it twice in a row, supporting both
|
||||||
-- upgrade usecases.
|
-- upgrade usecases.
|
||||||
INSERT INTO background_updates (update_name, progress_json) VALUES
|
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||||
('populate_stats_process_rooms_2', '{}');
|
(5812, 'populate_stats_process_rooms_2', '{}');
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
INSERT INTO background_updates (update_name, progress_json) VALUES
|
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||||
('users_have_local_media', '{}');
|
(5822, 'users_have_local_media', '{}');
|
||||||
|
|
|
@ -13,5 +13,5 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
INSERT INTO background_updates (update_name, progress_json) VALUES
|
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||||
('e2e_cross_signing_keys_idx', '{}');
|
(5823, 'e2e_cross_signing_keys_idx', '{}');
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- this index is essentially redundant. The only time it was ever used was when purging
|
||||||
|
-- rooms - and Synapse 1.24 will change that.
|
||||||
|
|
||||||
|
DROP INDEX IF EXISTS event_json_room_id;
|
|
@ -52,7 +52,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
self.fail("some_user was not in %s" % macaroon.inspect())
|
self.fail("some_user was not in %s" % macaroon.inspect())
|
||||||
|
|
||||||
def test_macaroon_caveats(self):
|
def test_macaroon_caveats(self):
|
||||||
self.hs.clock.now = 5000
|
self.hs.get_clock().now = 5000
|
||||||
|
|
||||||
token = self.macaroon_generator.generate_access_token("a_user")
|
token = self.macaroon_generator.generate_access_token("a_user")
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
|
@ -78,7 +78,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_short_term_login_token_gives_user_id(self):
|
def test_short_term_login_token_gives_user_id(self):
|
||||||
self.hs.clock.now = 1000
|
self.hs.get_clock().now = 1000
|
||||||
|
|
||||||
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
|
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
|
||||||
user_id = yield defer.ensureDeferred(
|
user_id = yield defer.ensureDeferred(
|
||||||
|
@ -87,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
self.assertEqual("a_user", user_id)
|
self.assertEqual("a_user", user_id)
|
||||||
|
|
||||||
# when we advance the clock, the token should be rejected
|
# when we advance the clock, the token should be rejected
|
||||||
self.hs.clock.now = 6000
|
self.hs.get_clock().now = 6000
|
||||||
with self.assertRaises(synapse.api.errors.AuthError):
|
with self.assertRaises(synapse.api.errors.AuthError):
|
||||||
yield defer.ensureDeferred(
|
yield defer.ensureDeferred(
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||||
|
|
|
@ -23,7 +23,7 @@ import pymacaroons
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.web._newclient import ResponseDone
|
from twisted.web._newclient import ResponseDone
|
||||||
|
|
||||||
from synapse.handlers.oidc_handler import OidcError, OidcHandler, OidcMappingProvider
|
from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
|
||||||
from synapse.handlers.sso import MappingException
|
from synapse.handlers.sso import MappingException
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
|
@ -127,13 +127,8 @@ async def get_json(url):
|
||||||
|
|
||||||
|
|
||||||
class OidcHandlerTestCase(HomeserverTestCase):
|
class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor, clock):
|
def default_config(self):
|
||||||
|
config = super().default_config()
|
||||||
self.http_client = Mock(spec=["get_json"])
|
|
||||||
self.http_client.get_json.side_effect = get_json
|
|
||||||
self.http_client.user_agent = "Synapse Test"
|
|
||||||
|
|
||||||
config = self.default_config()
|
|
||||||
config["public_baseurl"] = BASE_URL
|
config["public_baseurl"] = BASE_URL
|
||||||
oidc_config = {
|
oidc_config = {
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
|
@ -149,19 +144,24 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
oidc_config.update(config.get("oidc_config", {}))
|
oidc_config.update(config.get("oidc_config", {}))
|
||||||
config["oidc_config"] = oidc_config
|
config["oidc_config"] = oidc_config
|
||||||
|
|
||||||
hs = self.setup_test_homeserver(
|
return config
|
||||||
http_client=self.http_client,
|
|
||||||
proxied_http_client=self.http_client,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.handler = OidcHandler(hs)
|
def make_homeserver(self, reactor, clock):
|
||||||
|
|
||||||
|
self.http_client = Mock(spec=["get_json"])
|
||||||
|
self.http_client.get_json.side_effect = get_json
|
||||||
|
self.http_client.user_agent = "Synapse Test"
|
||||||
|
|
||||||
|
hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
|
||||||
|
|
||||||
|
self.handler = hs.get_oidc_handler()
|
||||||
|
sso_handler = hs.get_sso_handler()
|
||||||
# Mock the render error method.
|
# Mock the render error method.
|
||||||
self.render_error = Mock(return_value=None)
|
self.render_error = Mock(return_value=None)
|
||||||
self.handler._sso_handler.render_error = self.render_error
|
sso_handler.render_error = self.render_error
|
||||||
|
|
||||||
# Reduce the number of attempts when generating MXIDs.
|
# Reduce the number of attempts when generating MXIDs.
|
||||||
self.handler._sso_handler._MAP_USERNAME_RETRIES = 3
|
sso_handler._MAP_USERNAME_RETRIES = 3
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
|
@ -731,6 +731,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(mxid, "@test_user:test")
|
self.assertEqual(mxid, "@test_user:test")
|
||||||
|
|
||||||
|
# Subsequent calls should map to the same mxid.
|
||||||
|
mxid = self.get_success(
|
||||||
|
self.handler._map_userinfo_to_user(
|
||||||
|
userinfo, token, "user-agent", "10.10.10.10"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(mxid, "@test_user:test")
|
||||||
|
|
||||||
# Note that a second SSO user can be mapped to the same Matrix ID. (This
|
# Note that a second SSO user can be mapped to the same Matrix ID. (This
|
||||||
# requires a unique sub, but something that maps to the same matrix ID,
|
# requires a unique sub, but something that maps to the same matrix ID,
|
||||||
# in this case we'll just use the same username. A more realistic example
|
# in this case we'll just use the same username. A more realistic example
|
||||||
|
@ -832,7 +840,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
# test_user is already taken, so test_user1 gets registered instead.
|
# test_user is already taken, so test_user1 gets registered instead.
|
||||||
self.assertEqual(mxid, "@test_user1:test")
|
self.assertEqual(mxid, "@test_user1:test")
|
||||||
|
|
||||||
# Register all of the potential users for a particular username.
|
# Register all of the potential mxids for a particular OIDC username.
|
||||||
self.get_success(
|
self.get_success(
|
||||||
store.register_user(user_id="@tester:test", password_hash=None)
|
store.register_user(user_id="@tester:test", password_hash=None)
|
||||||
)
|
)
|
||||||
|
|
580
tests/handlers/test_password_providers.py
Normal file
580
tests/handlers/test_password_providers.py
Normal file
|
@ -0,0 +1,580 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Tests for the password_auth_provider interface"""
|
||||||
|
|
||||||
|
from typing import Any, Type, Union
|
||||||
|
|
||||||
|
from mock import Mock
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import synapse
|
||||||
|
from synapse.rest.client.v1 import login
|
||||||
|
from synapse.rest.client.v2_alpha import devices
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
from tests import unittest
|
||||||
|
from tests.server import FakeChannel
|
||||||
|
from tests.unittest import override_config
|
||||||
|
|
||||||
|
# (possibly experimental) login flows we expect to appear in the list after the normal
|
||||||
|
# ones
|
||||||
|
ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
|
||||||
|
|
||||||
|
# a mock instance which the dummy auth providers delegate to, so we can see what's going
|
||||||
|
# on
|
||||||
|
mock_password_provider = Mock()
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordOnlyAuthProvider:
|
||||||
|
"""A password_provider which only implements `check_password`."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_config(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __init__(self, config, account_handler):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def check_password(self, *args):
|
||||||
|
return mock_password_provider.check_password(*args)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomAuthProvider:
|
||||||
|
"""A password_provider which implements a custom login type."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_config(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __init__(self, config, account_handler):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_supported_login_types(self):
|
||||||
|
return {"test.login_type": ["test_field"]}
|
||||||
|
|
||||||
|
def check_auth(self, *args):
|
||||||
|
return mock_password_provider.check_auth(*args)
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordCustomAuthProvider:
|
||||||
|
"""A password_provider which implements password login via `check_auth`, as well
|
||||||
|
as a custom type."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_config(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __init__(self, config, account_handler):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_supported_login_types(self):
|
||||||
|
return {"m.login.password": ["password"], "test.login_type": ["test_field"]}
|
||||||
|
|
||||||
|
def check_auth(self, *args):
|
||||||
|
return mock_password_provider.check_auth(*args)
|
||||||
|
|
||||||
|
|
||||||
|
def providers_config(*providers: Type[Any]) -> dict:
|
||||||
|
"""Returns a config dict that will enable the given password auth providers"""
|
||||||
|
return {
|
||||||
|
"password_providers": [
|
||||||
|
{"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
|
||||||
|
for provider in providers
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
|
servlets = [
|
||||||
|
synapse.rest.admin.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
devices.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# we use a global mock device, so make sure we are starting with a clean slate
|
||||||
|
mock_password_provider.reset_mock()
|
||||||
|
super().setUp()
|
||||||
|
|
||||||
|
@override_config(providers_config(PasswordOnlyAuthProvider))
|
||||||
|
def test_password_only_auth_provider_login(self):
|
||||||
|
# login flows should only have m.login.password
|
||||||
|
flows = self._get_login_flows()
|
||||||
|
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
|
||||||
|
|
||||||
|
# check_password must return an awaitable
|
||||||
|
mock_password_provider.check_password.return_value = defer.succeed(True)
|
||||||
|
channel = self._send_password_login("u", "p")
|
||||||
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
self.assertEqual("@u:test", channel.json_body["user_id"])
|
||||||
|
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
|
||||||
|
mock_password_provider.reset_mock()
|
||||||
|
|
||||||
|
# login with mxid should work too
|
||||||
|
channel = self._send_password_login("@u:bz", "p")
|
||||||
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
self.assertEqual("@u:bz", channel.json_body["user_id"])
|
||||||
|
mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
|
||||||
|
mock_password_provider.reset_mock()
|
||||||
|
|
||||||
|
# try a weird username / pass. Honestly it's unclear what we *expect* to happen
|
||||||
|
# in these cases, but at least we can guard against the API changing
|
||||||
|
# unexpectedly
|
||||||
|
channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
|
||||||
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
|
||||||
|
mock_password_provider.check_password.assert_called_once_with(
|
||||||
|
"@ USER🙂NAME :test", " pASS😢word "
|
||||||
|
)
|
||||||
|
|
||||||
|
@override_config(providers_config(PasswordOnlyAuthProvider))
|
||||||
|
def test_password_only_auth_provider_ui_auth(self):
|
||||||
|
"""UI Auth should delegate correctly to the password provider"""
|
||||||
|
|
||||||
|
# create the user, otherwise access doesn't work
|
||||||
|
module_api = self.hs.get_module_api()
|
||||||
|
self.get_success(module_api.register_user("u"))
|
||||||
|
|
||||||
|
# log in twice, to get two devices
|
||||||
|
mock_password_provider.check_password.return_value = defer.succeed(True)
|
||||||
|
tok1 = self.login("u", "p")
|
||||||
|
self.login("u", "p", device_id="dev2")
|
||||||
|
mock_password_provider.reset_mock()
|
||||||
|
|
||||||
|
# have the auth provider deny the request to start with
|
||||||
|
mock_password_provider.check_password.return_value = defer.succeed(False)
|
||||||
|
|
||||||
|
# make the initial request which returns a 401
|
||||||
|
session = self._start_delete_device_session(tok1, "dev2")
|
||||||
|
mock_password_provider.check_password.assert_not_called()
|
||||||
|
|
||||||
|
# Make another request providing the UI auth flow.
|
||||||
|
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
|
||||||
|
self.assertEqual(channel.code, 401) # XXX why not a 403?
|
||||||
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||||
|
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
|
||||||
|
mock_password_provider.reset_mock()
|
||||||
|
|
||||||
|
# Finally, check the request goes through when we allow it
|
||||||
|
mock_password_provider.check_password.return_value = defer.succeed(True)
|
||||||
|
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
|
||||||
|
|
||||||
|
@override_config(providers_config(PasswordOnlyAuthProvider))
|
||||||
|
def test_local_user_fallback_login(self):
|
||||||
|
"""rejected login should fall back to local db"""
|
||||||
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
|
# check_password must return an awaitable
|
||||||
|
mock_password_provider.check_password.return_value = defer.succeed(False)
|
||||||
|
channel = self._send_password_login("u", "p")
|
||||||
|
self.assertEqual(channel.code, 403, channel.result)
|
||||||
|
|
||||||
|
channel = self._send_password_login("localuser", "localpass")
|
||||||
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
self.assertEqual("@localuser:test", channel.json_body["user_id"])
|
||||||
|
|
||||||
|
@override_config(providers_config(PasswordOnlyAuthProvider))
|
||||||
|
def test_local_user_fallback_ui_auth(self):
|
||||||
|
"""rejected login should fall back to local db"""
|
||||||
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
|
# have the auth provider deny the request
|
||||||
|
mock_password_provider.check_password.return_value = defer.succeed(False)
|
||||||
|
|
||||||
|
# log in twice, to get two devices
|
||||||
|
tok1 = self.login("localuser", "localpass")
|
||||||
|
self.login("localuser", "localpass", device_id="dev2")
|
||||||
|
mock_password_provider.check_password.reset_mock()
|
||||||
|
|
||||||
|
# first delete should give a 401
|
||||||
|
session = self._start_delete_device_session(tok1, "dev2")
|
||||||
|
mock_password_provider.check_password.assert_not_called()
|
||||||
|
|
||||||
|
# Wrong password
|
||||||
|
channel = self._authed_delete_device(tok1, "dev2", session, "localuser", "xxx")
|
||||||
|
self.assertEqual(channel.code, 401) # XXX why not a 403?
|
||||||
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||||
|
mock_password_provider.check_password.assert_called_once_with(
|
||||||
|
"@localuser:test", "xxx"
|
||||||
|
)
|
||||||
|
mock_password_provider.reset_mock()
|
||||||
|
|
||||||
|
# Right password
|
||||||
|
channel = self._authed_delete_device(
|
||||||
|
tok1, "dev2", session, "localuser", "localpass"
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
mock_password_provider.check_password.assert_called_once_with(
|
||||||
|
"@localuser:test", "localpass"
|
||||||
|
)
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
**providers_config(PasswordOnlyAuthProvider),
|
||||||
|
"password_config": {"localdb_enabled": False},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_no_local_user_fallback_login(self):
|
||||||
|
"""localdb_enabled can block login with the local password
|
||||||
|
"""
|
||||||
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
|
# check_password must return an awaitable
|
||||||
|
mock_password_provider.check_password.return_value = defer.succeed(False)
|
||||||
|
channel = self._send_password_login("localuser", "localpass")
|
||||||
|
self.assertEqual(channel.code, 403)
|
||||||
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||||
|
mock_password_provider.check_password.assert_called_once_with(
|
||||||
|
"@localuser:test", "localpass"
|
||||||
|
)
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
**providers_config(PasswordOnlyAuthProvider),
|
||||||
|
"password_config": {"localdb_enabled": False},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_no_local_user_fallback_ui_auth(self):
|
||||||
|
"""localdb_enabled can block ui auth with the local password
|
||||||
|
"""
|
||||||
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
|
# allow login via the auth provider
|
||||||
|
mock_password_provider.check_password.return_value = defer.succeed(True)
|
||||||
|
|
||||||
|
# log in twice, to get two devices
|
||||||
|
tok1 = self.login("localuser", "p")
|
||||||
|
self.login("localuser", "p", device_id="dev2")
|
||||||
|
mock_password_provider.check_password.reset_mock()
|
||||||
|
|
||||||
|
# first delete should give a 401
|
||||||
|
channel = self._delete_device(tok1, "dev2")
|
||||||
|
self.assertEqual(channel.code, 401)
|
||||||
|
# m.login.password UIA is permitted because the auth provider allows it,
|
||||||
|
# even though the localdb does not.
|
||||||
|
self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}])
|
||||||
|
session = channel.json_body["session"]
|
||||||
|
mock_password_provider.check_password.assert_not_called()
|
||||||
|
|
||||||
|
# now try deleting with the local password
|
||||||
|
mock_password_provider.check_password.return_value = defer.succeed(False)
|
||||||
|
channel = self._authed_delete_device(
|
||||||
|
tok1, "dev2", session, "localuser", "localpass"
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 401) # XXX why not a 403?
|
||||||
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||||
|
mock_password_provider.check_password.assert_called_once_with(
|
||||||
|
"@localuser:test", "localpass"
|
||||||
|
)
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
**providers_config(PasswordOnlyAuthProvider),
|
||||||
|
"password_config": {"enabled": False},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_password_auth_disabled(self):
|
||||||
|
"""password auth doesn't work if it's disabled across the board"""
|
||||||
|
# login flows should be empty
|
||||||
|
flows = self._get_login_flows()
|
||||||
|
self.assertEqual(flows, ADDITIONAL_LOGIN_FLOWS)
|
||||||
|
|
||||||
|
# login shouldn't work and should be rejected with a 400 ("unknown login type")
|
||||||
|
channel = self._send_password_login("u", "p")
|
||||||
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
|
mock_password_provider.check_password.assert_not_called()
|
||||||
|
|
||||||
|
@override_config(providers_config(CustomAuthProvider))
|
||||||
|
def test_custom_auth_provider_login(self):
|
||||||
|
# login flows should have the custom flow and m.login.password, since we
|
||||||
|
# haven't disabled local password lookup.
|
||||||
|
# (password must come first, because reasons)
|
||||||
|
flows = self._get_login_flows()
|
||||||
|
self.assertEqual(
|
||||||
|
flows,
|
||||||
|
[{"type": "m.login.password"}, {"type": "test.login_type"}]
|
||||||
|
+ ADDITIONAL_LOGIN_FLOWS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# login with missing param should be rejected
|
||||||
|
channel = self._send_login("test.login_type", "u")
|
||||||
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
|
mock_password_provider.check_auth.assert_not_called()
|
||||||
|
|
||||||
|
mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
|
||||||
|
channel = self._send_login("test.login_type", "u", test_field="y")
|
||||||
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
self.assertEqual("@user:bz", channel.json_body["user_id"])
|
||||||
|
mock_password_provider.check_auth.assert_called_once_with(
|
||||||
|
"u", "test.login_type", {"test_field": "y"}
|
||||||
|
)
|
||||||
|
mock_password_provider.reset_mock()
|
||||||
|
|
||||||
|
# try a weird username. Again, it's unclear what we *expect* to happen
|
||||||
|
# in these cases, but at least we can guard against the API changing
|
||||||
|
# unexpectedly
|
||||||
|
mock_password_provider.check_auth.return_value = defer.succeed(
|
||||||
|
"@ MALFORMED! :bz"
|
||||||
|
)
|
||||||
|
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
|
||||||
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
|
||||||
|
mock_password_provider.check_auth.assert_called_once_with(
|
||||||
|
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
|
||||||
|
)
|
||||||
|
|
||||||
|
@override_config(providers_config(CustomAuthProvider))
|
||||||
|
def test_custom_auth_provider_ui_auth(self):
|
||||||
|
# register the user and log in twice, to get two devices
|
||||||
|
self.register_user("localuser", "localpass")
|
||||||
|
tok1 = self.login("localuser", "localpass")
|
||||||
|
self.login("localuser", "localpass", device_id="dev2")
|
||||||
|
|
||||||
|
# make the initial request which returns a 401
|
||||||
|
channel = self._delete_device(tok1, "dev2")
|
||||||
|
self.assertEqual(channel.code, 401)
|
||||||
|
# Ensure that flows are what is expected.
|
||||||
|
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
|
||||||
|
self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
|
||||||
|
session = channel.json_body["session"]
|
||||||
|
|
||||||
|
# missing param
|
||||||
|
body = {
|
||||||
|
"auth": {
|
||||||
|
"type": "test.login_type",
|
||||||
|
"identifier": {"type": "m.id.user", "user": "localuser"},
|
||||||
|
"session": session,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
channel = self._delete_device(tok1, "dev2", body)
|
||||||
|
self.assertEqual(channel.code, 400)
|
||||||
|
# there's a perfectly good M_MISSING_PARAM errcode, but heaven forfend we should
|
||||||
|
# use it...
|
||||||
|
self.assertIn("Missing parameters", channel.json_body["error"])
|
||||||
|
mock_password_provider.check_auth.assert_not_called()
|
||||||
|
mock_password_provider.reset_mock()
|
||||||
|
|
||||||
|
# right params, but authing as the wrong user
|
||||||
|
mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
|
||||||
|
body["auth"]["test_field"] = "foo"
|
||||||
|
channel = self._delete_device(tok1, "dev2", body)
|
||||||
|
self.assertEqual(channel.code, 403)
|
||||||
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||||
|
mock_password_provider.check_auth.assert_called_once_with(
|
||||||
|
"localuser", "test.login_type", {"test_field": "foo"}
|
||||||
|
)
|
||||||
|
mock_password_provider.reset_mock()
|
||||||
|
|
||||||
|
# and finally, succeed
|
||||||
|
mock_password_provider.check_auth.return_value = defer.succeed(
|
||||||
|
"@localuser:test"
|
||||||
|
)
|
||||||
|
channel = self._delete_device(tok1, "dev2", body)
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
mock_password_provider.check_auth.assert_called_once_with(
|
||||||
|
"localuser", "test.login_type", {"test_field": "foo"}
|
||||||
|
)
|
||||||
|
|
||||||
|
@override_config(providers_config(CustomAuthProvider))
|
||||||
|
def test_custom_auth_provider_callback(self):
|
||||||
|
callback = Mock(return_value=defer.succeed(None))
|
||||||
|
|
||||||
|
mock_password_provider.check_auth.return_value = defer.succeed(
|
||||||
|
("@user:bz", callback)
|
||||||
|
)
|
||||||
|
channel = self._send_login("test.login_type", "u", test_field="y")
|
||||||
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
self.assertEqual("@user:bz", channel.json_body["user_id"])
|
||||||
|
mock_password_provider.check_auth.assert_called_once_with(
|
||||||
|
"u", "test.login_type", {"test_field": "y"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# check the args to the callback
|
||||||
|
callback.assert_called_once()
|
||||||
|
call_args, call_kwargs = callback.call_args
|
||||||
|
# should be one positional arg
|
||||||
|
self.assertEqual(len(call_args), 1)
|
||||||
|
self.assertEqual(call_args[0]["user_id"], "@user:bz")
|
||||||
|
for p in ["user_id", "access_token", "device_id", "home_server"]:
|
||||||
|
self.assertIn(p, call_args[0])
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
|
||||||
|
)
|
||||||
|
def test_custom_auth_password_disabled(self):
|
||||||
|
"""Test login with a custom auth provider where password login is disabled"""
|
||||||
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
|
flows = self._get_login_flows()
|
||||||
|
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
|
||||||
|
|
||||||
|
# login shouldn't work and should be rejected with a 400 ("unknown login type")
|
||||||
|
channel = self._send_password_login("localuser", "localpass")
|
||||||
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
|
mock_password_provider.check_auth.assert_not_called()
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
**providers_config(PasswordCustomAuthProvider),
|
||||||
|
"password_config": {"enabled": False},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_password_custom_auth_password_disabled_login(self):
|
||||||
|
"""log in with a custom auth provider which implements password, but password
|
||||||
|
login is disabled"""
|
||||||
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
|
flows = self._get_login_flows()
|
||||||
|
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
|
||||||
|
|
||||||
|
# login shouldn't work and should be rejected with a 400 ("unknown login type")
|
||||||
|
channel = self._send_password_login("localuser", "localpass")
|
||||||
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
|
mock_password_provider.check_auth.assert_not_called()
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
**providers_config(PasswordCustomAuthProvider),
|
||||||
|
"password_config": {"enabled": False},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_password_custom_auth_password_disabled_ui_auth(self):
|
||||||
|
"""UI Auth with a custom auth provider which implements password, but password
|
||||||
|
login is disabled"""
|
||||||
|
# register the user and log in twice via the test login type to get two devices,
|
||||||
|
self.register_user("localuser", "localpass")
|
||||||
|
mock_password_provider.check_auth.return_value = defer.succeed(
|
||||||
|
"@localuser:test"
|
||||||
|
)
|
||||||
|
channel = self._send_login("test.login_type", "localuser", test_field="")
|
||||||
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
tok1 = channel.json_body["access_token"]
|
||||||
|
|
||||||
|
channel = self._send_login(
|
||||||
|
"test.login_type", "localuser", test_field="", device_id="dev2"
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
|
||||||
|
# make the initial request which returns a 401
|
||||||
|
channel = self._delete_device(tok1, "dev2")
|
||||||
|
self.assertEqual(channel.code, 401)
|
||||||
|
# Ensure that flows are what is expected. In particular, "password" should *not*
|
||||||
|
# be present.
|
||||||
|
self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
|
||||||
|
session = channel.json_body["session"]
|
||||||
|
|
||||||
|
mock_password_provider.reset_mock()
|
||||||
|
|
||||||
|
# check that auth with password is rejected
|
||||||
|
body = {
|
||||||
|
"auth": {
|
||||||
|
"type": "m.login.password",
|
||||||
|
"identifier": {"type": "m.id.user", "user": "localuser"},
|
||||||
|
"password": "localpass",
|
||||||
|
"session": session,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
channel = self._delete_device(tok1, "dev2", body)
|
||||||
|
self.assertEqual(channel.code, 400)
|
||||||
|
self.assertEqual(
|
||||||
|
"Password login has been disabled.", channel.json_body["error"]
|
||||||
|
)
|
||||||
|
mock_password_provider.check_auth.assert_not_called()
|
||||||
|
mock_password_provider.reset_mock()
|
||||||
|
|
||||||
|
# successful auth
|
||||||
|
body["auth"]["type"] = "test.login_type"
|
||||||
|
body["auth"]["test_field"] = "x"
|
||||||
|
channel = self._delete_device(tok1, "dev2", body)
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
mock_password_provider.check_auth.assert_called_once_with(
|
||||||
|
"localuser", "test.login_type", {"test_field": "x"}
|
||||||
|
)
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
**providers_config(CustomAuthProvider),
|
||||||
|
"password_config": {"localdb_enabled": False},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_custom_auth_no_local_user_fallback(self):
|
||||||
|
"""Test login with a custom auth provider where the local db is disabled"""
|
||||||
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
|
flows = self._get_login_flows()
|
||||||
|
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
|
||||||
|
|
||||||
|
# password login shouldn't work and should be rejected with a 400
|
||||||
|
# ("unknown login type")
|
||||||
|
channel = self._send_password_login("localuser", "localpass")
|
||||||
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
|
|
||||||
|
def _get_login_flows(self) -> JsonDict:
|
||||||
|
_, channel = self.make_request("GET", "/_matrix/client/r0/login")
|
||||||
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
return channel.json_body["flows"]
|
||||||
|
|
||||||
|
def _send_password_login(self, user: str, password: str) -> FakeChannel:
|
||||||
|
return self._send_login(type="m.login.password", user=user, password=password)
|
||||||
|
|
||||||
|
def _send_login(self, type, user, **params) -> FakeChannel:
|
||||||
|
params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
|
||||||
|
_, channel = self.make_request("POST", "/_matrix/client/r0/login", params)
|
||||||
|
return channel
|
||||||
|
|
||||||
|
def _start_delete_device_session(self, access_token, device_id) -> str:
|
||||||
|
"""Make an initial delete device request, and return the UI Auth session ID"""
|
||||||
|
channel = self._delete_device(access_token, device_id)
|
||||||
|
self.assertEqual(channel.code, 401)
|
||||||
|
# Ensure that flows are what is expected.
|
||||||
|
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
|
||||||
|
return channel.json_body["session"]
|
||||||
|
|
||||||
|
def _authed_delete_device(
|
||||||
|
self,
|
||||||
|
access_token: str,
|
||||||
|
device_id: str,
|
||||||
|
session: str,
|
||||||
|
user_id: str,
|
||||||
|
password: str,
|
||||||
|
) -> FakeChannel:
|
||||||
|
"""Make a delete device request, authenticating with the given uid/password"""
|
||||||
|
return self._delete_device(
|
||||||
|
access_token,
|
||||||
|
device_id,
|
||||||
|
{
|
||||||
|
"auth": {
|
||||||
|
"type": "m.login.password",
|
||||||
|
"identifier": {"type": "m.id.user", "user": user_id},
|
||||||
|
"password": password,
|
||||||
|
"session": session,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _delete_device(
|
||||||
|
self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"",
|
||||||
|
) -> FakeChannel:
|
||||||
|
"""Delete an individual device."""
|
||||||
|
_, channel = self.make_request(
|
||||||
|
"DELETE", "devices/" + device, body, access_token=access_token
|
||||||
|
)
|
||||||
|
return channel
|
168
tests/handlers/test_saml.py
Normal file
168
tests/handlers/test_saml.py
Normal file
|
@ -0,0 +1,168 @@
|
||||||
|
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
|
from synapse.handlers.sso import MappingException
|
||||||
|
|
||||||
|
from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
|
||||||
|
# These are a few constants that are used as config parameters in the tests.
|
||||||
|
BASE_URL = "https://synapse/"
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class FakeAuthnResponse:
|
||||||
|
ava = attr.ib(type=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMappingProvider:
|
||||||
|
def __init__(self, config, module):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_config(config):
|
||||||
|
return
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_saml_attributes(config):
|
||||||
|
return {"uid"}, {"displayName"}
|
||||||
|
|
||||||
|
def get_remote_user_id(self, saml_response, client_redirect_url):
|
||||||
|
return saml_response.ava["uid"]
|
||||||
|
|
||||||
|
def saml_response_to_user_attributes(
|
||||||
|
self, saml_response, failures, client_redirect_url
|
||||||
|
):
|
||||||
|
localpart = saml_response.ava["username"] + (str(failures) if failures else "")
|
||||||
|
return {"mxid_localpart": localpart, "displayname": None}
|
||||||
|
|
||||||
|
|
||||||
|
class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
def default_config(self):
|
||||||
|
config = super().default_config()
|
||||||
|
config["public_baseurl"] = BASE_URL
|
||||||
|
saml_config = {
|
||||||
|
"sp_config": {"metadata": {}},
|
||||||
|
# Disable grandfathering.
|
||||||
|
"grandfathered_mxid_source_attribute": None,
|
||||||
|
"user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update this config with what's in the default config so that
|
||||||
|
# override_config works as expected.
|
||||||
|
saml_config.update(config.get("saml2_config", {}))
|
||||||
|
config["saml2_config"] = saml_config
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def make_homeserver(self, reactor, clock):
|
||||||
|
hs = self.setup_test_homeserver()
|
||||||
|
|
||||||
|
self.handler = hs.get_saml_handler()
|
||||||
|
|
||||||
|
# Reduce the number of attempts when generating MXIDs.
|
||||||
|
sso_handler = hs.get_sso_handler()
|
||||||
|
sso_handler._MAP_USERNAME_RETRIES = 3
|
||||||
|
|
||||||
|
return hs
|
||||||
|
|
||||||
|
def test_map_saml_response_to_user(self):
|
||||||
|
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
|
||||||
|
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
|
||||||
|
# The redirect_url doesn't matter with the default user mapping provider.
|
||||||
|
redirect_url = ""
|
||||||
|
mxid = self.get_success(
|
||||||
|
self.handler._map_saml_response_to_user(
|
||||||
|
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(mxid, "@test_user:test")
|
||||||
|
|
||||||
|
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
|
||||||
|
def test_map_saml_response_to_existing_user(self):
|
||||||
|
"""Existing users can log in with SAML account."""
|
||||||
|
store = self.hs.get_datastore()
|
||||||
|
self.get_success(
|
||||||
|
store.register_user(user_id="@test_user:test", password_hash=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map a user via SSO.
|
||||||
|
saml_response = FakeAuthnResponse(
|
||||||
|
{"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
|
||||||
|
)
|
||||||
|
redirect_url = ""
|
||||||
|
mxid = self.get_success(
|
||||||
|
self.handler._map_saml_response_to_user(
|
||||||
|
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(mxid, "@test_user:test")
|
||||||
|
|
||||||
|
# Subsequent calls should map to the same mxid.
|
||||||
|
mxid = self.get_success(
|
||||||
|
self.handler._map_saml_response_to_user(
|
||||||
|
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(mxid, "@test_user:test")
|
||||||
|
|
||||||
|
def test_map_saml_response_to_invalid_localpart(self):
|
||||||
|
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
||||||
|
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
|
||||||
|
redirect_url = ""
|
||||||
|
e = self.get_failure(
|
||||||
|
self.handler._map_saml_response_to_user(
|
||||||
|
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||||
|
),
|
||||||
|
MappingException,
|
||||||
|
)
|
||||||
|
self.assertEqual(str(e.value), "localpart is invalid: föö")
|
||||||
|
|
||||||
|
def test_map_saml_response_to_user_retries(self):
|
||||||
|
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
|
||||||
|
store = self.hs.get_datastore()
|
||||||
|
self.get_success(
|
||||||
|
store.register_user(user_id="@test_user:test", password_hash=None)
|
||||||
|
)
|
||||||
|
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
|
||||||
|
redirect_url = ""
|
||||||
|
mxid = self.get_success(
|
||||||
|
self.handler._map_saml_response_to_user(
|
||||||
|
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# test_user is already taken, so test_user1 gets registered instead.
|
||||||
|
self.assertEqual(mxid, "@test_user1:test")
|
||||||
|
|
||||||
|
# Register all of the potential mxids for a particular SAML username.
|
||||||
|
self.get_success(
|
||||||
|
store.register_user(user_id="@tester:test", password_hash=None)
|
||||||
|
)
|
||||||
|
for i in range(1, 3):
|
||||||
|
self.get_success(
|
||||||
|
store.register_user(user_id="@tester%d:test" % i, password_hash=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now attempt to map to a username, this will fail since all potential usernames are taken.
|
||||||
|
saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
|
||||||
|
e = self.get_failure(
|
||||||
|
self.handler._map_saml_response_to_user(
|
||||||
|
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||||
|
),
|
||||||
|
MappingException,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
str(e.value), "Unable to generate a Matrix ID from the SSO response"
|
||||||
|
)
|
|
@ -12,7 +12,6 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
|
@ -20,8 +19,9 @@ from twisted.internet.defer import Deferred
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.rest.client.v1 import login, room
|
from synapse.rest.client.v1 import login, room
|
||||||
|
from synapse.rest.client.v2_alpha import receipts
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
|
||||||
|
|
||||||
class HTTPPusherTests(HomeserverTestCase):
|
class HTTPPusherTests(HomeserverTestCase):
|
||||||
|
@ -29,6 +29,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||||
room.register_servlets,
|
room.register_servlets,
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
|
receipts.register_servlets,
|
||||||
]
|
]
|
||||||
user_id = True
|
user_id = True
|
||||||
hijack_auth = False
|
hijack_auth = False
|
||||||
|
@ -499,3 +500,161 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
|
|
||||||
# check that this is low-priority
|
# check that this is low-priority
|
||||||
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
|
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
|
||||||
|
|
||||||
|
def test_push_unread_count_group_by_room(self):
|
||||||
|
"""
|
||||||
|
The HTTP pusher will group unread count by number of unread rooms.
|
||||||
|
"""
|
||||||
|
# Carry out common push count tests and setup
|
||||||
|
self._test_push_unread_count()
|
||||||
|
|
||||||
|
# Carry out our option-value specific test
|
||||||
|
#
|
||||||
|
# This push should still only contain an unread count of 1 (for 1 unread room)
|
||||||
|
self.assertEqual(
|
||||||
|
self.push_attempts[5][2]["notification"]["counts"]["unread"], 1
|
||||||
|
)
|
||||||
|
|
||||||
|
@override_config({"push": {"group_unread_count_by_room": False}})
|
||||||
|
def test_push_unread_count_message_count(self):
|
||||||
|
"""
|
||||||
|
The HTTP pusher will send the total unread message count.
|
||||||
|
"""
|
||||||
|
# Carry out common push count tests and setup
|
||||||
|
self._test_push_unread_count()
|
||||||
|
|
||||||
|
# Carry out our option-value specific test
|
||||||
|
#
|
||||||
|
# We're counting every unread message, so there should now be 4 since the
|
||||||
|
# last read receipt
|
||||||
|
self.assertEqual(
|
||||||
|
self.push_attempts[5][2]["notification"]["counts"]["unread"], 4
|
||||||
|
)
|
||||||
|
|
||||||
|
def _test_push_unread_count(self):
|
||||||
|
"""
|
||||||
|
Tests that the correct unread count appears in sent push notifications
|
||||||
|
|
||||||
|
Note that:
|
||||||
|
* Sending messages will cause push notifications to go out to relevant users
|
||||||
|
* Sending a read receipt will cause a "badge update" notification to go out to
|
||||||
|
the user that sent the receipt
|
||||||
|
"""
|
||||||
|
# Register the user who gets notified
|
||||||
|
user_id = self.register_user("user", "pass")
|
||||||
|
access_token = self.login("user", "pass")
|
||||||
|
|
||||||
|
# Register the user who sends the message
|
||||||
|
other_user_id = self.register_user("other_user", "pass")
|
||||||
|
other_access_token = self.login("other_user", "pass")
|
||||||
|
|
||||||
|
# Create a room (as other_user)
|
||||||
|
room_id = self.helper.create_room_as(other_user_id, tok=other_access_token)
|
||||||
|
|
||||||
|
# The user to get notified joins
|
||||||
|
self.helper.join(room=room_id, user=user_id, tok=access_token)
|
||||||
|
|
||||||
|
# Register the pusher
|
||||||
|
user_tuple = self.get_success(
|
||||||
|
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||||
|
)
|
||||||
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
|
self.get_success(
|
||||||
|
self.hs.get_pusherpool().add_pusher(
|
||||||
|
user_id=user_id,
|
||||||
|
access_token=token_id,
|
||||||
|
kind="http",
|
||||||
|
app_id="m.http",
|
||||||
|
app_display_name="HTTP Push Notifications",
|
||||||
|
device_display_name="pushy push",
|
||||||
|
pushkey="a@example.com",
|
||||||
|
lang=None,
|
||||||
|
data={"url": "example.com"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send a message
|
||||||
|
response = self.helper.send(
|
||||||
|
room_id, body="Hello there!", tok=other_access_token
|
||||||
|
)
|
||||||
|
# To get an unread count, the user who is getting notified has to have a read
|
||||||
|
# position in the room. We'll set the read position to this event in a moment
|
||||||
|
first_message_event_id = response["event_id"]
|
||||||
|
|
||||||
|
# Advance time a bit (so the pusher will register something has happened) and
|
||||||
|
# make the push succeed
|
||||||
|
self.push_attempts[0][0].callback({})
|
||||||
|
self.pump()
|
||||||
|
|
||||||
|
# Check our push made it
|
||||||
|
self.assertEqual(len(self.push_attempts), 1)
|
||||||
|
self.assertEqual(self.push_attempts[0][1], "example.com")
|
||||||
|
|
||||||
|
# Check that the unread count for the room is 0
|
||||||
|
#
|
||||||
|
# The unread count is zero as the user has no read receipt in the room yet
|
||||||
|
self.assertEqual(
|
||||||
|
self.push_attempts[0][2]["notification"]["counts"]["unread"], 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now set the user's read receipt position to the first event
|
||||||
|
#
|
||||||
|
# This will actually trigger a new notification to be sent out so that
|
||||||
|
# even if the user does not receive another message, their unread
|
||||||
|
# count goes down
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id),
|
||||||
|
{},
|
||||||
|
access_token=access_token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200, channel.json_body)
|
||||||
|
|
||||||
|
# Advance time and make the push succeed
|
||||||
|
self.push_attempts[1][0].callback({})
|
||||||
|
self.pump()
|
||||||
|
|
||||||
|
# Unread count is still zero as we've read the only message in the room
|
||||||
|
self.assertEqual(len(self.push_attempts), 2)
|
||||||
|
self.assertEqual(
|
||||||
|
self.push_attempts[1][2]["notification"]["counts"]["unread"], 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send another message
|
||||||
|
self.helper.send(
|
||||||
|
room_id, body="How's the weather today?", tok=other_access_token
|
||||||
|
)
|
||||||
|
|
||||||
|
# Advance time and make the push succeed
|
||||||
|
self.push_attempts[2][0].callback({})
|
||||||
|
self.pump()
|
||||||
|
|
||||||
|
# This push should contain an unread count of 1 as there's now been one
|
||||||
|
# message since our last read receipt
|
||||||
|
self.assertEqual(len(self.push_attempts), 3)
|
||||||
|
self.assertEqual(
|
||||||
|
self.push_attempts[2][2]["notification"]["counts"]["unread"], 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Since we're grouping by room, sending more messages shouldn't increase the
|
||||||
|
# unread count, as they're all being sent in the same room
|
||||||
|
self.helper.send(room_id, body="Hello?", tok=other_access_token)
|
||||||
|
|
||||||
|
# Advance time and make the push succeed
|
||||||
|
self.pump()
|
||||||
|
self.push_attempts[3][0].callback({})
|
||||||
|
|
||||||
|
self.helper.send(room_id, body="Hello??", tok=other_access_token)
|
||||||
|
|
||||||
|
# Advance time and make the push succeed
|
||||||
|
self.pump()
|
||||||
|
self.push_attempts[4][0].callback({})
|
||||||
|
|
||||||
|
self.helper.send(room_id, body="HELLO???", tok=other_access_token)
|
||||||
|
|
||||||
|
# Advance time and make the push succeed
|
||||||
|
self.pump()
|
||||||
|
self.push_attempts[5][0].callback({})
|
||||||
|
|
||||||
|
self.assertEqual(len(self.push_attempts), 6)
|
||||||
|
|
|
@ -78,7 +78,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool
|
self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool
|
||||||
|
|
||||||
self.test_handler = self._build_replication_data_handler()
|
self.test_handler = self._build_replication_data_handler()
|
||||||
self.worker_hs.replication_data_handler = self.test_handler
|
self.worker_hs._replication_data_handler = self.test_handler
|
||||||
|
|
||||||
repl_handler = ReplicationCommandHandler(self.worker_hs)
|
repl_handler = ReplicationCommandHandler(self.worker_hs)
|
||||||
self.client = ClientReplicationStreamProtocol(
|
self.client = ClientReplicationStreamProtocol(
|
||||||
|
|
|
@ -192,7 +192,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
|
||||||
self.handler = hs.get_device_handler()
|
self.handler = hs.get_device_handler()
|
||||||
self.media_repo = hs.get_media_repository_resource()
|
self.media_repo = hs.get_media_repository_resource()
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.clock = hs.clock
|
|
||||||
|
|
||||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||||
self.admin_user_tok = self.login("admin", "pass")
|
self.admin_user_tok = self.login("admin", "pass")
|
||||||
|
|
|
@ -33,12 +33,15 @@ class PresenceTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor, clock):
|
||||||
|
|
||||||
hs = self.setup_test_homeserver(
|
presence_handler = Mock()
|
||||||
"red", http_client=None, federation_client=Mock()
|
presence_handler.set_state.return_value = defer.succeed(None)
|
||||||
)
|
|
||||||
|
|
||||||
hs.presence_handler = Mock()
|
hs = self.setup_test_homeserver(
|
||||||
hs.presence_handler.set_state.return_value = defer.succeed(None)
|
"red",
|
||||||
|
http_client=None,
|
||||||
|
federation_client=Mock(),
|
||||||
|
presence_handler=presence_handler,
|
||||||
|
)
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
|
@ -55,7 +58,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
self.assertEqual(self.hs.presence_handler.set_state.call_count, 1)
|
self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1)
|
||||||
|
|
||||||
def test_put_presence_disabled(self):
|
def test_put_presence_disabled(self):
|
||||||
"""
|
"""
|
||||||
|
@ -70,4 +73,4 @@ class PresenceTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
self.assertEqual(self.hs.presence_handler.set_state.call_count, 0)
|
self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0)
|
||||||
|
|
|
@ -41,14 +41,37 @@ class RestHelper:
|
||||||
auth_user_id = attr.ib()
|
auth_user_id = attr.ib()
|
||||||
|
|
||||||
def create_room_as(
|
def create_room_as(
|
||||||
self, room_creator=None, is_public=True, tok=None, expect_code=200,
|
self,
|
||||||
):
|
room_creator: str = None,
|
||||||
|
is_public: bool = True,
|
||||||
|
room_version: str = None,
|
||||||
|
tok: str = None,
|
||||||
|
expect_code: int = 200,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Create a room.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_creator: The user ID to create the room with.
|
||||||
|
is_public: If True, the `visibility` parameter will be set to the
|
||||||
|
default (public). Otherwise, the `visibility` parameter will be set
|
||||||
|
to "private".
|
||||||
|
room_version: The room version to create the room as. Defaults to Synapse's
|
||||||
|
default room version.
|
||||||
|
tok: The access token to use in the request.
|
||||||
|
expect_code: The expected HTTP response code.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The ID of the newly created room.
|
||||||
|
"""
|
||||||
temp_id = self.auth_user_id
|
temp_id = self.auth_user_id
|
||||||
self.auth_user_id = room_creator
|
self.auth_user_id = room_creator
|
||||||
path = "/_matrix/client/r0/createRoom"
|
path = "/_matrix/client/r0/createRoom"
|
||||||
content = {}
|
content = {}
|
||||||
if not is_public:
|
if not is_public:
|
||||||
content["visibility"] = "private"
|
content["visibility"] = "private"
|
||||||
|
if room_version:
|
||||||
|
content["room_version"] = room_version
|
||||||
if tok:
|
if tok:
|
||||||
path = path + "?access_token=%s" % tok
|
path = path + "?access_token=%s" % tok
|
||||||
|
|
||||||
|
|
|
@ -38,11 +38,6 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker):
|
||||||
return succeed(True)
|
return succeed(True)
|
||||||
|
|
||||||
|
|
||||||
class DummyPasswordChecker(UserInteractiveAuthChecker):
|
|
||||||
def check_auth(self, authdict, clientip):
|
|
||||||
return succeed(authdict["identifier"]["user"])
|
|
||||||
|
|
||||||
|
|
||||||
class FallbackAuthTests(unittest.HomeserverTestCase):
|
class FallbackAuthTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
servlets = [
|
servlets = [
|
||||||
|
@ -162,9 +157,6 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
auth_handler = hs.get_auth_handler()
|
|
||||||
auth_handler.checkers[LoginType.PASSWORD] = DummyPasswordChecker(hs)
|
|
||||||
|
|
||||||
self.user_pass = "pass"
|
self.user_pass = "pass"
|
||||||
self.user = self.register_user("test", self.user_pass)
|
self.user = self.register_user("test", self.user_pass)
|
||||||
self.user_tok = self.login("test", self.user_pass)
|
self.user_tok = self.login("test", self.user_pass)
|
||||||
|
@ -234,6 +226,31 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_grandfathered_identifier(self):
|
||||||
|
"""Check behaviour without "identifier" dict
|
||||||
|
|
||||||
|
Synapse used to require clients to submit a "user" field for m.login.password
|
||||||
|
UIA - check that still works.
|
||||||
|
"""
|
||||||
|
|
||||||
|
device_id = self.get_device_ids()[0]
|
||||||
|
channel = self.delete_device(device_id, 401)
|
||||||
|
session = channel.json_body["session"]
|
||||||
|
|
||||||
|
# Make another request providing the UI auth flow.
|
||||||
|
self.delete_device(
|
||||||
|
device_id,
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"auth": {
|
||||||
|
"type": "m.login.password",
|
||||||
|
"user": self.user,
|
||||||
|
"password": self.user_pass,
|
||||||
|
"session": session,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_can_change_body(self):
|
def test_can_change_body(self):
|
||||||
"""
|
"""
|
||||||
The client dict can be modified during the user interactive authentication session.
|
The client dict can be modified during the user interactive authentication session.
|
||||||
|
|
|
@ -569,7 +569,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
||||||
tok = self.login("kermit", "monkey")
|
tok = self.login("kermit", "monkey")
|
||||||
# We need to manually add an email address otherwise the handler will do
|
# We need to manually add an email address otherwise the handler will do
|
||||||
# nothing.
|
# nothing.
|
||||||
now = self.hs.clock.time_msec()
|
now = self.hs.get_clock().time_msec()
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.user_add_threepid(
|
self.store.user_add_threepid(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
@ -587,7 +587,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# We need to manually add an email address otherwise the handler will do
|
# We need to manually add an email address otherwise the handler will do
|
||||||
# nothing.
|
# nothing.
|
||||||
now = self.hs.clock.time_msec()
|
now = self.hs.get_clock().time_msec()
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.user_add_threepid(
|
self.store.user_add_threepid(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
@ -646,7 +646,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.hs.config.account_validity.startup_job_max_delta = self.max_delta
|
self.hs.config.account_validity.startup_job_max_delta = self.max_delta
|
||||||
|
|
||||||
now_ms = self.hs.clock.time_msec()
|
now_ms = self.hs.get_clock().time_msec()
|
||||||
self.get_success(self.store._set_expiration_date_when_missing())
|
self.get_success(self.store._set_expiration_date_when_missing())
|
||||||
|
|
||||||
res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
|
res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
|
||||||
|
|
|
@ -271,7 +271,7 @@ def setup_test_homeserver(
|
||||||
|
|
||||||
# Install @cache_in_self attributes
|
# Install @cache_in_self attributes
|
||||||
for key, val in kwargs.items():
|
for key, val in kwargs.items():
|
||||||
setattr(hs, key, val)
|
setattr(hs, "_" + key, val)
|
||||||
|
|
||||||
# Mock TLS
|
# Mock TLS
|
||||||
hs.tls_server_context_factory = Mock()
|
hs.tls_server_context_factory = Mock()
|
||||||
|
|
Loading…
Reference in a new issue