Merge remote-tracking branch 'origin/develop' into erikj/type_server

This commit is contained in:
Erik Johnston 2020-08-11 22:03:14 +01:00
commit fdb46b5442
19 changed files with 406 additions and 117 deletions

1
changelog.d/8040.misc Normal file
View file

@ -0,0 +1 @@
Change the default log config to reduce disk I/O and storage for new servers.

1
changelog.d/8050.misc Normal file
View file

@ -0,0 +1 @@
Reduce amount of outbound request logging at INFO level.

1
changelog.d/8051.misc Normal file
View file

@ -0,0 +1 @@
It is no longer necessary to explicitly define `filters` in the logging configuration. (Continuing to do so is redundant but harmless.)

1
changelog.d/8052.feature Normal file
View file

@ -0,0 +1 @@
Allow login to be blocked based on the values of SAML attributes.

1
changelog.d/8058.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to `Notifier`.

View file

@ -4,16 +4,10 @@ formatters:
precise: precise:
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s' format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s'
filters:
context:
(): synapse.logging.context.LoggingContextFilter
request: ""
handlers: handlers:
console: console:
class: logging.StreamHandler class: logging.StreamHandler
formatter: precise formatter: precise
filters: [context]
loggers: loggers:
synapse.storage.SQL: synapse.storage.SQL:

View file

@ -1577,6 +1577,17 @@ saml2_config:
# #
#grandfathered_mxid_source_attribute: upn #grandfathered_mxid_source_attribute: upn
# It is possible to configure Synapse to only allow logins if SAML attributes
# match particular values. The requirements can be listed under
# `attribute_requirements` as shown below. All of the listed attributes must
# match for the login to be permitted.
#
#attribute_requirements:
# - attribute: userGroup
# value: "staff"
# - attribute: department
# value: "sales"
# Directory in which Synapse will try to find the template files below. # Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used. # If not set, default templates from within the Synapse package will be used.
# #

View file

@ -11,24 +11,33 @@ formatters:
precise: precise:
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s' format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s'
filters:
context:
(): synapse.logging.context.LoggingContextFilter
request: ""
handlers: handlers:
file: file:
class: logging.handlers.RotatingFileHandler class: logging.handlers.TimedRotatingFileHandler
formatter: precise formatter: precise
filename: /var/log/matrix-synapse/homeserver.log filename: /var/log/matrix-synapse/homeserver.log
maxBytes: 104857600 when: midnight
backupCount: 10 backupCount: 3 # Does not include the current log file.
filters: [context]
encoding: utf8 encoding: utf8
# Default to buffering writes to log file for efficiency. This means that
# will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR
# logs will still be flushed immediately.
buffer:
class: logging.handlers.MemoryHandler
target: file
# The capacity is the number of log lines that are buffered before
# being written to disk. Increasing this will lead to better
# performance, at the expensive of it taking longer for log lines to
# be written to disk.
capacity: 10
flushLevel: 30 # Flush for WARNING logs as well
# A handler that writes logs to stderr. Unused by default, but can be used
# instead of "buffer" and "file" in the logger handlers.
console: console:
class: logging.StreamHandler class: logging.StreamHandler
formatter: precise formatter: precise
filters: [context]
loggers: loggers:
synapse.storage.SQL: synapse.storage.SQL:
@ -36,8 +45,23 @@ loggers:
# information such as access tokens. # information such as access tokens.
level: INFO level: INFO
twisted:
# We send the twisted logging directly to the file handler,
# to work around https://github.com/matrix-org/synapse/issues/3471
# when using "buffer" logger. Use "console" to log to stderr instead.
handlers: [file]
propagate: false
root: root:
level: INFO level: INFO
handlers: [file, console]
# Write logs to the `buffer` handler, which will buffer them together in memory,
# then write them to a file.
#
# Replace "buffer" with "console" to log to stderr instead. (Note that you'll
# also need to update the configuation for the `twisted` logger above, in
# this case.)
#
handlers: [buffer]
disable_existing_loggers: false disable_existing_loggers: false

49
synapse/config/_util.py Normal file
View file

@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List
import jsonschema
from synapse.config._base import ConfigError
from synapse.types import JsonDict
def validate_config(json_schema: JsonDict, config: Any, config_path: List[str]) -> None:
"""Validates a config setting against a JsonSchema definition
This can be used to validate a section of the config file against a schema
definition. If the validation fails, a ConfigError is raised with a textual
description of the problem.
Args:
json_schema: the schema to validate against
config: the configuration value to be validated
config_path: the path within the config file. This will be used as a basis
for the error message.
"""
try:
jsonschema.validate(config, json_schema)
except jsonschema.ValidationError as e:
# copy `config_path` before modifying it.
path = list(config_path)
for p in list(e.path):
if isinstance(p, int):
path.append("<item %i>" % p)
else:
path.append(str(p))
raise ConfigError(
"Unable to parse configuration: %s at %s" % (e.message, ".".join(path))
)

View file

@ -55,24 +55,33 @@ formatters:
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \ format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \
%(request)s - %(message)s' %(request)s - %(message)s'
filters:
context:
(): synapse.logging.context.LoggingContextFilter
request: ""
handlers: handlers:
file: file:
class: logging.handlers.RotatingFileHandler class: logging.handlers.TimedRotatingFileHandler
formatter: precise formatter: precise
filename: ${log_file} filename: ${log_file}
maxBytes: 104857600 when: midnight
backupCount: 10 backupCount: 3 # Does not include the current log file.
filters: [context]
encoding: utf8 encoding: utf8
# Default to buffering writes to log file for efficiency. This means that
# will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR
# logs will still be flushed immediately.
buffer:
class: logging.handlers.MemoryHandler
target: file
# The capacity is the number of log lines that are buffered before
# being written to disk. Increasing this will lead to better
# performance, at the expensive of it taking longer for log lines to
# be written to disk.
capacity: 10
flushLevel: 30 # Flush for WARNING logs as well
# A handler that writes logs to stderr. Unused by default, but can be used
# instead of "buffer" and "file" in the logger handlers.
console: console:
class: logging.StreamHandler class: logging.StreamHandler
formatter: precise formatter: precise
filters: [context]
loggers: loggers:
synapse.storage.SQL: synapse.storage.SQL:
@ -80,9 +89,24 @@ loggers:
# information such as access tokens. # information such as access tokens.
level: INFO level: INFO
twisted:
# We send the twisted logging directly to the file handler,
# to work around https://github.com/matrix-org/synapse/issues/3471
# when using "buffer" logger. Use "console" to log to stderr instead.
handlers: [file]
propagate: false
root: root:
level: INFO level: INFO
handlers: [file, console]
# Write logs to the `buffer` handler, which will buffer them together in memory,
# then write them to a file.
#
# Replace "buffer" with "console" to log to stderr instead. (Note that you'll
# also need to update the configuation for the `twisted` logger above, in
# this case.)
#
handlers: [buffer]
disable_existing_loggers: false disable_existing_loggers: false
""" """
@ -168,11 +192,26 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(formatter) handler.setFormatter(formatter)
handler.addFilter(LoggingContextFilter(request=""))
logger.addHandler(handler) logger.addHandler(handler)
else: else:
logging.config.dictConfig(log_config) logging.config.dictConfig(log_config)
# We add a log record factory that runs all messages through the
# LoggingContextFilter so that we get the context *at the time we log*
# rather than when we write to a handler. This can be done in config using
# filter options, but care must when using e.g. MemoryHandler to buffer
# writes.
log_filter = LoggingContextFilter(request="")
old_factory = logging.getLogRecordFactory()
def factory(*args, **kwargs):
record = old_factory(*args, **kwargs)
log_filter.filter(record)
return record
logging.setLogRecordFactory(factory)
# Route Twisted's native logging through to the standard library logging # Route Twisted's native logging through to the standard library logging
# system. # system.
observer = STDLibLogObserver() observer = STDLibLogObserver()

View file

@ -15,7 +15,9 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, List
import attr
import jinja2 import jinja2
import pkg_resources import pkg_resources
@ -23,6 +25,7 @@ from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module, load_python_module from synapse.util.module_loader import load_module, load_python_module
from ._base import Config, ConfigError from ._base import Config, ConfigError
from ._util import validate_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -80,6 +83,11 @@ class SAML2Config(Config):
self.saml2_enabled = True self.saml2_enabled = True
attribute_requirements = saml2_config.get("attribute_requirements") or []
self.attribute_requirements = _parse_attribute_requirements_def(
attribute_requirements
)
self.saml2_grandfathered_mxid_source_attribute = saml2_config.get( self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
"grandfathered_mxid_source_attribute", "uid" "grandfathered_mxid_source_attribute", "uid"
) )
@ -341,6 +349,17 @@ class SAML2Config(Config):
# #
#grandfathered_mxid_source_attribute: upn #grandfathered_mxid_source_attribute: upn
# It is possible to configure Synapse to only allow logins if SAML attributes
# match particular values. The requirements can be listed under
# `attribute_requirements` as shown below. All of the listed attributes must
# match for the login to be permitted.
#
#attribute_requirements:
# - attribute: userGroup
# value: "staff"
# - attribute: department
# value: "sales"
# Directory in which Synapse will try to find the template files below. # Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used. # If not set, default templates from within the Synapse package will be used.
# #
@ -368,3 +387,34 @@ class SAML2Config(Config):
""" % { """ % {
"config_dir_path": config_dir_path "config_dir_path": config_dir_path
} }
@attr.s(frozen=True)
class SamlAttributeRequirement:
"""Object describing a single requirement for SAML attributes."""
attribute = attr.ib(type=str)
value = attr.ib(type=str)
JSON_SCHEMA = {
"type": "object",
"properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
"required": ["attribute", "value"],
}
ATTRIBUTE_REQUIREMENTS_SCHEMA = {
"type": "array",
"items": SamlAttributeRequirement.JSON_SCHEMA,
}
def _parse_attribute_requirements_def(
attribute_requirements: Any,
) -> List[SamlAttributeRequirement]:
validate_config(
ATTRIBUTE_REQUIREMENTS_SCHEMA,
attribute_requirements,
config_path=["saml2_config", "attribute_requirements"],
)
return [SamlAttributeRequirement(**x) for x in attribute_requirements]

View file

@ -57,13 +57,10 @@ class EventStreamHandler(BaseHandler):
timeout=0, timeout=0,
as_client_event=True, as_client_event=True,
affect_presence=True, affect_presence=True,
only_keys=None,
room_id=None, room_id=None,
is_guest=False, is_guest=False,
): ):
"""Fetches the events stream for a given user. """Fetches the events stream for a given user.
If `only_keys` is not None, events from keys will be sent down.
""" """
if room_id: if room_id:
@ -93,7 +90,6 @@ class EventStreamHandler(BaseHandler):
auth_user, auth_user,
pagin_config, pagin_config,
timeout, timeout,
only_keys=only_keys,
is_guest=is_guest, is_guest=is_guest,
explicit_room_id=room_id, explicit_room_id=room_id,
) )

View file

@ -14,15 +14,16 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re import re
from typing import Callable, Dict, Optional, Set, Tuple from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple
import attr import attr
import saml2 import saml2
import saml2.response import saml2.response
from saml2.client import Saml2Client from saml2.client import Saml2Client
from synapse.api.errors import SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
@ -34,6 +35,9 @@ from synapse.types import (
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.iterutils import chunk_seq from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
import synapse.server
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -49,7 +53,7 @@ class Saml2SessionData:
class SamlHandler: class SamlHandler:
def __init__(self, hs): def __init__(self, hs: "synapse.server.HomeServer"):
self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._auth = hs.get_auth() self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
@ -62,6 +66,7 @@ class SamlHandler:
self._grandfathered_mxid_source_attribute = ( self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute hs.config.saml2_grandfathered_mxid_source_attribute
) )
self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
# plugin to do custom mapping from saml response to mxid # plugin to do custom mapping from saml response to mxid
self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class( self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
@ -73,7 +78,7 @@ class SamlHandler:
self._auth_provider_id = "saml" self._auth_provider_id = "saml"
# a map from saml session id to Saml2SessionData object # a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
# a lock on the mappings # a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock) self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
@ -165,11 +170,18 @@ class SamlHandler:
saml2.BINDING_HTTP_POST, saml2.BINDING_HTTP_POST,
outstanding=self._outstanding_requests_dict, outstanding=self._outstanding_requests_dict,
) )
except saml2.response.UnsolicitedResponse as e:
# the pysaml2 library helpfully logs an ERROR here, but neglects to log
# the session ID. I don't really want to put the full text of the exception
# in the (user-visible) exception message, so let's log the exception here
# so we can track down the session IDs later.
logger.warning(str(e))
raise SynapseError(400, "Unexpected SAML2 login.")
except Exception as e: except Exception as e:
raise SynapseError(400, "Unable to parse SAML2 response: %s" % (e,)) raise SynapseError(400, "Unable to parse SAML2 response: %s." % (e,))
if saml2_auth.not_signed: if saml2_auth.not_signed:
raise SynapseError(400, "SAML2 response was not signed") raise SynapseError(400, "SAML2 response was not signed.")
logger.debug("SAML2 response: %s", saml2_auth.origxml) logger.debug("SAML2 response: %s", saml2_auth.origxml)
for assertion in saml2_auth.assertions: for assertion in saml2_auth.assertions:
@ -188,6 +200,9 @@ class SamlHandler:
saml2_auth.in_response_to, None saml2_auth.in_response_to, None
) )
for requirement in self._saml2_attribute_requirements:
_check_attribute_requirement(saml2_auth.ava, requirement)
remote_user_id = self._user_mapping_provider.get_remote_user_id( remote_user_id = self._user_mapping_provider.get_remote_user_id(
saml2_auth, client_redirect_url saml2_auth, client_redirect_url
) )
@ -294,6 +309,21 @@ class SamlHandler:
del self._outstanding_requests_dict[reqid] del self._outstanding_requests_dict[reqid]
def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement):
values = ava.get(req.attribute, [])
for v in values:
if v == req.value:
return
logger.info(
"SAML2 attribute %s did not match required value '%s' (was '%s')",
req.attribute,
req.value,
values,
)
raise AuthError(403, "You are not authorized to log in here.")
DOT_REPLACE_PATTERN = re.compile( DOT_REPLACE_PATTERN = re.compile(
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
) )

View file

@ -297,7 +297,7 @@ class SimpleHttpClient(object):
outgoing_requests_counter.labels(method).inc() outgoing_requests_counter.labels(method).inc()
# log request but strip `access_token` (AS requests for example include this) # log request but strip `access_token` (AS requests for example include this)
logger.info("Sending request %s %s", method, redact_uri(uri)) logger.debug("Sending request %s %s", method, redact_uri(uri))
with start_active_span( with start_active_span(
"outgoing-client-request", "outgoing-client-request",

View file

@ -247,7 +247,7 @@ class MatrixHostnameEndpoint(object):
port = server.port port = server.port
try: try:
logger.info("Connecting to %s:%i", host.decode("ascii"), port) logger.debug("Connecting to %s:%i", host.decode("ascii"), port)
endpoint = HostnameEndpoint(self._reactor, host, port) endpoint = HostnameEndpoint(self._reactor, host, port)
if self._tls_options: if self._tls_options:
endpoint = wrapClientTLS(self._tls_options, endpoint) endpoint = wrapClientTLS(self._tls_options, endpoint)

View file

@ -29,10 +29,11 @@ from zope.interface import implementer
from twisted.internet import defer, protocol from twisted.internet import defer, protocol
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorPluggableNameResolver from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime
from twisted.internet.task import _EPSILON, Cooperator from twisted.internet.task import _EPSILON, Cooperator
from twisted.web._newclient import ResponseDone from twisted.web._newclient import ResponseDone
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
import synapse.metrics import synapse.metrics
import synapse.util.retryutils import synapse.util.retryutils
@ -74,7 +75,7 @@ MAXINT = sys.maxsize
_next_id = 1 _next_id = 1
@attr.s @attr.s(frozen=True)
class MatrixFederationRequest(object): class MatrixFederationRequest(object):
method = attr.ib() method = attr.ib()
"""HTTP method """HTTP method
@ -110,26 +111,52 @@ class MatrixFederationRequest(object):
:type: str|None :type: str|None
""" """
uri = attr.ib(init=False, type=bytes)
"""The URI of this request
"""
def __attrs_post_init__(self): def __attrs_post_init__(self):
global _next_id global _next_id
self.txn_id = "%s-O-%s" % (self.method, _next_id) txn_id = "%s-O-%s" % (self.method, _next_id)
_next_id = (_next_id + 1) % (MAXINT - 1) _next_id = (_next_id + 1) % (MAXINT - 1)
object.__setattr__(self, "txn_id", txn_id)
destination_bytes = self.destination.encode("ascii")
path_bytes = self.path.encode("ascii")
if self.query:
query_bytes = encode_query_args(self.query)
else:
query_bytes = b""
# The object is frozen so we can pre-compute this.
uri = urllib.parse.urlunparse(
(b"matrix", destination_bytes, path_bytes, None, query_bytes, b"")
)
object.__setattr__(self, "uri", uri)
def get_json(self): def get_json(self):
if self.json_callback: if self.json_callback:
return self.json_callback() return self.json_callback()
return self.json return self.json
async def _handle_json_response(reactor, timeout_sec, request, response): async def _handle_json_response(
reactor: IReactorTime,
timeout_sec: float,
request: MatrixFederationRequest,
response: IResponse,
start_ms: int,
):
""" """
Reads the JSON body of a response, with a timeout Reads the JSON body of a response, with a timeout
Args: Args:
reactor (IReactor): twisted reactor, for the timeout reactor: twisted reactor, for the timeout
timeout_sec (float): number of seconds to wait for response to complete timeout_sec: number of seconds to wait for response to complete
request (MatrixFederationRequest): the request that triggered the response request: the request that triggered the response
response (IResponse): response to the request response: response to the request
start_ms: Timestamp when request was made
Returns: Returns:
dict: parsed JSON response dict: parsed JSON response
@ -143,23 +170,35 @@ async def _handle_json_response(reactor, timeout_sec, request, response):
body = await make_deferred_yieldable(d) body = await make_deferred_yieldable(d)
except TimeoutError as e: except TimeoutError as e:
logger.warning( logger.warning(
"{%s} [%s] Timed out reading response", request.txn_id, request.destination, "{%s} [%s] Timed out reading response - %s %s",
request.txn_id,
request.destination,
request.method,
request.uri.decode("ascii"),
) )
raise RequestSendFailed(e, can_retry=True) from e raise RequestSendFailed(e, can_retry=True) from e
except Exception as e: except Exception as e:
logger.warning( logger.warning(
"{%s} [%s] Error reading response: %s", "{%s} [%s] Error reading response %s %s: %s",
request.txn_id, request.txn_id,
request.destination, request.destination,
request.method,
request.uri.decode("ascii"),
e, e,
) )
raise raise
time_taken_secs = reactor.seconds() - start_ms / 1000
logger.info( logger.info(
"{%s} [%s] Completed: %d %s", "{%s} [%s] Completed request: %d %s in %.2f secs - %s %s",
request.txn_id, request.txn_id,
request.destination, request.destination,
response.code, response.code,
response.phrase.decode("ascii", errors="replace"), response.phrase.decode("ascii", errors="replace"),
time_taken_secs,
request.method,
request.uri.decode("ascii"),
) )
return body return body
@ -261,7 +300,9 @@ class MatrixFederationHttpClient(object):
# 'M_UNRECOGNIZED' which some endpoints can return when omitting a # 'M_UNRECOGNIZED' which some endpoints can return when omitting a
# trailing slash on Synapse <= v0.99.3. # trailing slash on Synapse <= v0.99.3.
logger.info("Retrying request with trailing slash") logger.info("Retrying request with trailing slash")
request.path += "/"
# Request is frozen so we create a new instance
request = attr.evolve(request, path=request.path + "/")
response = await self._send_request(request, **send_request_args) response = await self._send_request(request, **send_request_args)
@ -373,9 +414,7 @@ class MatrixFederationHttpClient(object):
else: else:
retries_left = MAX_SHORT_RETRIES retries_left = MAX_SHORT_RETRIES
url_bytes = urllib.parse.urlunparse( url_bytes = request.uri
(b"matrix", destination_bytes, path_bytes, None, query_bytes, b"")
)
url_str = url_bytes.decode("ascii") url_str = url_bytes.decode("ascii")
url_to_sign_bytes = urllib.parse.urlunparse( url_to_sign_bytes = urllib.parse.urlunparse(
@ -402,7 +441,7 @@ class MatrixFederationHttpClient(object):
headers_dict[b"Authorization"] = auth_headers headers_dict[b"Authorization"] = auth_headers
logger.info( logger.debug(
"{%s} [%s] Sending request: %s %s; timeout %fs", "{%s} [%s] Sending request: %s %s; timeout %fs",
request.txn_id, request.txn_id,
request.destination, request.destination,
@ -436,7 +475,6 @@ class MatrixFederationHttpClient(object):
except DNSLookupError as e: except DNSLookupError as e:
raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e
except Exception as e: except Exception as e:
logger.info("Failed to send request: %s", e)
raise RequestSendFailed(e, can_retry=True) from e raise RequestSendFailed(e, can_retry=True) from e
incoming_responses_counter.labels( incoming_responses_counter.labels(
@ -496,7 +534,7 @@ class MatrixFederationHttpClient(object):
break break
except RequestSendFailed as e: except RequestSendFailed as e:
logger.warning( logger.info(
"{%s} [%s] Request failed: %s %s: %s", "{%s} [%s] Request failed: %s %s: %s",
request.txn_id, request.txn_id,
request.destination, request.destination,
@ -654,6 +692,8 @@ class MatrixFederationHttpClient(object):
json=data, json=data,
) )
start_ms = self.clock.time_msec()
response = await self._send_request_with_optional_trailing_slash( response = await self._send_request_with_optional_trailing_slash(
request, request,
try_trailing_slash_on_400, try_trailing_slash_on_400,
@ -664,7 +704,7 @@ class MatrixFederationHttpClient(object):
) )
body = await _handle_json_response( body = await _handle_json_response(
self.reactor, self.default_timeout, request, response self.reactor, self.default_timeout, request, response, start_ms
) )
return body return body
@ -720,6 +760,8 @@ class MatrixFederationHttpClient(object):
method="POST", destination=destination, path=path, query=args, json=data method="POST", destination=destination, path=path, query=args, json=data
) )
start_ms = self.clock.time_msec()
response = await self._send_request( response = await self._send_request(
request, request,
long_retries=long_retries, long_retries=long_retries,
@ -733,7 +775,7 @@ class MatrixFederationHttpClient(object):
_sec_timeout = self.default_timeout _sec_timeout = self.default_timeout
body = await _handle_json_response( body = await _handle_json_response(
self.reactor, _sec_timeout, request, response self.reactor, _sec_timeout, request, response, start_ms,
) )
return body return body
@ -786,6 +828,8 @@ class MatrixFederationHttpClient(object):
method="GET", destination=destination, path=path, query=args method="GET", destination=destination, path=path, query=args
) )
start_ms = self.clock.time_msec()
response = await self._send_request_with_optional_trailing_slash( response = await self._send_request_with_optional_trailing_slash(
request, request,
try_trailing_slash_on_400, try_trailing_slash_on_400,
@ -796,7 +840,7 @@ class MatrixFederationHttpClient(object):
) )
body = await _handle_json_response( body = await _handle_json_response(
self.reactor, self.default_timeout, request, response self.reactor, self.default_timeout, request, response, start_ms
) )
return body return body
@ -846,6 +890,8 @@ class MatrixFederationHttpClient(object):
method="DELETE", destination=destination, path=path, query=args method="DELETE", destination=destination, path=path, query=args
) )
start_ms = self.clock.time_msec()
response = await self._send_request( response = await self._send_request(
request, request,
long_retries=long_retries, long_retries=long_retries,
@ -854,7 +900,7 @@ class MatrixFederationHttpClient(object):
) )
body = await _handle_json_response( body = await _handle_json_response(
self.reactor, self.default_timeout, request, response self.reactor, self.default_timeout, request, response, start_ms
) )
return body return body
@ -914,12 +960,14 @@ class MatrixFederationHttpClient(object):
) )
raise raise
logger.info( logger.info(
"{%s} [%s] Completed: %d %s [%d bytes]", "{%s} [%s] Completed: %d %s [%d bytes] %s %s",
request.txn_id, request.txn_id,
request.destination, request.destination,
response.code, response.code,
response.phrase.decode("ascii", errors="replace"), response.phrase.decode("ascii", errors="replace"),
length, length,
request.method,
request.uri.decode("ascii"),
) )
return (length, headers) return (length, headers)

View file

@ -15,7 +15,17 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Callable, Iterable, List, TypeVar from typing import (
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
)
from prometheus_client import Counter from prometheus_client import Counter
@ -24,12 +34,14 @@ from twisted.internet import defer
import synapse.server import synapse.server
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.logging.context import PreserveLoggingContext from synapse.logging.context import PreserveLoggingContext
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import StreamToken from synapse.streams.config import PaginationConfig
from synapse.types import Collection, StreamToken, UserID
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -77,7 +89,13 @@ class _NotifierUserStream(object):
so that it can remove itself from the indexes in the Notifier class. so that it can remove itself from the indexes in the Notifier class.
""" """
def __init__(self, user_id, rooms, current_token, time_now_ms): def __init__(
self,
user_id: str,
rooms: Collection[str],
current_token: StreamToken,
time_now_ms: int,
):
self.user_id = user_id self.user_id = user_id
self.rooms = set(rooms) self.rooms = set(rooms)
self.current_token = current_token self.current_token = current_token
@ -93,13 +111,13 @@ class _NotifierUserStream(object):
with PreserveLoggingContext(): with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred()) self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(self, stream_key, stream_id, time_now_ms): def notify(self, stream_key: str, stream_id: int, time_now_ms: int):
"""Notify any listeners for this user of a new event from an """Notify any listeners for this user of a new event from an
event source. event source.
Args: Args:
stream_key(str): The stream the event came from. stream_key: The stream the event came from.
stream_id(str): The new id for the stream the event came from. stream_id: The new id for the stream the event came from.
time_now_ms(int): The current time in milliseconds. time_now_ms: The current time in milliseconds.
""" """
self.current_token = self.current_token.copy_and_advance(stream_key, stream_id) self.current_token = self.current_token.copy_and_advance(stream_key, stream_id)
self.last_notified_token = self.current_token self.last_notified_token = self.current_token
@ -112,7 +130,7 @@ class _NotifierUserStream(object):
self.notify_deferred = ObservableDeferred(defer.Deferred()) self.notify_deferred = ObservableDeferred(defer.Deferred())
noify_deferred.callback(self.current_token) noify_deferred.callback(self.current_token)
def remove(self, notifier): def remove(self, notifier: "Notifier"):
""" Remove this listener from all the indexes in the Notifier """ Remove this listener from all the indexes in the Notifier
it knows about. it knows about.
""" """
@ -123,10 +141,10 @@ class _NotifierUserStream(object):
notifier.user_to_user_stream.pop(self.user_id) notifier.user_to_user_stream.pop(self.user_id)
def count_listeners(self): def count_listeners(self) -> int:
return len(self.notify_deferred.observers()) return len(self.notify_deferred.observers())
def new_listener(self, token): def new_listener(self, token: StreamToken) -> _NotificationListener:
"""Returns a deferred that is resolved when there is a new token """Returns a deferred that is resolved when there is a new token
greater than the given token. greater than the given token.
@ -159,14 +177,16 @@ class Notifier(object):
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
def __init__(self, hs: "synapse.server.HomeServer"): def __init__(self, hs: "synapse.server.HomeServer"):
self.user_to_user_stream = {} self.user_to_user_stream = {} # type: Dict[str, _NotifierUserStream]
self.room_to_user_streams = {} self.room_to_user_streams = {} # type: Dict[str, Set[_NotifierUserStream]]
self.hs = hs self.hs = hs
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.pending_new_room_events = [] self.pending_new_room_events = (
[]
) # type: List[Tuple[int, EventBase, Collection[str]]]
# Called when there are new things to stream over replication # Called when there are new things to stream over replication
self.replication_callbacks = [] # type: List[Callable[[], None]] self.replication_callbacks = [] # type: List[Callable[[], None]]
@ -178,10 +198,9 @@ class Notifier(object):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
self.federation_sender = None
if hs.should_send_federation(): if hs.should_send_federation():
self.federation_sender = hs.get_federation_sender() self.federation_sender = hs.get_federation_sender()
else:
self.federation_sender = None
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
@ -193,12 +212,12 @@ class Notifier(object):
# when rendering the metrics page, which is likely once per minute at # when rendering the metrics page, which is likely once per minute at
# most when scraping it. # most when scraping it.
def count_listeners(): def count_listeners():
all_user_streams = set() all_user_streams = set() # type: Set[_NotifierUserStream]
for x in list(self.room_to_user_streams.values()): for streams in list(self.room_to_user_streams.values()):
all_user_streams |= x all_user_streams |= streams
for x in list(self.user_to_user_stream.values()): for stream in list(self.user_to_user_stream.values()):
all_user_streams.add(x) all_user_streams.add(stream)
return sum(stream.count_listeners() for stream in all_user_streams) return sum(stream.count_listeners() for stream in all_user_streams)
@ -223,7 +242,11 @@ class Notifier(object):
self.replication_callbacks.append(cb) self.replication_callbacks.append(cb)
def on_new_room_event( def on_new_room_event(
self, event, room_stream_id, max_room_stream_id, extra_users=[] self,
event: EventBase,
room_stream_id: int,
max_room_stream_id: int,
extra_users: Collection[str] = [],
): ):
""" Used by handlers to inform the notifier something has happened """ Used by handlers to inform the notifier something has happened
in the room, room event wise. in the room, room event wise.
@ -241,11 +264,11 @@ class Notifier(object):
self.notify_replication() self.notify_replication()
def _notify_pending_new_room_events(self, max_room_stream_id): def _notify_pending_new_room_events(self, max_room_stream_id: int):
"""Notify for the room events that were queued waiting for a previous """Notify for the room events that were queued waiting for a previous
event to be persisted. event to be persisted.
Args: Args:
max_room_stream_id(int): The highest stream_id below which all max_room_stream_id: The highest stream_id below which all
events have been persisted. events have been persisted.
""" """
pending = self.pending_new_room_events pending = self.pending_new_room_events
@ -258,7 +281,9 @@ class Notifier(object):
else: else:
self._on_new_room_event(event, room_stream_id, extra_users) self._on_new_room_event(event, room_stream_id, extra_users)
def _on_new_room_event(self, event, room_stream_id, extra_users=[]): def _on_new_room_event(
self, event: EventBase, room_stream_id: int, extra_users: Collection[str] = []
):
"""Notify any user streams that are interested in this room event""" """Notify any user streams that are interested in this room event"""
# poke any interested application service. # poke any interested application service.
run_as_background_process( run_as_background_process(
@ -275,13 +300,19 @@ class Notifier(object):
"room_key", room_stream_id, users=extra_users, rooms=[event.room_id] "room_key", room_stream_id, users=extra_users, rooms=[event.room_id]
) )
async def _notify_app_services(self, room_stream_id): async def _notify_app_services(self, room_stream_id: int):
try: try:
await self.appservice_handler.notify_interested_services(room_stream_id) await self.appservice_handler.notify_interested_services(room_stream_id)
except Exception: except Exception:
logger.exception("Error notifying application services of event") logger.exception("Error notifying application services of event")
def on_new_event(self, stream_key, new_token, users=[], rooms=[]): def on_new_event(
self,
stream_key: str,
new_token: int,
users: Collection[str] = [],
rooms: Collection[str] = [],
):
""" Used to inform listeners that something has happened event wise. """ Used to inform listeners that something has happened event wise.
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
@ -307,14 +338,19 @@ class Notifier(object):
self.notify_replication() self.notify_replication()
def on_new_replication_data(self): def on_new_replication_data(self) -> None:
"""Used to inform replication listeners that something has happend """Used to inform replication listeners that something has happend
without waking up any of the normal user event streams""" without waking up any of the normal user event streams"""
self.notify_replication() self.notify_replication()
async def wait_for_events( async def wait_for_events(
self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START self,
): user_id: str,
timeout: int,
callback: Callable[[StreamToken, StreamToken], Awaitable[T]],
room_ids=None,
from_token=StreamToken.START,
) -> T:
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
""" """
@ -377,19 +413,16 @@ class Notifier(object):
async def get_events_for( async def get_events_for(
self, self,
user, user: UserID,
pagination_config, pagination_config: PaginationConfig,
timeout, timeout: int,
only_keys=None, is_guest: bool = False,
is_guest=False, explicit_room_id: str = None,
explicit_room_id=None, ) -> EventStreamResult:
):
""" For the given user and rooms, return any new events for them. If """ For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning. new events to happen before returning.
If `only_keys` is not None, events from keys will be sent down.
If explicit_room_id is not set, the user's joined rooms will be polled If explicit_room_id is not set, the user's joined rooms will be polled
for events. for events.
If explicit_room_id is set, that room will be polled for events only if If explicit_room_id is set, that room will be polled for events only if
@ -404,11 +437,13 @@ class Notifier(object):
room_ids, is_joined = await self._get_room_ids(user, explicit_room_id) room_ids, is_joined = await self._get_room_ids(user, explicit_room_id)
is_peeking = not is_joined is_peeking = not is_joined
async def check_for_updates(before_token, after_token): async def check_for_updates(
before_token: StreamToken, after_token: StreamToken
) -> EventStreamResult:
if not after_token.is_after(before_token): if not after_token.is_after(before_token):
return EventStreamResult([], (from_token, from_token)) return EventStreamResult([], (from_token, from_token))
events = [] events = [] # type: List[EventBase]
end_token = from_token end_token = from_token
for name, source in self.event_sources.sources.items(): for name, source in self.event_sources.sources.items():
@ -417,8 +452,6 @@ class Notifier(object):
after_id = getattr(after_token, keyname) after_id = getattr(after_token, keyname)
if before_id == after_id: if before_id == after_id:
continue continue
if only_keys and name not in only_keys:
continue
new_events, new_key = await source.get_new_events( new_events, new_key = await source.get_new_events(
user=user, user=user,
@ -476,7 +509,9 @@ class Notifier(object):
return result return result
async def _get_room_ids(self, user, explicit_room_id): async def _get_room_ids(
self, user: UserID, explicit_room_id: Optional[str]
) -> Tuple[Collection[str], bool]:
joined_room_ids = await self.store.get_rooms_for_user(user.to_string()) joined_room_ids = await self.store.get_rooms_for_user(user.to_string())
if explicit_room_id: if explicit_room_id:
if explicit_room_id in joined_room_ids: if explicit_room_id in joined_room_ids:
@ -486,7 +521,7 @@ class Notifier(object):
raise AuthError(403, "Non-joined access not allowed") raise AuthError(403, "Non-joined access not allowed")
return joined_room_ids, True return joined_room_ids, True
async def _is_world_readable(self, room_id): async def _is_world_readable(self, room_id: str) -> bool:
state = await self.state_handler.get_current_state( state = await self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, "" room_id, EventTypes.RoomHistoryVisibility, ""
) )
@ -496,7 +531,7 @@ class Notifier(object):
return False return False
@log_function @log_function
def remove_expired_streams(self): def remove_expired_streams(self) -> None:
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
expired_streams = [] expired_streams = []
expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
@ -510,21 +545,21 @@ class Notifier(object):
expired_stream.remove(self) expired_stream.remove(self)
@log_function @log_function
def _register_with_keys(self, user_stream): def _register_with_keys(self, user_stream: _NotifierUserStream):
self.user_to_user_stream[user_stream.user_id] = user_stream self.user_to_user_stream[user_stream.user_id] = user_stream
for room in user_stream.rooms: for room in user_stream.rooms:
s = self.room_to_user_streams.setdefault(room, set()) s = self.room_to_user_streams.setdefault(room, set())
s.add(user_stream) s.add(user_stream)
def _user_joined_room(self, user_id, room_id): def _user_joined_room(self, user_id: str, room_id: str):
new_user_stream = self.user_to_user_stream.get(user_id) new_user_stream = self.user_to_user_stream.get(user_id)
if new_user_stream is not None: if new_user_stream is not None:
room_streams = self.room_to_user_streams.setdefault(room_id, set()) room_streams = self.room_to_user_streams.setdefault(room_id, set())
room_streams.add(new_user_stream) room_streams.add(new_user_stream)
new_user_stream.rooms.add(room_id) new_user_stream.rooms.add(room_id)
def notify_replication(self): def notify_replication(self) -> None:
"""Notify the any replication listeners that there's a new event""" """Notify the any replication listeners that there's a new event"""
for cb in self.replication_callbacks: for cb in self.replication_callbacks:
cb() cb()

View file

@ -2,10 +2,17 @@
<html lang="en"> <html lang="en">
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<title>SSO error</title> <title>SSO login error</title>
</head> </head>
<body> <body>
<p>Oops! Something went wrong during authentication<span id="errormsg"></span>.</p> {# a 403 means we have actively rejected their login #}
{% if code == 403 %}
<p>You are not allowed to log in here.</p>
{% else %}
<p>
There was an error during authentication:
</p>
<div id="errormsg" style="margin:20px 80px">{{ msg }}</div>
<p> <p>
If you are seeing this page after clicking a link sent to you via email, make If you are seeing this page after clicking a link sent to you via email, make
sure you only click the confirmation link once, and that you open the sure you only click the confirmation link once, and that you open the
@ -37,9 +44,9 @@
// to print one. // to print one.
let errorDesc = new URLSearchParams(searchStr).get("error_description") let errorDesc = new URLSearchParams(searchStr).get("error_description")
if (errorDesc) { if (errorDesc) {
document.getElementById("errormsg").innerText = errorDesc;
document.getElementById("errormsg").innerText = ` ("${errorDesc}")`;
} }
</script> </script>
{% endif %}
</body> </body>
</html> </html>

View file

@ -198,6 +198,7 @@ commands = mypy \
synapse/logging/ \ synapse/logging/ \
synapse/metrics \ synapse/metrics \
synapse/module_api \ synapse/module_api \
synapse/notifier.py \
synapse/push/pusherpool.py \ synapse/push/pusherpool.py \
synapse/push/push_rule_evaluator.py \ synapse/push/push_rule_evaluator.py \
synapse/replication \ synapse/replication \