mirror of
https://github.com/element-hq/synapse
synced 2024-10-01 21:32:40 +00:00
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
This commit is contained in:
commit
1ff3bc332a
108 changed files with 2360 additions and 1540 deletions
|
@ -46,7 +46,7 @@ locally. You'll need python 3.6 or later, and to install a number of tools:
|
|||
|
||||
```
|
||||
# Install the dependencies
|
||||
pip install -e ".[lint]"
|
||||
pip install -e ".[lint,mypy]"
|
||||
|
||||
# Run the linter script
|
||||
./scripts-dev/lint.sh
|
||||
|
@ -63,7 +63,7 @@ run-time:
|
|||
./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder
|
||||
```
|
||||
|
||||
You can also provided the `-d` option, which will lint the files that have been
|
||||
You can also provide the `-d` option, which will lint the files that have been
|
||||
changed since the last git commit. This will often be significantly faster than
|
||||
linting the whole codebase.
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ light workloads.
|
|||
System requirements:
|
||||
|
||||
- POSIX-compliant system (tested on Linux & OS X)
|
||||
- Python 3.5.2 or later, up to Python 3.8.
|
||||
- Python 3.5.2 or later, up to Python 3.9.
|
||||
- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org
|
||||
|
||||
Synapse is written in Python but some of the libraries it uses are written in
|
||||
|
|
10
README.rst
10
README.rst
|
@ -256,9 +256,9 @@ directory of your choice::
|
|||
Synapse has a number of external dependencies, that are easiest
|
||||
to install using pip and a virtualenv::
|
||||
|
||||
virtualenv -p python3 env
|
||||
source env/bin/activate
|
||||
python -m pip install --no-use-pep517 -e ".[all]"
|
||||
python3 -m venv ./env
|
||||
source ./env/bin/activate
|
||||
pip install -e ".[all,test]"
|
||||
|
||||
This will run a process of downloading and installing all the needed
|
||||
dependencies into a virtual env.
|
||||
|
@ -270,9 +270,9 @@ check that everything is installed as it should be::
|
|||
|
||||
This should end with a 'PASSED' result::
|
||||
|
||||
Ran 143 tests in 0.601s
|
||||
Ran 1266 tests in 643.930s
|
||||
|
||||
PASSED (successes=143)
|
||||
PASSED (skips=15, successes=1251)
|
||||
|
||||
Running the Integration Tests
|
||||
=============================
|
||||
|
|
16
UPGRADE.rst
16
UPGRADE.rst
|
@ -75,6 +75,22 @@ for example:
|
|||
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
||||
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
||||
|
||||
Upgrading to v1.23.0
|
||||
====================
|
||||
|
||||
Structured logging configuration breaking changes
|
||||
-------------------------------------------------
|
||||
|
||||
This release deprecates use of the ``structured: true`` logging configuration for
|
||||
structured logging. If your logging configuration contains ``structured: true``
|
||||
then it should be modified based on the `structured logging documentation
|
||||
<https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md>`_.
|
||||
|
||||
The ``structured`` and ``drains`` logging options are now deprecated and should
|
||||
be replaced by standard logging configuration of ``handlers`` and ``formatters`.
|
||||
|
||||
A future will release of Synapse will make using ``structured: true`` an error.
|
||||
|
||||
Upgrading to v1.22.0
|
||||
====================
|
||||
|
||||
|
|
1
changelog.d/8559.misc
Normal file
1
changelog.d/8559.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Optimise `/createRoom` with multiple invited users.
|
1
changelog.d/8595.misc
Normal file
1
changelog.d/8595.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Implement and use an @lru_cache decorator.
|
1
changelog.d/8607.feature
Normal file
1
changelog.d/8607.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Support generating structured logs via the standard logging configuration.
|
1
changelog.d/8610.feature
Normal file
1
changelog.d/8610.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add an admin APIs to allow server admins to list users' pushers. Contributed by @dklimpel.
|
1
changelog.d/8616.misc
Normal file
1
changelog.d/8616.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Change schema to support access tokens belonging to one user but granting access to another.
|
1
changelog.d/8633.misc
Normal file
1
changelog.d/8633.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Run `mypy` as part of the lint.sh script.
|
1
changelog.d/8655.misc
Normal file
1
changelog.d/8655.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add more type hints to the application services code.
|
1
changelog.d/8664.misc
Normal file
1
changelog.d/8664.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Tell Black to format code for Python 3.5.
|
1
changelog.d/8665.doc
Normal file
1
changelog.d/8665.doc
Normal file
|
@ -0,0 +1 @@
|
|||
Note support for Python 3.9.
|
1
changelog.d/8666.doc
Normal file
1
changelog.d/8666.doc
Normal file
|
@ -0,0 +1 @@
|
|||
Minor updates to docs on running tests.
|
1
changelog.d/8667.doc
Normal file
1
changelog.d/8667.doc
Normal file
|
@ -0,0 +1 @@
|
|||
Interlink prometheus/grafana documentation.
|
1
changelog.d/8669.misc
Normal file
1
changelog.d/8669.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Don't pull event from DB when handling replication traffic.
|
1
changelog.d/8671.misc
Normal file
1
changelog.d/8671.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Abstract some invite-related code in preparation for landing knocking.
|
1
changelog.d/8676.bugfix
Normal file
1
changelog.d/8676.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix a bug where an appservice may not be forwarded events for a room it was recently invited to. Broken in v1.22.0.
|
1
changelog.d/8678.bugfix
Normal file
1
changelog.d/8678.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix `Object of type frozendict is not JSON serializable` exceptions when using third-party event rules.
|
1
changelog.d/8679.misc
Normal file
1
changelog.d/8679.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Clarify representation of events in logfiles.
|
1
changelog.d/8680.misc
Normal file
1
changelog.d/8680.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Don't require `hiredis` package to be installed to run unit tests.
|
1
changelog.d/8682.bugfix
Normal file
1
changelog.d/8682.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix exception during handling multiple concurrent requests for remote media when using multiple media repositories.
|
1
changelog.d/8684.misc
Normal file
1
changelog.d/8684.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Fix typing info on cache call signature to accept `on_invalidate`.
|
1
changelog.d/8685.feature
Normal file
1
changelog.d/8685.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Support generating structured logs via the standard logging configuration.
|
1
changelog.d/8688.misc
Normal file
1
changelog.d/8688.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Abstract some invite-related code in preparation for landing knocking.
|
1
changelog.d/8689.feature
Normal file
1
changelog.d/8689.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add an admin APIs to allow server admins to list users' pushers. Contributed by @dklimpel.
|
1
changelog.d/8690.misc
Normal file
1
changelog.d/8690.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Fail tests if they do not await coroutines.
|
|
@ -3,4 +3,4 @@
|
|||
0. Set up Prometheus and Grafana. Out of scope for this readme. Useful documentation about using Grafana with Prometheus: http://docs.grafana.org/features/datasources/prometheus/
|
||||
1. Have your Prometheus scrape your Synapse. https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md
|
||||
2. Import dashboard into Grafana. Download `synapse.json`. Import it to Grafana and select the correct Prometheus datasource. http://docs.grafana.org/reference/export_import/
|
||||
3. Set up additional recording rules
|
||||
3. Set up required recording rules. https://github.com/matrix-org/synapse/tree/master/contrib/prometheus
|
||||
|
|
|
@ -611,3 +611,82 @@ The following parameters should be set in the URL:
|
|||
|
||||
- ``user_id`` - fully qualified: for example, ``@user:server.com``.
|
||||
- ``device_id`` - The device to delete.
|
||||
|
||||
List all pushers
|
||||
================
|
||||
Gets information about all pushers for a specific ``user_id``.
|
||||
|
||||
The API is::
|
||||
|
||||
GET /_synapse/admin/v1/users/<user_id>/pushers
|
||||
|
||||
To use it, you will need to authenticate by providing an ``access_token`` for a
|
||||
server admin: see `README.rst <README.rst>`_.
|
||||
|
||||
A response body like the following is returned:
|
||||
|
||||
.. code:: json
|
||||
|
||||
{
|
||||
"pushers": [
|
||||
{
|
||||
"app_display_name":"HTTP Push Notifications",
|
||||
"app_id":"m.http",
|
||||
"data": {
|
||||
"url":"example.com"
|
||||
},
|
||||
"device_display_name":"pushy push",
|
||||
"kind":"http",
|
||||
"lang":"None",
|
||||
"profile_tag":"",
|
||||
"pushkey":"a@example.com"
|
||||
}
|
||||
],
|
||||
"total": 1
|
||||
}
|
||||
|
||||
**Parameters**
|
||||
|
||||
The following parameters should be set in the URL:
|
||||
|
||||
- ``user_id`` - fully qualified: for example, ``@user:server.com``.
|
||||
|
||||
**Response**
|
||||
|
||||
The following fields are returned in the JSON response body:
|
||||
|
||||
- ``pushers`` - An array containing the current pushers for the user
|
||||
|
||||
- ``app_display_name`` - string - A string that will allow the user to identify
|
||||
what application owns this pusher.
|
||||
|
||||
- ``app_id`` - string - This is a reverse-DNS style identifier for the application.
|
||||
Max length, 64 chars.
|
||||
|
||||
- ``data`` - A dictionary of information for the pusher implementation itself.
|
||||
|
||||
- ``url`` - string - Required if ``kind`` is ``http``. The URL to use to send
|
||||
notifications to.
|
||||
|
||||
- ``format`` - string - The format to use when sending notifications to the
|
||||
Push Gateway.
|
||||
|
||||
- ``device_display_name`` - string - A string that will allow the user to identify
|
||||
what device owns this pusher.
|
||||
|
||||
- ``profile_tag`` - string - This string determines which set of device specific rules
|
||||
this pusher executes.
|
||||
|
||||
- ``kind`` - string - The kind of pusher. "http" is a pusher that sends HTTP pokes.
|
||||
- ``lang`` - string - The preferred language for receiving notifications
|
||||
(e.g. 'en' or 'en-US')
|
||||
|
||||
- ``profile_tag`` - string - This string determines which set of device specific rules
|
||||
this pusher executes.
|
||||
|
||||
- ``pushkey`` - string - This is a unique identifier for this pusher.
|
||||
Max length, 512 bytes.
|
||||
|
||||
- ``total`` - integer - Number of pushers.
|
||||
|
||||
See also `Client-Server API Spec <https://matrix.org/docs/spec/client_server/latest#get-matrix-client-r0-pushers>`_
|
||||
|
|
|
@ -60,6 +60,8 @@
|
|||
|
||||
1. Restart Prometheus.
|
||||
|
||||
1. Consider using the [grafana dashboard](https://github.com/matrix-org/synapse/tree/master/contrib/grafana/) and required [recording rules](https://github.com/matrix-org/synapse/tree/master/contrib/prometheus/)
|
||||
|
||||
## Monitoring workers
|
||||
|
||||
To monitor a Synapse installation using
|
||||
|
|
|
@ -3,7 +3,11 @@
|
|||
# This is a YAML file containing a standard Python logging configuration
|
||||
# dictionary. See [1] for details on the valid settings.
|
||||
#
|
||||
# Synapse also supports structured logging for machine readable logs which can
|
||||
# be ingested by ELK stacks. See [2] for details.
|
||||
#
|
||||
# [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema
|
||||
# [2]: https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md
|
||||
|
||||
version: 1
|
||||
|
||||
|
|
|
@ -1,11 +1,116 @@
|
|||
# Structured Logging
|
||||
|
||||
A structured logging system can be useful when your logs are destined for a machine to parse and process. By maintaining its machine-readable characteristics, it enables more efficient searching and aggregations when consumed by software such as the "ELK stack".
|
||||
A structured logging system can be useful when your logs are destined for a
|
||||
machine to parse and process. By maintaining its machine-readable characteristics,
|
||||
it enables more efficient searching and aggregations when consumed by software
|
||||
such as the "ELK stack".
|
||||
|
||||
Synapse's structured logging system is configured via the file that Synapse's `log_config` config option points to. The file must be YAML and contain `structured: true`. It must contain a list of "drains" (places where logs go to).
|
||||
Synapse's structured logging system is configured via the file that Synapse's
|
||||
`log_config` config option points to. The file should include a formatter which
|
||||
uses the `synapse.logging.TerseJsonFormatter` class included with Synapse and a
|
||||
handler which uses the above formatter.
|
||||
|
||||
There is also a `synapse.logging.JsonFormatter` option which does not include
|
||||
a timestamp in the resulting JSON. This is useful if the log ingester adds its
|
||||
own timestamp.
|
||||
|
||||
A structured logging configuration looks similar to the following:
|
||||
|
||||
```yaml
|
||||
version: 1
|
||||
|
||||
formatters:
|
||||
structured:
|
||||
class: synapse.logging.TerseJsonFormatter
|
||||
|
||||
handlers:
|
||||
file:
|
||||
class: logging.handlers.TimedRotatingFileHandler
|
||||
formatter: structured
|
||||
filename: /path/to/my/logs/homeserver.log
|
||||
when: midnight
|
||||
backupCount: 3 # Does not include the current log file.
|
||||
encoding: utf8
|
||||
|
||||
loggers:
|
||||
synapse:
|
||||
level: INFO
|
||||
handlers: [remote]
|
||||
synapse.storage.SQL:
|
||||
level: WARNING
|
||||
```
|
||||
|
||||
The above logging config will set Synapse as 'INFO' logging level by default,
|
||||
with the SQL layer at 'WARNING', and will log to a file, stored as JSON.
|
||||
|
||||
It is also possible to figure Synapse to log to a remote endpoint by using the
|
||||
`synapse.logging.RemoteHandler` class included with Synapse. It takes the
|
||||
following arguments:
|
||||
|
||||
- `host`: Hostname or IP address of the log aggregator.
|
||||
- `port`: Numerical port to contact on the host.
|
||||
- `maximum_buffer`: (Optional, defaults to 1000) The maximum buffer size to allow.
|
||||
|
||||
A remote structured logging configuration looks similar to the following:
|
||||
|
||||
```yaml
|
||||
version: 1
|
||||
|
||||
formatters:
|
||||
structured:
|
||||
class: synapse.logging.TerseJsonFormatter
|
||||
|
||||
handlers:
|
||||
remote:
|
||||
class: synapse.logging.RemoteHandler
|
||||
formatter: structured
|
||||
host: 10.1.2.3
|
||||
port: 9999
|
||||
|
||||
loggers:
|
||||
synapse:
|
||||
level: INFO
|
||||
handlers: [remote]
|
||||
synapse.storage.SQL:
|
||||
level: WARNING
|
||||
```
|
||||
|
||||
The above logging config will set Synapse as 'INFO' logging level by default,
|
||||
with the SQL layer at 'WARNING', and will log JSON formatted messages to a
|
||||
remote endpoint at 10.1.2.3:9999.
|
||||
|
||||
## Upgrading from legacy structured logging configuration
|
||||
|
||||
Versions of Synapse prior to v1.23.0 included a custom structured logging
|
||||
configuration which is deprecated. It used a `structured: true` flag and
|
||||
configured `drains` instead of ``handlers`` and `formatters`.
|
||||
|
||||
Synapse currently automatically converts the old configuration to the new
|
||||
configuration, but this will be removed in a future version of Synapse. The
|
||||
following reference can be used to update your configuration. Based on the drain
|
||||
`type`, we can pick a new handler:
|
||||
|
||||
1. For a type of `console`, `console_json`, or `console_json_terse`: a handler
|
||||
with a class of `logging.StreamHandler` and a `stream` of `ext://sys.stdout`
|
||||
or `ext://sys.stderr` should be used.
|
||||
2. For a type of `file` or `file_json`: a handler of `logging.FileHandler` with
|
||||
a location of the file path should be used.
|
||||
3. For a type of `network_json_terse`: a handler of `synapse.logging.RemoteHandler`
|
||||
with the host and port should be used.
|
||||
|
||||
Then based on the drain `type` we can pick a new formatter:
|
||||
|
||||
1. For a type of `console` or `file` no formatter is necessary.
|
||||
2. For a type of `console_json` or `file_json`: a formatter of
|
||||
`synapse.logging.JsonFormatter` should be used.
|
||||
3. For a type of `console_json_terse` or `network_json_terse`: a formatter of
|
||||
`synapse.logging.TerseJsonFormatter` should be used.
|
||||
|
||||
For each new handler and formatter they should be added to the logging configuration
|
||||
and then assigned to either a logger or the root logger.
|
||||
|
||||
An example legacy configuration:
|
||||
|
||||
```yaml
|
||||
structured: true
|
||||
|
||||
|
@ -24,60 +129,33 @@ drains:
|
|||
location: homeserver.log
|
||||
```
|
||||
|
||||
The above logging config will set Synapse as 'INFO' logging level by default, with the SQL layer at 'WARNING', and will have two logging drains (to the console and to a file, stored as JSON).
|
||||
Would be converted into a new configuration:
|
||||
|
||||
## Drain Types
|
||||
```yaml
|
||||
version: 1
|
||||
|
||||
Drain types can be specified by the `type` key.
|
||||
formatters:
|
||||
json:
|
||||
class: synapse.logging.JsonFormatter
|
||||
|
||||
### `console`
|
||||
handlers:
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
location: ext://sys.stdout
|
||||
file:
|
||||
class: logging.FileHandler
|
||||
formatter: json
|
||||
filename: homeserver.log
|
||||
|
||||
Outputs human-readable logs to the console.
|
||||
loggers:
|
||||
synapse:
|
||||
level: INFO
|
||||
handlers: [console, file]
|
||||
synapse.storage.SQL:
|
||||
level: WARNING
|
||||
```
|
||||
|
||||
Arguments:
|
||||
|
||||
- `location`: Either `stdout` or `stderr`.
|
||||
|
||||
### `console_json`
|
||||
|
||||
Outputs machine-readable JSON logs to the console.
|
||||
|
||||
Arguments:
|
||||
|
||||
- `location`: Either `stdout` or `stderr`.
|
||||
|
||||
### `console_json_terse`
|
||||
|
||||
Outputs machine-readable JSON logs to the console, separated by newlines. This
|
||||
format is not designed to be read and re-formatted into human-readable text, but
|
||||
is optimal for a logging aggregation system.
|
||||
|
||||
Arguments:
|
||||
|
||||
- `location`: Either `stdout` or `stderr`.
|
||||
|
||||
### `file`
|
||||
|
||||
Outputs human-readable logs to a file.
|
||||
|
||||
Arguments:
|
||||
|
||||
- `location`: An absolute path to the file to log to.
|
||||
|
||||
### `file_json`
|
||||
|
||||
Outputs machine-readable logs to a file.
|
||||
|
||||
Arguments:
|
||||
|
||||
- `location`: An absolute path to the file to log to.
|
||||
|
||||
### `network_json_terse`
|
||||
|
||||
Delivers machine-readable JSON logs to a log aggregator over TCP. This is
|
||||
compatible with LogStash's TCP input with the codec set to `json_lines`.
|
||||
|
||||
Arguments:
|
||||
|
||||
- `host`: Hostname or IP address of the log aggregator.
|
||||
- `port`: Numerical port to contact on the host.
|
||||
The new logging configuration is a bit more verbose, but significantly more
|
||||
flexible. It allows for configuration that were not previously possible, such as
|
||||
sending plain logs over the network, or using different handlers for different
|
||||
modules.
|
||||
|
|
4
mypy.ini
4
mypy.ini
|
@ -57,6 +57,7 @@ files =
|
|||
synapse/server_notices,
|
||||
synapse/spam_checker_api,
|
||||
synapse/state,
|
||||
synapse/storage/databases/main/appservice.py,
|
||||
synapse/storage/databases/main/events.py,
|
||||
synapse/storage/databases/main/registration.py,
|
||||
synapse/storage/databases/main/stream.py,
|
||||
|
@ -82,6 +83,9 @@ ignore_missing_imports = True
|
|||
[mypy-zope]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-bcrypt]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-constantly]
|
||||
ignore_missing_imports = True
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@
|
|||
showcontent = true
|
||||
|
||||
[tool.black]
|
||||
target-version = ['py34']
|
||||
target-version = ['py35']
|
||||
exclude = '''
|
||||
|
||||
(
|
||||
|
|
|
@ -80,7 +80,7 @@ else
|
|||
# then lint everything!
|
||||
if [[ -z ${files+x} ]]; then
|
||||
# Lint all source code files and directories
|
||||
files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py")
|
||||
files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark")
|
||||
fi
|
||||
fi
|
||||
|
||||
|
@ -94,3 +94,4 @@ isort "${files[@]}"
|
|||
python3 -m black "${files[@]}"
|
||||
./scripts-dev/config-lint.sh
|
||||
flake8 "${files[@]}"
|
||||
mypy
|
||||
|
|
|
@ -19,9 +19,10 @@ can crop up, e.g the cache descriptors.
|
|||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from mypy.nodes import ARG_NAMED_OPT
|
||||
from mypy.plugin import MethodSigContext, Plugin
|
||||
from mypy.typeops import bind_self
|
||||
from mypy.types import CallableType
|
||||
from mypy.types import CallableType, NoneType
|
||||
|
||||
|
||||
class SynapsePlugin(Plugin):
|
||||
|
@ -40,8 +41,9 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
|
|||
|
||||
It already has *almost* the correct signature, except:
|
||||
|
||||
1. the `self` argument needs to be marked as "bound"; and
|
||||
2. any `cache_context` argument should be removed.
|
||||
1. the `self` argument needs to be marked as "bound";
|
||||
2. any `cache_context` argument should be removed;
|
||||
3. an optional keyword argument `on_invalidated` should be added.
|
||||
"""
|
||||
|
||||
# First we mark this as a bound function signature.
|
||||
|
@ -58,16 +60,30 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
|
|||
context_arg_index = idx
|
||||
break
|
||||
|
||||
if context_arg_index:
|
||||
arg_types = list(signature.arg_types)
|
||||
arg_types.pop(context_arg_index)
|
||||
|
||||
arg_names = list(signature.arg_names)
|
||||
arg_names.pop(context_arg_index)
|
||||
|
||||
arg_kinds = list(signature.arg_kinds)
|
||||
|
||||
if context_arg_index:
|
||||
arg_types.pop(context_arg_index)
|
||||
arg_names.pop(context_arg_index)
|
||||
arg_kinds.pop(context_arg_index)
|
||||
|
||||
# Third, we add an optional "on_invalidate" argument.
|
||||
#
|
||||
# This is a callable which accepts no input and returns nothing.
|
||||
calltyp = CallableType(
|
||||
arg_types=[],
|
||||
arg_kinds=[],
|
||||
arg_names=[],
|
||||
ret_type=NoneType(),
|
||||
fallback=ctx.api.named_generic_type("builtins.function", []),
|
||||
)
|
||||
|
||||
arg_types.append(calltyp)
|
||||
arg_names.append("on_invalidate")
|
||||
arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg.
|
||||
|
||||
signature = signature.copy_modified(
|
||||
arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
|
||||
)
|
||||
|
|
1
setup.py
1
setup.py
|
@ -131,6 +131,7 @@ setup(
|
|||
"Programming Language :: Python :: 3.6",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
],
|
||||
scripts=["synctl"] + glob.glob("scripts/*"),
|
||||
cmdclass={"test": TestCommand},
|
||||
|
|
|
@ -33,6 +33,7 @@ from synapse.api.errors import (
|
|||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events import EventBase
|
||||
from synapse.logging import opentracing as opentracing
|
||||
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||
from synapse.types import StateMap, UserID
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.metrics import Measure
|
||||
|
@ -190,10 +191,6 @@ class Auth:
|
|||
|
||||
user_id, app_service = await self._get_appservice_user_id(request)
|
||||
if user_id:
|
||||
request.authenticated_entity = user_id
|
||||
opentracing.set_tag("authenticated_entity", user_id)
|
||||
opentracing.set_tag("appservice_id", app_service.id)
|
||||
|
||||
if ip_addr and self._track_appservice_user_ips:
|
||||
await self.store.insert_client_ip(
|
||||
user_id=user_id,
|
||||
|
@ -203,31 +200,38 @@ class Auth:
|
|||
device_id="dummy-device", # stubbed
|
||||
)
|
||||
|
||||
return synapse.types.create_requester(user_id, app_service=app_service)
|
||||
requester = synapse.types.create_requester(
|
||||
user_id, app_service=app_service
|
||||
)
|
||||
|
||||
request.requester = user_id
|
||||
opentracing.set_tag("authenticated_entity", user_id)
|
||||
opentracing.set_tag("user_id", user_id)
|
||||
opentracing.set_tag("appservice_id", app_service.id)
|
||||
|
||||
return requester
|
||||
|
||||
user_info = await self.get_user_by_access_token(
|
||||
access_token, rights, allow_expired=allow_expired
|
||||
)
|
||||
user = user_info["user"]
|
||||
token_id = user_info["token_id"]
|
||||
is_guest = user_info["is_guest"]
|
||||
shadow_banned = user_info["shadow_banned"]
|
||||
token_id = user_info.token_id
|
||||
is_guest = user_info.is_guest
|
||||
shadow_banned = user_info.shadow_banned
|
||||
|
||||
# Deny the request if the user account has expired.
|
||||
if self._account_validity.enabled and not allow_expired:
|
||||
user_id = user.to_string()
|
||||
if await self.store.is_account_expired(user_id, self.clock.time_msec()):
|
||||
if await self.store.is_account_expired(
|
||||
user_info.user_id, self.clock.time_msec()
|
||||
):
|
||||
raise AuthError(
|
||||
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
|
||||
)
|
||||
|
||||
# device_id may not be present if get_user_by_access_token has been
|
||||
# stubbed out.
|
||||
device_id = user_info.get("device_id")
|
||||
device_id = user_info.device_id
|
||||
|
||||
if user and access_token and ip_addr:
|
||||
if access_token and ip_addr:
|
||||
await self.store.insert_client_ip(
|
||||
user_id=user.to_string(),
|
||||
user_id=user_info.token_owner,
|
||||
access_token=access_token,
|
||||
ip=ip_addr,
|
||||
user_agent=user_agent,
|
||||
|
@ -241,19 +245,23 @@ class Auth:
|
|||
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
|
||||
)
|
||||
|
||||
request.authenticated_entity = user.to_string()
|
||||
opentracing.set_tag("authenticated_entity", user.to_string())
|
||||
if device_id:
|
||||
opentracing.set_tag("device_id", device_id)
|
||||
|
||||
return synapse.types.create_requester(
|
||||
user,
|
||||
requester = synapse.types.create_requester(
|
||||
user_info.user_id,
|
||||
token_id,
|
||||
is_guest,
|
||||
shadow_banned,
|
||||
device_id,
|
||||
app_service=app_service,
|
||||
authenticated_entity=user_info.token_owner,
|
||||
)
|
||||
|
||||
request.requester = requester
|
||||
opentracing.set_tag("authenticated_entity", user_info.token_owner)
|
||||
opentracing.set_tag("user_id", user_info.user_id)
|
||||
if device_id:
|
||||
opentracing.set_tag("device_id", device_id)
|
||||
|
||||
return requester
|
||||
except KeyError:
|
||||
raise MissingClientTokenError()
|
||||
|
||||
|
@ -284,7 +292,7 @@ class Auth:
|
|||
|
||||
async def get_user_by_access_token(
|
||||
self, token: str, rights: str = "access", allow_expired: bool = False,
|
||||
) -> dict:
|
||||
) -> TokenLookupResult:
|
||||
""" Validate access token and get user_id from it
|
||||
|
||||
Args:
|
||||
|
@ -293,13 +301,7 @@ class Auth:
|
|||
allow this
|
||||
allow_expired: If False, raises an InvalidClientTokenError
|
||||
if the token is expired
|
||||
Returns:
|
||||
dict that includes:
|
||||
`user` (UserID)
|
||||
`is_guest` (bool)
|
||||
`shadow_banned` (bool)
|
||||
`token_id` (int|None): access token id. May be None if guest
|
||||
`device_id` (str|None): device corresponding to access token
|
||||
|
||||
Raises:
|
||||
InvalidClientTokenError if a user by that token exists, but the token is
|
||||
expired
|
||||
|
@ -309,9 +311,9 @@ class Auth:
|
|||
|
||||
if rights == "access":
|
||||
# first look in the database
|
||||
r = await self._look_up_user_by_access_token(token)
|
||||
r = await self.store.get_user_by_access_token(token)
|
||||
if r:
|
||||
valid_until_ms = r["valid_until_ms"]
|
||||
valid_until_ms = r.valid_until_ms
|
||||
if (
|
||||
not allow_expired
|
||||
and valid_until_ms is not None
|
||||
|
@ -328,7 +330,6 @@ class Auth:
|
|||
# otherwise it needs to be a valid macaroon
|
||||
try:
|
||||
user_id, guest = self._parse_and_validate_macaroon(token, rights)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
if rights == "access":
|
||||
if not guest:
|
||||
|
@ -354,23 +355,17 @@ class Auth:
|
|||
raise InvalidClientTokenError(
|
||||
"Guest access token used for regular user"
|
||||
)
|
||||
ret = {
|
||||
"user": user,
|
||||
"is_guest": True,
|
||||
"shadow_banned": False,
|
||||
"token_id": None,
|
||||
|
||||
ret = TokenLookupResult(
|
||||
user_id=user_id,
|
||||
is_guest=True,
|
||||
# all guests get the same device id
|
||||
"device_id": GUEST_DEVICE_ID,
|
||||
}
|
||||
device_id=GUEST_DEVICE_ID,
|
||||
)
|
||||
elif rights == "delete_pusher":
|
||||
# We don't store these tokens in the database
|
||||
ret = {
|
||||
"user": user,
|
||||
"is_guest": False,
|
||||
"shadow_banned": False,
|
||||
"token_id": None,
|
||||
"device_id": None,
|
||||
}
|
||||
|
||||
ret = TokenLookupResult(user_id=user_id, is_guest=False)
|
||||
else:
|
||||
raise RuntimeError("Unknown rights setting %s", rights)
|
||||
return ret
|
||||
|
@ -479,31 +474,15 @@ class Auth:
|
|||
now = self.hs.get_clock().time_msec()
|
||||
return now < expiry
|
||||
|
||||
async def _look_up_user_by_access_token(self, token):
|
||||
ret = await self.store.get_user_by_access_token(token)
|
||||
if not ret:
|
||||
return None
|
||||
|
||||
# we use ret.get() below because *lots* of unit tests stub out
|
||||
# get_user_by_access_token in a way where it only returns a couple of
|
||||
# the fields.
|
||||
user_info = {
|
||||
"user": UserID.from_string(ret.get("name")),
|
||||
"token_id": ret.get("token_id", None),
|
||||
"is_guest": False,
|
||||
"shadow_banned": ret.get("shadow_banned"),
|
||||
"device_id": ret.get("device_id"),
|
||||
"valid_until_ms": ret.get("valid_until_ms"),
|
||||
}
|
||||
return user_info
|
||||
|
||||
def get_appservice_by_req(self, request):
|
||||
token = self.get_access_token_from_request(request)
|
||||
service = self.store.get_app_service_by_token(token)
|
||||
if not service:
|
||||
logger.warning("Unrecognised appservice access token.")
|
||||
raise InvalidClientTokenError()
|
||||
request.authenticated_entity = service.sender
|
||||
request.requester = synapse.types.create_requester(
|
||||
service.sender, app_service=service
|
||||
)
|
||||
return service
|
||||
|
||||
async def is_server_admin(self, user: UserID) -> bool:
|
||||
|
|
|
@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Iterable, List, Match, Optional
|
|||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.caches.descriptors import _CacheContext, cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.appservice.api import ApplicationServiceApi
|
||||
|
@ -52,11 +52,11 @@ class ApplicationService:
|
|||
self,
|
||||
token,
|
||||
hostname,
|
||||
id,
|
||||
sender,
|
||||
url=None,
|
||||
namespaces=None,
|
||||
hs_token=None,
|
||||
sender=None,
|
||||
id=None,
|
||||
protocols=None,
|
||||
rate_limited=True,
|
||||
ip_range_whitelist=None,
|
||||
|
@ -164,9 +164,9 @@ class ApplicationService:
|
|||
does_match = await self.matches_user_in_member_list(event.room_id, store)
|
||||
return does_match
|
||||
|
||||
@cached(num_args=1)
|
||||
@cached(num_args=1, cache_context=True)
|
||||
async def matches_user_in_member_list(
|
||||
self, room_id: str, store: "DataStore"
|
||||
self, room_id: str, store: "DataStore", cache_context: _CacheContext,
|
||||
) -> bool:
|
||||
"""Check if this service is interested a room based upon it's membership
|
||||
|
||||
|
@ -177,7 +177,9 @@ class ApplicationService:
|
|||
Returns:
|
||||
True if this service would like to know about this room.
|
||||
"""
|
||||
member_list = await store.get_users_in_room(room_id)
|
||||
member_list = await store.get_users_in_room(
|
||||
room_id, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
|
||||
# check joined member events
|
||||
for user_id in member_list:
|
||||
|
|
|
@ -23,7 +23,6 @@ from string import Template
|
|||
import yaml
|
||||
|
||||
from twisted.logger import (
|
||||
ILogObserver,
|
||||
LogBeginner,
|
||||
STDLibLogObserver,
|
||||
eventAsText,
|
||||
|
@ -32,11 +31,9 @@ from twisted.logger import (
|
|||
|
||||
import synapse
|
||||
from synapse.app import _base as appbase
|
||||
from synapse.logging._structured import (
|
||||
reload_structured_logging,
|
||||
setup_structured_logging,
|
||||
)
|
||||
from synapse.logging._structured import setup_structured_logging
|
||||
from synapse.logging.context import LoggingContextFilter
|
||||
from synapse.logging.filter import MetadataFilter
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
@ -48,7 +45,11 @@ DEFAULT_LOG_CONFIG = Template(
|
|||
# This is a YAML file containing a standard Python logging configuration
|
||||
# dictionary. See [1] for details on the valid settings.
|
||||
#
|
||||
# Synapse also supports structured logging for machine readable logs which can
|
||||
# be ingested by ELK stacks. See [2] for details.
|
||||
#
|
||||
# [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema
|
||||
# [2]: https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md
|
||||
|
||||
version: 1
|
||||
|
||||
|
@ -176,11 +177,11 @@ class LoggingConfig(Config):
|
|||
log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file))
|
||||
|
||||
|
||||
def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
|
||||
def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> None:
|
||||
"""
|
||||
Set up Python stdlib logging.
|
||||
Set up Python standard library logging.
|
||||
"""
|
||||
if log_config is None:
|
||||
if log_config_path is None:
|
||||
log_format = (
|
||||
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
|
||||
" - %(message)s"
|
||||
|
@ -196,7 +197,8 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
|
|||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
else:
|
||||
logging.config.dictConfig(log_config)
|
||||
# Load the logging configuration.
|
||||
_load_logging_config(log_config_path)
|
||||
|
||||
# We add a log record factory that runs all messages through the
|
||||
# LoggingContextFilter so that we get the context *at the time we log*
|
||||
|
@ -204,12 +206,14 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
|
|||
# filter options, but care must when using e.g. MemoryHandler to buffer
|
||||
# writes.
|
||||
|
||||
log_filter = LoggingContextFilter(request="")
|
||||
log_context_filter = LoggingContextFilter(request="")
|
||||
log_metadata_filter = MetadataFilter({"server_name": config.server_name})
|
||||
old_factory = logging.getLogRecordFactory()
|
||||
|
||||
def factory(*args, **kwargs):
|
||||
record = old_factory(*args, **kwargs)
|
||||
log_filter.filter(record)
|
||||
log_context_filter.filter(record)
|
||||
log_metadata_filter.filter(record)
|
||||
return record
|
||||
|
||||
logging.setLogRecordFactory(factory)
|
||||
|
@ -255,21 +259,40 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
|
|||
if not config.no_redirect_stdio:
|
||||
print("Redirected stdout/stderr to logs")
|
||||
|
||||
return observer
|
||||
|
||||
|
||||
def _reload_stdlib_logging(*args, log_config=None):
|
||||
logger = logging.getLogger("")
|
||||
def _load_logging_config(log_config_path: str) -> None:
|
||||
"""
|
||||
Configure logging from a log config path.
|
||||
"""
|
||||
with open(log_config_path, "rb") as f:
|
||||
log_config = yaml.safe_load(f.read())
|
||||
|
||||
if not log_config:
|
||||
logger.warning("Reloaded a blank config?")
|
||||
logging.warning("Loaded a blank logging config?")
|
||||
|
||||
# If the old structured logging configuration is being used, convert it to
|
||||
# the new style configuration.
|
||||
if "structured" in log_config and log_config.get("structured"):
|
||||
log_config = setup_structured_logging(log_config)
|
||||
|
||||
logging.config.dictConfig(log_config)
|
||||
|
||||
|
||||
def _reload_logging_config(log_config_path):
|
||||
"""
|
||||
Reload the log configuration from the file and apply it.
|
||||
"""
|
||||
# If no log config path was given, it cannot be reloaded.
|
||||
if log_config_path is None:
|
||||
return
|
||||
|
||||
_load_logging_config(log_config_path)
|
||||
logging.info("Reloaded log config from %s due to SIGHUP", log_config_path)
|
||||
|
||||
|
||||
def setup_logging(
|
||||
hs, config, use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner
|
||||
) -> ILogObserver:
|
||||
) -> None:
|
||||
"""
|
||||
Set up the logging subsystem.
|
||||
|
||||
|
@ -282,41 +305,18 @@ def setup_logging(
|
|||
|
||||
logBeginner: The Twisted logBeginner to use.
|
||||
|
||||
Returns:
|
||||
The "root" Twisted Logger observer, suitable for sending logs to from a
|
||||
Logger instance.
|
||||
"""
|
||||
log_config = config.worker_log_config if use_worker_options else config.log_config
|
||||
|
||||
def read_config(*args, callback=None):
|
||||
if log_config is None:
|
||||
return None
|
||||
|
||||
with open(log_config, "rb") as f:
|
||||
log_config_body = yaml.safe_load(f.read())
|
||||
|
||||
if callback:
|
||||
callback(log_config=log_config_body)
|
||||
logging.info("Reloaded log config from %s due to SIGHUP", log_config)
|
||||
|
||||
return log_config_body
|
||||
|
||||
log_config_body = read_config()
|
||||
|
||||
if log_config_body and log_config_body.get("structured") is True:
|
||||
logger = setup_structured_logging(
|
||||
hs, config, log_config_body, logBeginner=logBeginner
|
||||
log_config_path = (
|
||||
config.worker_log_config if use_worker_options else config.log_config
|
||||
)
|
||||
appbase.register_sighup(read_config, callback=reload_structured_logging)
|
||||
else:
|
||||
logger = _setup_stdlib_logging(config, log_config_body, logBeginner=logBeginner)
|
||||
appbase.register_sighup(read_config, callback=_reload_stdlib_logging)
|
||||
|
||||
# make sure that the first thing we log is a thing we can grep backwards
|
||||
# for
|
||||
# Perform one-time logging configuration.
|
||||
_setup_stdlib_logging(config, log_config_path, logBeginner=logBeginner)
|
||||
# Add a SIGHUP handler to reload the logging configuration, if one is available.
|
||||
appbase.register_sighup(_reload_logging_config, log_config_path)
|
||||
|
||||
# Log immediately so we can grep backwards.
|
||||
logging.warning("***** STARTING SERVER *****")
|
||||
logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse))
|
||||
logging.info("Server hostname: %s", config.server_name)
|
||||
logging.info("Instance name: %s", hs.get_instance_name())
|
||||
|
||||
return logger
|
||||
|
|
|
@ -368,7 +368,7 @@ class FrozenEvent(EventBase):
|
|||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
return "<FrozenEvent event_id='%s', type='%s', state_key='%s'>" % (
|
||||
return "<FrozenEvent event_id=%r, type=%r, state_key=%r>" % (
|
||||
self.get("event_id", None),
|
||||
self.get("type", None),
|
||||
self.get("state_key", None),
|
||||
|
@ -451,7 +451,7 @@ class FrozenEventV2(EventBase):
|
|||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s event_id='%s', type='%s', state_key='%s'>" % (
|
||||
return "<%s event_id=%r, type=%r, state_key=%r>" % (
|
||||
self.__class__.__name__,
|
||||
self.event_id,
|
||||
self.get("type", None),
|
||||
|
|
|
@ -154,7 +154,7 @@ class Authenticator:
|
|||
)
|
||||
|
||||
logger.debug("Request from %s", origin)
|
||||
request.authenticated_entity = origin
|
||||
request.requester = origin
|
||||
|
||||
# If we get a valid signed request from the other side, its probably
|
||||
# alive
|
||||
|
|
|
@ -12,9 +12,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
|
@ -34,16 +33,20 @@ from synapse.metrics.background_process_metrics import (
|
|||
run_as_background_process,
|
||||
wrap_as_background_process,
|
||||
)
|
||||
from synapse.types import Collection, JsonDict, RoomStreamToken, UserID
|
||||
from synapse.storage.databases.main.directory import RoomAliasMapping
|
||||
from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, UserID
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
|
||||
|
||||
|
||||
class ApplicationServicesHandler:
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.appservice_api = hs.get_application_service_api()
|
||||
|
@ -247,7 +250,9 @@ class ApplicationServicesHandler:
|
|||
service, "presence", new_token
|
||||
)
|
||||
|
||||
async def _handle_typing(self, service: ApplicationService, new_token: int):
|
||||
async def _handle_typing(
|
||||
self, service: ApplicationService, new_token: int
|
||||
) -> List[JsonDict]:
|
||||
typing_source = self.event_sources.sources["typing"]
|
||||
# Get the typing events from just before current
|
||||
typing, _ = await typing_source.get_new_events_as(
|
||||
|
@ -259,7 +264,7 @@ class ApplicationServicesHandler:
|
|||
)
|
||||
return typing
|
||||
|
||||
async def _handle_receipts(self, service: ApplicationService):
|
||||
async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]:
|
||||
from_key = await self.store.get_type_stream_id_for_appservice(
|
||||
service, "read_receipt"
|
||||
)
|
||||
|
@ -271,7 +276,7 @@ class ApplicationServicesHandler:
|
|||
|
||||
async def _handle_presence(
|
||||
self, service: ApplicationService, users: Collection[Union[str, UserID]]
|
||||
):
|
||||
) -> List[JsonDict]:
|
||||
events = [] # type: List[JsonDict]
|
||||
presence_source = self.event_sources.sources["presence"]
|
||||
from_key = await self.store.get_type_stream_id_for_appservice(
|
||||
|
@ -301,11 +306,11 @@ class ApplicationServicesHandler:
|
|||
|
||||
return events
|
||||
|
||||
async def query_user_exists(self, user_id):
|
||||
async def query_user_exists(self, user_id: str) -> bool:
|
||||
"""Check if any application service knows this user_id exists.
|
||||
|
||||
Args:
|
||||
user_id(str): The user to query if they exist on any AS.
|
||||
user_id: The user to query if they exist on any AS.
|
||||
Returns:
|
||||
True if this user exists on at least one application service.
|
||||
"""
|
||||
|
@ -316,11 +321,13 @@ class ApplicationServicesHandler:
|
|||
return True
|
||||
return False
|
||||
|
||||
async def query_room_alias_exists(self, room_alias):
|
||||
async def query_room_alias_exists(
|
||||
self, room_alias: RoomAlias
|
||||
) -> Optional[RoomAliasMapping]:
|
||||
"""Check if an application service knows this room alias exists.
|
||||
|
||||
Args:
|
||||
room_alias(RoomAlias): The room alias to query.
|
||||
room_alias: The room alias to query.
|
||||
Returns:
|
||||
namedtuple: with keys "room_id" and "servers" or None if no
|
||||
association can be found.
|
||||
|
@ -336,10 +343,13 @@ class ApplicationServicesHandler:
|
|||
)
|
||||
if is_known_alias:
|
||||
# the alias exists now so don't query more ASes.
|
||||
result = await self.store.get_association_from_room_alias(room_alias)
|
||||
return result
|
||||
return await self.store.get_association_from_room_alias(room_alias)
|
||||
|
||||
async def query_3pe(self, kind, protocol, fields):
|
||||
return None
|
||||
|
||||
async def query_3pe(
|
||||
self, kind: str, protocol: str, fields: Dict[bytes, List[bytes]]
|
||||
) -> List[JsonDict]:
|
||||
services = self._get_services_for_3pn(protocol)
|
||||
|
||||
results = await make_deferred_yieldable(
|
||||
|
@ -361,7 +371,9 @@ class ApplicationServicesHandler:
|
|||
|
||||
return ret
|
||||
|
||||
async def get_3pe_protocols(self, only_protocol=None):
|
||||
async def get_3pe_protocols(
|
||||
self, only_protocol: Optional[str] = None
|
||||
) -> Dict[str, JsonDict]:
|
||||
services = self.store.get_app_services()
|
||||
protocols = {} # type: Dict[str, List[JsonDict]]
|
||||
|
||||
|
@ -379,7 +391,7 @@ class ApplicationServicesHandler:
|
|||
if info is not None:
|
||||
protocols[p].append(info)
|
||||
|
||||
def _merge_instances(infos):
|
||||
def _merge_instances(infos: List[JsonDict]) -> JsonDict:
|
||||
if not infos:
|
||||
return {}
|
||||
|
||||
|
@ -394,19 +406,17 @@ class ApplicationServicesHandler:
|
|||
|
||||
return combined
|
||||
|
||||
for p in protocols.keys():
|
||||
protocols[p] = _merge_instances(protocols[p])
|
||||
return {p: _merge_instances(protocols[p]) for p in protocols.keys()}
|
||||
|
||||
return protocols
|
||||
|
||||
async def _get_services_for_event(self, event):
|
||||
async def _get_services_for_event(
|
||||
self, event: EventBase
|
||||
) -> List[ApplicationService]:
|
||||
"""Retrieve a list of application services interested in this event.
|
||||
|
||||
Args:
|
||||
event(Event): The event to check. Can be None if alias_list is not.
|
||||
event: The event to check. Can be None if alias_list is not.
|
||||
Returns:
|
||||
list<ApplicationService>: A list of services interested in this
|
||||
event based on the service regex.
|
||||
A list of services interested in this event based on the service regex.
|
||||
"""
|
||||
services = self.store.get_app_services()
|
||||
|
||||
|
@ -420,17 +430,15 @@ class ApplicationServicesHandler:
|
|||
|
||||
return interested_list
|
||||
|
||||
def _get_services_for_user(self, user_id):
|
||||
def _get_services_for_user(self, user_id: str) -> List[ApplicationService]:
|
||||
services = self.store.get_app_services()
|
||||
interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
|
||||
return interested_list
|
||||
return [s for s in services if (s.is_interested_in_user(user_id))]
|
||||
|
||||
def _get_services_for_3pn(self, protocol):
|
||||
def _get_services_for_3pn(self, protocol: str) -> List[ApplicationService]:
|
||||
services = self.store.get_app_services()
|
||||
interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
|
||||
return interested_list
|
||||
return [s for s in services if s.is_interested_in_protocol(protocol)]
|
||||
|
||||
async def _is_unknown_user(self, user_id):
|
||||
async def _is_unknown_user(self, user_id: str) -> bool:
|
||||
if not self.is_mine_id(user_id):
|
||||
# we don't know if they are unknown or not since it isn't one of our
|
||||
# users. We can't poke ASes.
|
||||
|
@ -445,9 +453,8 @@ class ApplicationServicesHandler:
|
|||
service_list = [s for s in services if s.sender == user_id]
|
||||
return len(service_list) == 0
|
||||
|
||||
async def _check_user_exists(self, user_id):
|
||||
async def _check_user_exists(self, user_id: str) -> bool:
|
||||
unknown_user = await self._is_unknown_user(user_id)
|
||||
if unknown_user:
|
||||
exists = await self.query_user_exists(user_id)
|
||||
return exists
|
||||
return await self.query_user_exists(user_id)
|
||||
return True
|
||||
|
|
|
@ -18,10 +18,20 @@ import logging
|
|||
import time
|
||||
import unicodedata
|
||||
import urllib.parse
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import attr
|
||||
import bcrypt # type: ignore[import]
|
||||
import bcrypt
|
||||
import pymacaroons
|
||||
|
||||
from synapse.api.constants import LoginType
|
||||
|
@ -49,6 +59,9 @@ from synapse.util.threepids import canonicalise_email
|
|||
|
||||
from ._base import BaseHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -149,11 +162,7 @@ class SsoLoginExtraAttributes:
|
|||
class AuthHandler(BaseHandler):
|
||||
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
|
||||
|
@ -982,17 +991,17 @@ class AuthHandler(BaseHandler):
|
|||
# This might return an awaitable, if it does block the log out
|
||||
# until it completes.
|
||||
result = provider.on_logged_out(
|
||||
user_id=str(user_info["user"]),
|
||||
device_id=user_info["device_id"],
|
||||
user_id=user_info.user_id,
|
||||
device_id=user_info.device_id,
|
||||
access_token=access_token,
|
||||
)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
|
||||
# delete pushers associated with this access token
|
||||
if user_info["token_id"] is not None:
|
||||
if user_info.token_id is not None:
|
||||
await self.hs.get_pusherpool().remove_pushers_by_access_token(
|
||||
str(user_info["user"]), (user_info["token_id"],)
|
||||
user_info.user_id, (user_info.token_id,)
|
||||
)
|
||||
|
||||
async def delete_access_tokens_for_user(
|
||||
|
|
|
@ -50,9 +50,8 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
|||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util import json_decoder, json_encoder
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.frozenutils import frozendict_json_encoder
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
|
@ -928,7 +927,7 @@ class EventCreationHandler:
|
|||
|
||||
# Ensure that we can round trip before trying to persist in db
|
||||
try:
|
||||
dump = frozendict_json_encoder.encode(event.content)
|
||||
dump = json_encoder.encode(event.content)
|
||||
json_decoder.decode(dump)
|
||||
except Exception:
|
||||
logger.exception("Failed to encode content: %r", event.content)
|
||||
|
@ -1100,34 +1099,13 @@ class EventCreationHandler:
|
|||
|
||||
if event.type == EventTypes.Member:
|
||||
if event.content["membership"] == Membership.INVITE:
|
||||
|
||||
def is_inviter_member_event(e):
|
||||
return e.type == EventTypes.Member and e.sender == event.sender
|
||||
|
||||
current_state_ids = await context.get_current_state_ids()
|
||||
|
||||
# We know this event is not an outlier, so this must be
|
||||
# non-None.
|
||||
assert current_state_ids is not None
|
||||
|
||||
state_to_include_ids = [
|
||||
e_id
|
||||
for k, e_id in current_state_ids.items()
|
||||
if k[0] in self.room_invite_state_types
|
||||
or k == (EventTypes.Member, event.sender)
|
||||
]
|
||||
|
||||
state_to_include = await self.store.get_events(state_to_include_ids)
|
||||
|
||||
event.unsigned["invite_room_state"] = [
|
||||
{
|
||||
"type": e.type,
|
||||
"state_key": e.state_key,
|
||||
"content": e.content,
|
||||
"sender": e.sender,
|
||||
}
|
||||
for e in state_to_include.values()
|
||||
]
|
||||
event.unsigned[
|
||||
"invite_room_state"
|
||||
] = await self.store.get_stripped_room_state_from_event_context(
|
||||
context,
|
||||
self.room_invite_state_types,
|
||||
membership_user_id=event.sender,
|
||||
)
|
||||
|
||||
invitee = UserID.from_string(event.state_key)
|
||||
if not self.hs.is_mine(invitee):
|
||||
|
|
|
@ -48,7 +48,7 @@ from synapse.util.wheel_timer import WheelTimer
|
|||
|
||||
MYPY = False
|
||||
if MYPY:
|
||||
import synapse.server
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -101,7 +101,7 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
|
|||
class BasePresenceHandler(abc.ABC):
|
||||
"""Parts of the PresenceHandler that are shared between workers and master"""
|
||||
|
||||
def __init__(self, hs: "synapse.server.HomeServer"):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
|
@ -199,7 +199,7 @@ class BasePresenceHandler(abc.ABC):
|
|||
|
||||
|
||||
class PresenceHandler(BasePresenceHandler):
|
||||
def __init__(self, hs: "synapse.server.HomeServer"):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.hs = hs
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
@ -1011,7 +1011,7 @@ def format_user_presence_state(state, now, include_user_id=True):
|
|||
|
||||
|
||||
class PresenceEventSource:
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
# We can't call get_presence_handler here because there's a cycle:
|
||||
#
|
||||
# Presence -> Notifier -> PresenceEventSource -> Presence
|
||||
|
@ -1071,12 +1071,14 @@ class PresenceEventSource:
|
|||
|
||||
users_interested_in = await self._get_interested_in(user, explicit_room_id)
|
||||
|
||||
user_ids_changed = set()
|
||||
user_ids_changed = set() # type: Collection[str]
|
||||
changed = None
|
||||
if from_key:
|
||||
changed = stream_change_cache.get_all_entities_changed(from_key)
|
||||
|
||||
if changed is not None and len(changed) < 500:
|
||||
assert isinstance(user_ids_changed, set)
|
||||
|
||||
# For small deltas, its quicker to get all changes and then
|
||||
# work out if we share a room or they're in our presence list
|
||||
get_updates_counter.labels("stream").inc()
|
||||
|
|
|
@ -115,7 +115,10 @@ class RegistrationHandler(BaseHandler):
|
|||
400, "User ID already taken.", errcode=Codes.USER_IN_USE
|
||||
)
|
||||
user_data = await self.auth.get_user_by_access_token(guest_access_token)
|
||||
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
|
||||
if (
|
||||
not user_data.is_guest
|
||||
or UserID.from_string(user_data.user_id).localpart != localpart
|
||||
):
|
||||
raise AuthError(
|
||||
403,
|
||||
"Cannot register taken user ID without valid guest "
|
||||
|
@ -741,7 +744,7 @@ class RegistrationHandler(BaseHandler):
|
|||
# up when the access token is saved, but that's quite an
|
||||
# invasive change I'd rather do separately.
|
||||
user_tuple = await self.store.get_user_by_access_token(token)
|
||||
token_id = user_tuple["token_id"]
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
await self.pusher_pool.add_pusher(
|
||||
user_id=user_id,
|
||||
|
|
|
@ -771,15 +771,22 @@ class RoomCreationHandler(BaseHandler):
|
|||
ratelimit=False,
|
||||
)
|
||||
|
||||
for invitee in invite_list:
|
||||
# we avoid dropping the lock between invites, as otherwise joins can
|
||||
# start coming in and making the createRoom slow.
|
||||
#
|
||||
# we also don't need to check the requester's shadow-ban here, as we
|
||||
# have already done so above (and potentially emptied invite_list).
|
||||
with (await self.room_member_handler.member_linearizer.queue((room_id,))):
|
||||
content = {}
|
||||
is_direct = config.get("is_direct", None)
|
||||
if is_direct:
|
||||
content["is_direct"] = is_direct
|
||||
|
||||
# Note that update_membership with an action of "invite" can raise a
|
||||
# ShadowBanError, but this was handled above by emptying invite_list.
|
||||
_, last_stream_id = await self.room_member_handler.update_membership(
|
||||
for invitee in invite_list:
|
||||
(
|
||||
_,
|
||||
last_stream_id,
|
||||
) = await self.room_member_handler.update_membership_locked(
|
||||
requester,
|
||||
UserID.from_string(invitee),
|
||||
room_id,
|
||||
|
|
|
@ -327,7 +327,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
# haproxy would have timed the request out anyway...
|
||||
raise SynapseError(504, "took to long to process")
|
||||
|
||||
result = await self._update_membership(
|
||||
result = await self.update_membership_locked(
|
||||
requester,
|
||||
target,
|
||||
room_id,
|
||||
|
@ -342,7 +342,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
|
||||
return result
|
||||
|
||||
async def _update_membership(
|
||||
async def update_membership_locked(
|
||||
self,
|
||||
requester: Requester,
|
||||
target: UserID,
|
||||
|
@ -355,6 +355,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
content: Optional[dict] = None,
|
||||
require_consent: bool = True,
|
||||
) -> Tuple[str, int]:
|
||||
"""Helper for update_membership.
|
||||
|
||||
Assumes that the membership linearizer is already held for the room.
|
||||
"""
|
||||
content_specified = bool(content)
|
||||
if content is None:
|
||||
content = {}
|
||||
|
|
|
@ -359,7 +359,7 @@ class SimpleHttpClient:
|
|||
agent=self.agent,
|
||||
data=body_producer,
|
||||
headers=headers,
|
||||
**self._extra_treq_args
|
||||
**self._extra_treq_args,
|
||||
) # type: defer.Deferred
|
||||
|
||||
# we use our own timeout mechanism rather than treq's as a workaround
|
||||
|
|
|
@ -35,8 +35,6 @@ from twisted.web.server import NOT_DONE_YET, Request
|
|||
from twisted.web.static import File, NoRangeStaticProducer
|
||||
from twisted.web.util import redirectTo
|
||||
|
||||
import synapse.events
|
||||
import synapse.metrics
|
||||
from synapse.api.errors import (
|
||||
CodeMessageException,
|
||||
Codes,
|
||||
|
@ -620,7 +618,7 @@ def respond_with_json(
|
|||
if pretty_print:
|
||||
encoder = iterencode_pretty_printed_json
|
||||
else:
|
||||
if canonical_json or synapse.events.USE_FROZEN_DICTS:
|
||||
if canonical_json:
|
||||
encoder = iterencode_canonical_json
|
||||
else:
|
||||
encoder = _encode_json_bytes
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web.server import Request, Site
|
||||
|
@ -23,6 +23,7 @@ from synapse.config.server import ListenerConfig
|
|||
from synapse.http import redact_uri
|
||||
from synapse.http.request_metrics import RequestMetrics, requests_counter
|
||||
from synapse.logging.context import LoggingContext, PreserveLoggingContext
|
||||
from synapse.types import Requester
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -54,9 +55,12 @@ class SynapseRequest(Request):
|
|||
Request.__init__(self, channel, *args, **kw)
|
||||
self.site = channel.site
|
||||
self._channel = channel # this is used by the tests
|
||||
self.authenticated_entity = None
|
||||
self.start_time = 0.0
|
||||
|
||||
# The requester, if authenticated. For federation requests this is the
|
||||
# server name, for client requests this is the Requester object.
|
||||
self.requester = None # type: Optional[Union[Requester, str]]
|
||||
|
||||
# we can't yet create the logcontext, as we don't know the method.
|
||||
self.logcontext = None # type: Optional[LoggingContext]
|
||||
|
||||
|
@ -271,11 +275,23 @@ class SynapseRequest(Request):
|
|||
# to the client (nb may be negative)
|
||||
response_send_time = self.finish_time - self._processing_finished_time
|
||||
|
||||
# need to decode as it could be raw utf-8 bytes
|
||||
# from a IDN servname in an auth header
|
||||
authenticated_entity = self.authenticated_entity
|
||||
if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
|
||||
authenticated_entity = authenticated_entity.decode("utf-8", "replace")
|
||||
# Convert the requester into a string that we can log
|
||||
authenticated_entity = None
|
||||
if isinstance(self.requester, str):
|
||||
authenticated_entity = self.requester
|
||||
elif isinstance(self.requester, Requester):
|
||||
authenticated_entity = self.requester.authenticated_entity
|
||||
|
||||
# If this is a request where the target user doesn't match the user who
|
||||
# authenticated (e.g. and admin is puppetting a user) then we log both.
|
||||
if self.requester.user.to_string() != authenticated_entity:
|
||||
authenticated_entity = "{},{}".format(
|
||||
authenticated_entity, self.requester.user.to_string(),
|
||||
)
|
||||
elif self.requester is not None:
|
||||
# This shouldn't happen, but we log it so we don't lose information
|
||||
# and can see that we're doing something wrong.
|
||||
authenticated_entity = repr(self.requester) # type: ignore[unreachable]
|
||||
|
||||
# ...or could be raw utf-8 bytes in the User-Agent header.
|
||||
# N.B. if you don't do this, the logger explodes cryptically
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# These are imported to allow for nicer logging configuration files.
|
||||
from synapse.logging._remote import RemoteHandler
|
||||
from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
|
||||
|
||||
__all__ = ["RemoteHandler", "JsonFormatter", "TerseJsonFormatter"]
|
|
@ -13,6 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
from collections import deque
|
||||
|
@ -21,10 +22,11 @@ from math import floor
|
|||
from typing import Callable, Optional
|
||||
|
||||
import attr
|
||||
from typing_extensions import Deque
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.application.internet import ClientService
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.defer import CancelledError, Deferred
|
||||
from twisted.internet.endpoints import (
|
||||
HostnameEndpoint,
|
||||
TCP4ClientEndpoint,
|
||||
|
@ -32,7 +34,9 @@ from twisted.internet.endpoints import (
|
|||
)
|
||||
from twisted.internet.interfaces import IPushProducer, ITransport
|
||||
from twisted.internet.protocol import Factory, Protocol
|
||||
from twisted.logger import ILogObserver, Logger, LogLevel
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@attr.s
|
||||
|
@ -45,11 +49,11 @@ class LogProducer:
|
|||
Args:
|
||||
buffer: Log buffer to read logs from.
|
||||
transport: Transport to write to.
|
||||
format_event: A callable to format the log entry to a string.
|
||||
format: A callable to format the log record to a string.
|
||||
"""
|
||||
|
||||
transport = attr.ib(type=ITransport)
|
||||
format_event = attr.ib(type=Callable[[dict], str])
|
||||
_format = attr.ib(type=Callable[[logging.LogRecord], str])
|
||||
_buffer = attr.ib(type=deque)
|
||||
_paused = attr.ib(default=False, type=bool, init=False)
|
||||
|
||||
|
@ -61,16 +65,19 @@ class LogProducer:
|
|||
self._buffer = deque()
|
||||
|
||||
def resumeProducing(self):
|
||||
# If we're already producing, nothing to do.
|
||||
self._paused = False
|
||||
|
||||
# Loop until paused.
|
||||
while self._paused is False and (self._buffer and self.transport.connected):
|
||||
try:
|
||||
# Request the next event and format it.
|
||||
event = self._buffer.popleft()
|
||||
msg = self.format_event(event)
|
||||
# Request the next record and format it.
|
||||
record = self._buffer.popleft()
|
||||
msg = self._format(record)
|
||||
|
||||
# Send it as a new line over the transport.
|
||||
self.transport.write(msg.encode("utf8"))
|
||||
self.transport.write(b"\n")
|
||||
except Exception:
|
||||
# Something has gone wrong writing to the transport -- log it
|
||||
# and break out of the while.
|
||||
|
@ -78,76 +85,85 @@ class LogProducer:
|
|||
break
|
||||
|
||||
|
||||
@attr.s
|
||||
@implementer(ILogObserver)
|
||||
class TCPLogObserver:
|
||||
class RemoteHandler(logging.Handler):
|
||||
"""
|
||||
An IObserver that writes JSON logs to a TCP target.
|
||||
An logging handler that writes logs to a TCP target.
|
||||
|
||||
Args:
|
||||
hs (HomeServer): The homeserver that is being logged for.
|
||||
host: The host of the logging target.
|
||||
port: The logging target's port.
|
||||
format_event: A callable to format the log entry to a string.
|
||||
maximum_buffer: The maximum buffer size.
|
||||
"""
|
||||
|
||||
hs = attr.ib()
|
||||
host = attr.ib(type=str)
|
||||
port = attr.ib(type=int)
|
||||
format_event = attr.ib(type=Callable[[dict], str])
|
||||
maximum_buffer = attr.ib(type=int)
|
||||
_buffer = attr.ib(default=attr.Factory(deque), type=deque)
|
||||
_connection_waiter = attr.ib(default=None, type=Optional[Deferred])
|
||||
_logger = attr.ib(default=attr.Factory(Logger))
|
||||
_producer = attr.ib(default=None, type=Optional[LogProducer])
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
maximum_buffer: int = 1000,
|
||||
level=logging.NOTSET,
|
||||
_reactor=None,
|
||||
):
|
||||
super().__init__(level=level)
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.maximum_buffer = maximum_buffer
|
||||
|
||||
def start(self) -> None:
|
||||
self._buffer = deque() # type: Deque[logging.LogRecord]
|
||||
self._connection_waiter = None # type: Optional[Deferred]
|
||||
self._producer = None # type: Optional[LogProducer]
|
||||
|
||||
# Connect without DNS lookups if it's a direct IP.
|
||||
if _reactor is None:
|
||||
from twisted.internet import reactor
|
||||
|
||||
_reactor = reactor
|
||||
|
||||
try:
|
||||
ip = ip_address(self.host)
|
||||
if isinstance(ip, IPv4Address):
|
||||
endpoint = TCP4ClientEndpoint(
|
||||
self.hs.get_reactor(), self.host, self.port
|
||||
)
|
||||
endpoint = TCP4ClientEndpoint(_reactor, self.host, self.port)
|
||||
elif isinstance(ip, IPv6Address):
|
||||
endpoint = TCP6ClientEndpoint(
|
||||
self.hs.get_reactor(), self.host, self.port
|
||||
)
|
||||
endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port)
|
||||
else:
|
||||
raise ValueError("Unknown IP address provided: %s" % (self.host,))
|
||||
except ValueError:
|
||||
endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
|
||||
endpoint = HostnameEndpoint(_reactor, self.host, self.port)
|
||||
|
||||
factory = Factory.forProtocol(Protocol)
|
||||
self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
|
||||
self._service = ClientService(endpoint, factory, clock=_reactor)
|
||||
self._service.startService()
|
||||
self._stopping = False
|
||||
self._connect()
|
||||
|
||||
def stop(self):
|
||||
def close(self):
|
||||
self._stopping = True
|
||||
self._service.stopService()
|
||||
|
||||
def _connect(self) -> None:
|
||||
"""
|
||||
Triggers an attempt to connect then write to the remote if not already writing.
|
||||
"""
|
||||
# Do not attempt to open multiple connections.
|
||||
if self._connection_waiter:
|
||||
return
|
||||
|
||||
self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
|
||||
|
||||
@self._connection_waiter.addErrback
|
||||
def fail(r):
|
||||
r.printTraceback(file=sys.__stderr__)
|
||||
def fail(failure: Failure) -> None:
|
||||
# If the Deferred was cancelled (e.g. during shutdown) do not try to
|
||||
# reconnect (this will cause an infinite loop of errors).
|
||||
if failure.check(CancelledError) and self._stopping:
|
||||
return
|
||||
|
||||
# For a different error, print the traceback and re-connect.
|
||||
failure.printTraceback(file=sys.__stderr__)
|
||||
self._connection_waiter = None
|
||||
self._connect()
|
||||
|
||||
@self._connection_waiter.addCallback
|
||||
def writer(r):
|
||||
def writer(result: Protocol) -> None:
|
||||
# We have a connection. If we already have a producer, and its
|
||||
# transport is the same, just trigger a resumeProducing.
|
||||
if self._producer and r.transport is self._producer.transport:
|
||||
if self._producer and result.transport is self._producer.transport:
|
||||
self._producer.resumeProducing()
|
||||
self._connection_waiter = None
|
||||
return
|
||||
|
@ -158,29 +174,29 @@ class TCPLogObserver:
|
|||
|
||||
# Make a new producer and start it.
|
||||
self._producer = LogProducer(
|
||||
buffer=self._buffer,
|
||||
transport=r.transport,
|
||||
format_event=self.format_event,
|
||||
buffer=self._buffer, transport=result.transport, format=self.format,
|
||||
)
|
||||
r.transport.registerProducer(self._producer, True)
|
||||
result.transport.registerProducer(self._producer, True)
|
||||
self._producer.resumeProducing()
|
||||
self._connection_waiter = None
|
||||
|
||||
self._connection_waiter.addCallbacks(writer, fail)
|
||||
|
||||
def _handle_pressure(self) -> None:
|
||||
"""
|
||||
Handle backpressure by shedding events.
|
||||
Handle backpressure by shedding records.
|
||||
|
||||
The buffer will, in this order, until the buffer is below the maximum:
|
||||
- Shed DEBUG events
|
||||
- Shed INFO events
|
||||
- Shed the middle 50% of the events.
|
||||
- Shed DEBUG records.
|
||||
- Shed INFO records.
|
||||
- Shed the middle 50% of the records.
|
||||
"""
|
||||
if len(self._buffer) <= self.maximum_buffer:
|
||||
return
|
||||
|
||||
# Strip out DEBUGs
|
||||
self._buffer = deque(
|
||||
filter(lambda event: event["log_level"] != LogLevel.debug, self._buffer)
|
||||
filter(lambda record: record.levelno > logging.DEBUG, self._buffer)
|
||||
)
|
||||
|
||||
if len(self._buffer) <= self.maximum_buffer:
|
||||
|
@ -188,7 +204,7 @@ class TCPLogObserver:
|
|||
|
||||
# Strip out INFOs
|
||||
self._buffer = deque(
|
||||
filter(lambda event: event["log_level"] != LogLevel.info, self._buffer)
|
||||
filter(lambda record: record.levelno > logging.INFO, self._buffer)
|
||||
)
|
||||
|
||||
if len(self._buffer) <= self.maximum_buffer:
|
||||
|
@ -209,8 +225,8 @@ class TCPLogObserver:
|
|||
|
||||
self._buffer.extend(reversed(end_buffer))
|
||||
|
||||
def __call__(self, event: dict) -> None:
|
||||
self._buffer.append(event)
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
self._buffer.append(record)
|
||||
|
||||
# Handle backpressure, if it exists.
|
||||
try:
|
||||
|
@ -219,7 +235,7 @@ class TCPLogObserver:
|
|||
# If handling backpressure fails, clear the buffer and log the
|
||||
# exception.
|
||||
self._buffer.clear()
|
||||
self._logger.failure("Failed clearing backpressure")
|
||||
logger.warning("Failed clearing backpressure")
|
||||
|
||||
# Try and write immediately.
|
||||
self._connect()
|
||||
|
|
|
@ -12,138 +12,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os.path
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from typing import List
|
||||
from typing import Any, Dict, Generator, Optional, Tuple
|
||||
|
||||
import attr
|
||||
from constantly import NamedConstant, Names, ValueConstant, Values
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.logger import (
|
||||
FileLogObserver,
|
||||
FilteringLogObserver,
|
||||
ILogObserver,
|
||||
LogBeginner,
|
||||
Logger,
|
||||
LogLevel,
|
||||
LogLevelFilterPredicate,
|
||||
LogPublisher,
|
||||
eventAsText,
|
||||
jsonFileLogObserver,
|
||||
)
|
||||
from constantly import NamedConstant, Names
|
||||
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.logging._terse_json import (
|
||||
TerseJSONToConsoleLogObserver,
|
||||
TerseJSONToTCPLogObserver,
|
||||
)
|
||||
from synapse.logging.context import current_context
|
||||
|
||||
|
||||
def stdlib_log_level_to_twisted(level: str) -> LogLevel:
|
||||
"""
|
||||
Convert a stdlib log level to Twisted's log level.
|
||||
"""
|
||||
lvl = level.lower().replace("warning", "warn")
|
||||
return LogLevel.levelWithName(lvl)
|
||||
|
||||
|
||||
@attr.s
|
||||
@implementer(ILogObserver)
|
||||
class LogContextObserver:
|
||||
"""
|
||||
An ILogObserver which adds Synapse-specific log context information.
|
||||
|
||||
Attributes:
|
||||
observer (ILogObserver): The target parent observer.
|
||||
"""
|
||||
|
||||
observer = attr.ib()
|
||||
|
||||
def __call__(self, event: dict) -> None:
|
||||
"""
|
||||
Consume a log event and emit it to the parent observer after filtering
|
||||
and adding log context information.
|
||||
|
||||
Args:
|
||||
event (dict)
|
||||
"""
|
||||
# Filter out some useless events that Twisted outputs
|
||||
if "log_text" in event:
|
||||
if event["log_text"].startswith("DNSDatagramProtocol starting on "):
|
||||
return
|
||||
|
||||
if event["log_text"].startswith("(UDP Port "):
|
||||
return
|
||||
|
||||
if event["log_text"].startswith("Timing out client") or event[
|
||||
"log_format"
|
||||
].startswith("Timing out client"):
|
||||
return
|
||||
|
||||
context = current_context()
|
||||
|
||||
# Copy the context information to the log event.
|
||||
context.copy_to_twisted_log_entry(event)
|
||||
|
||||
self.observer(event)
|
||||
|
||||
|
||||
class PythonStdlibToTwistedLogger(logging.Handler):
|
||||
"""
|
||||
Transform a Python stdlib log message into a Twisted one.
|
||||
"""
|
||||
|
||||
def __init__(self, observer, *args, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
observer (ILogObserver): A Twisted logging observer.
|
||||
*args, **kwargs: Args/kwargs to be passed to logging.Handler.
|
||||
"""
|
||||
self.observer = observer
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
"""
|
||||
Emit a record to Twisted's observer.
|
||||
|
||||
Args:
|
||||
record (logging.LogRecord)
|
||||
"""
|
||||
|
||||
self.observer(
|
||||
{
|
||||
"log_time": record.created,
|
||||
"log_text": record.getMessage(),
|
||||
"log_format": "{log_text}",
|
||||
"log_namespace": record.name,
|
||||
"log_level": stdlib_log_level_to_twisted(record.levelname),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def SynapseFileLogObserver(outFile: typing.IO[str]) -> FileLogObserver:
|
||||
"""
|
||||
A log observer that formats events like the traditional log formatter and
|
||||
sends them to `outFile`.
|
||||
|
||||
Args:
|
||||
outFile (file object): The file object to write to.
|
||||
"""
|
||||
|
||||
def formatEvent(_event: dict) -> str:
|
||||
event = dict(_event)
|
||||
event["log_level"] = event["log_level"].name.upper()
|
||||
event["log_format"] = "- {log_namespace} - {log_level} - {request} - " + (
|
||||
event.get("log_format", "{log_text}") or "{log_text}"
|
||||
)
|
||||
return eventAsText(event, includeSystem=False) + "\n"
|
||||
|
||||
return FileLogObserver(outFile, formatEvent)
|
||||
|
||||
|
||||
class DrainType(Names):
|
||||
|
@ -155,30 +29,12 @@ class DrainType(Names):
|
|||
NETWORK_JSON_TERSE = NamedConstant()
|
||||
|
||||
|
||||
class OutputPipeType(Values):
|
||||
stdout = ValueConstant(sys.__stdout__)
|
||||
stderr = ValueConstant(sys.__stderr__)
|
||||
|
||||
|
||||
@attr.s
|
||||
class DrainConfiguration:
|
||||
name = attr.ib()
|
||||
type = attr.ib()
|
||||
location = attr.ib()
|
||||
options = attr.ib(default=None)
|
||||
|
||||
|
||||
@attr.s
|
||||
class NetworkJSONTerseOptions:
|
||||
maximum_buffer = attr.ib(type=int)
|
||||
|
||||
|
||||
DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}}
|
||||
DEFAULT_LOGGERS = {"synapse": {"level": "info"}}
|
||||
|
||||
|
||||
def parse_drain_configs(
|
||||
drains: dict,
|
||||
) -> typing.Generator[DrainConfiguration, None, None]:
|
||||
) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
"""
|
||||
Parse the drain configurations.
|
||||
|
||||
|
@ -186,11 +42,12 @@ def parse_drain_configs(
|
|||
drains (dict): A list of drain configurations.
|
||||
|
||||
Yields:
|
||||
DrainConfiguration instances.
|
||||
dict instances representing a logging handler.
|
||||
|
||||
Raises:
|
||||
ConfigError: If any of the drain configuration items are invalid.
|
||||
"""
|
||||
|
||||
for name, config in drains.items():
|
||||
if "type" not in config:
|
||||
raise ConfigError("Logging drains require a 'type' key.")
|
||||
|
@ -202,6 +59,18 @@ def parse_drain_configs(
|
|||
"%s is not a known logging drain type." % (config["type"],)
|
||||
)
|
||||
|
||||
# Either use the default formatter or the tersejson one.
|
||||
if logging_type in (DrainType.CONSOLE_JSON, DrainType.FILE_JSON,):
|
||||
formatter = "json" # type: Optional[str]
|
||||
elif logging_type in (
|
||||
DrainType.CONSOLE_JSON_TERSE,
|
||||
DrainType.NETWORK_JSON_TERSE,
|
||||
):
|
||||
formatter = "tersejson"
|
||||
else:
|
||||
# A formatter of None implies using the default formatter.
|
||||
formatter = None
|
||||
|
||||
if logging_type in [
|
||||
DrainType.CONSOLE,
|
||||
DrainType.CONSOLE_JSON,
|
||||
|
@ -217,9 +86,11 @@ def parse_drain_configs(
|
|||
% (logging_type,)
|
||||
)
|
||||
|
||||
pipe = OutputPipeType.lookupByName(location).value
|
||||
|
||||
yield DrainConfiguration(name=name, type=logging_type, location=pipe)
|
||||
yield name, {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": formatter,
|
||||
"stream": "ext://sys." + location,
|
||||
}
|
||||
|
||||
elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]:
|
||||
if "location" not in config:
|
||||
|
@ -233,18 +104,25 @@ def parse_drain_configs(
|
|||
"File paths need to be absolute, '%s' is a relative path"
|
||||
% (location,)
|
||||
)
|
||||
yield DrainConfiguration(name=name, type=logging_type, location=location)
|
||||
|
||||
yield name, {
|
||||
"class": "logging.FileHandler",
|
||||
"formatter": formatter,
|
||||
"filename": location,
|
||||
}
|
||||
|
||||
elif logging_type in [DrainType.NETWORK_JSON_TERSE]:
|
||||
host = config.get("host")
|
||||
port = config.get("port")
|
||||
maximum_buffer = config.get("maximum_buffer", 1000)
|
||||
yield DrainConfiguration(
|
||||
name=name,
|
||||
type=logging_type,
|
||||
location=(host, port),
|
||||
options=NetworkJSONTerseOptions(maximum_buffer=maximum_buffer),
|
||||
)
|
||||
|
||||
yield name, {
|
||||
"class": "synapse.logging.RemoteHandler",
|
||||
"formatter": formatter,
|
||||
"host": host,
|
||||
"port": port,
|
||||
"maximum_buffer": maximum_buffer,
|
||||
}
|
||||
|
||||
else:
|
||||
raise ConfigError(
|
||||
|
@ -253,126 +131,29 @@ def parse_drain_configs(
|
|||
)
|
||||
|
||||
|
||||
class StoppableLogPublisher(LogPublisher):
|
||||
def setup_structured_logging(log_config: dict,) -> dict:
|
||||
"""
|
||||
A log publisher that can tell its observers to shut down any external
|
||||
communications.
|
||||
Convert a legacy structured logging configuration (from Synapse < v1.23.0)
|
||||
to one compatible with the new standard library handlers.
|
||||
"""
|
||||
|
||||
def stop(self):
|
||||
for obs in self._observers:
|
||||
if hasattr(obs, "stop"):
|
||||
obs.stop()
|
||||
|
||||
|
||||
def setup_structured_logging(
|
||||
hs,
|
||||
config,
|
||||
log_config: dict,
|
||||
logBeginner: LogBeginner,
|
||||
redirect_stdlib_logging: bool = True,
|
||||
) -> LogPublisher:
|
||||
"""
|
||||
Set up Twisted's structured logging system.
|
||||
|
||||
Args:
|
||||
hs: The homeserver to use.
|
||||
config (HomeserverConfig): The configuration of the Synapse homeserver.
|
||||
log_config (dict): The log configuration to use.
|
||||
"""
|
||||
if config.no_redirect_stdio:
|
||||
raise ConfigError(
|
||||
"no_redirect_stdio cannot be defined using structured logging."
|
||||
)
|
||||
|
||||
logger = Logger()
|
||||
|
||||
if "drains" not in log_config:
|
||||
raise ConfigError("The logging configuration requires a list of drains.")
|
||||
|
||||
observers = [] # type: List[ILogObserver]
|
||||
new_config = {
|
||||
"version": 1,
|
||||
"formatters": {
|
||||
"json": {"class": "synapse.logging.JsonFormatter"},
|
||||
"tersejson": {"class": "synapse.logging.TerseJsonFormatter"},
|
||||
},
|
||||
"handlers": {},
|
||||
"loggers": log_config.get("loggers", DEFAULT_LOGGERS),
|
||||
"root": {"handlers": []},
|
||||
}
|
||||
|
||||
for observer in parse_drain_configs(log_config["drains"]):
|
||||
# Pipe drains
|
||||
if observer.type == DrainType.CONSOLE:
|
||||
logger.debug(
|
||||
"Starting up the {name} console logger drain", name=observer.name
|
||||
)
|
||||
observers.append(SynapseFileLogObserver(observer.location))
|
||||
elif observer.type == DrainType.CONSOLE_JSON:
|
||||
logger.debug(
|
||||
"Starting up the {name} JSON console logger drain", name=observer.name
|
||||
)
|
||||
observers.append(jsonFileLogObserver(observer.location))
|
||||
elif observer.type == DrainType.CONSOLE_JSON_TERSE:
|
||||
logger.debug(
|
||||
"Starting up the {name} terse JSON console logger drain",
|
||||
name=observer.name,
|
||||
)
|
||||
observers.append(
|
||||
TerseJSONToConsoleLogObserver(observer.location, metadata={})
|
||||
)
|
||||
for handler_name, handler in parse_drain_configs(log_config["drains"]):
|
||||
new_config["handlers"][handler_name] = handler
|
||||
|
||||
# File drains
|
||||
elif observer.type == DrainType.FILE:
|
||||
logger.debug("Starting up the {name} file logger drain", name=observer.name)
|
||||
log_file = open(observer.location, "at", buffering=1, encoding="utf8")
|
||||
observers.append(SynapseFileLogObserver(log_file))
|
||||
elif observer.type == DrainType.FILE_JSON:
|
||||
logger.debug(
|
||||
"Starting up the {name} JSON file logger drain", name=observer.name
|
||||
)
|
||||
log_file = open(observer.location, "at", buffering=1, encoding="utf8")
|
||||
observers.append(jsonFileLogObserver(log_file))
|
||||
# Add each handler to the root logger.
|
||||
new_config["root"]["handlers"].append(handler_name)
|
||||
|
||||
elif observer.type == DrainType.NETWORK_JSON_TERSE:
|
||||
metadata = {"server_name": hs.config.server_name}
|
||||
log_observer = TerseJSONToTCPLogObserver(
|
||||
hs=hs,
|
||||
host=observer.location[0],
|
||||
port=observer.location[1],
|
||||
metadata=metadata,
|
||||
maximum_buffer=observer.options.maximum_buffer,
|
||||
)
|
||||
log_observer.start()
|
||||
observers.append(log_observer)
|
||||
else:
|
||||
# We should never get here, but, just in case, throw an error.
|
||||
raise ConfigError("%s drain type cannot be configured" % (observer.type,))
|
||||
|
||||
publisher = StoppableLogPublisher(*observers)
|
||||
log_filter = LogLevelFilterPredicate()
|
||||
|
||||
for namespace, namespace_config in log_config.get(
|
||||
"loggers", DEFAULT_LOGGERS
|
||||
).items():
|
||||
# Set the log level for twisted.logger.Logger namespaces
|
||||
log_filter.setLogLevelForNamespace(
|
||||
namespace,
|
||||
stdlib_log_level_to_twisted(namespace_config.get("level", "INFO")),
|
||||
)
|
||||
|
||||
# Also set the log levels for the stdlib logger namespaces, to prevent
|
||||
# them getting to PythonStdlibToTwistedLogger and having to be formatted
|
||||
if "level" in namespace_config:
|
||||
logging.getLogger(namespace).setLevel(namespace_config.get("level"))
|
||||
|
||||
f = FilteringLogObserver(publisher, [log_filter])
|
||||
lco = LogContextObserver(f)
|
||||
|
||||
if redirect_stdlib_logging:
|
||||
stuff_into_twisted = PythonStdlibToTwistedLogger(lco)
|
||||
stdliblogger = logging.getLogger()
|
||||
stdliblogger.addHandler(stuff_into_twisted)
|
||||
|
||||
# Always redirect standard I/O, otherwise other logging outputs might miss
|
||||
# it.
|
||||
logBeginner.beginLoggingTo([lco], redirectStandardIO=True)
|
||||
|
||||
return publisher
|
||||
|
||||
|
||||
def reload_structured_logging(*args, log_config=None) -> None:
|
||||
warnings.warn(
|
||||
"Currently the structured logging system can not be reloaded, doing nothing"
|
||||
)
|
||||
return new_config
|
||||
|
|
|
@ -16,141 +16,65 @@
|
|||
"""
|
||||
Log formatters that output terse JSON.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import IO
|
||||
|
||||
from twisted.logger import FileLogObserver
|
||||
|
||||
from synapse.logging._remote import TCPLogObserver
|
||||
import logging
|
||||
|
||||
_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
|
||||
def flatten_event(event: dict, metadata: dict, include_time: bool = False):
|
||||
"""
|
||||
Flatten a Twisted logging event to an dictionary capable of being sent
|
||||
as a log event to a logging aggregation system.
|
||||
|
||||
The format is vastly simplified and is not designed to be a "human readable
|
||||
string" in the sense that traditional logs are. Instead, the structure is
|
||||
optimised for searchability and filtering, with human-understandable log
|
||||
keys.
|
||||
|
||||
Args:
|
||||
event (dict): The Twisted logging event we are flattening.
|
||||
metadata (dict): Additional data to include with each log message. This
|
||||
can be information like the server name. Since the target log
|
||||
consumer does not know who we are other than by host IP, this
|
||||
allows us to forward through static information.
|
||||
include_time (bool): Should we include the `time` key? If False, the
|
||||
event time is stripped from the event.
|
||||
"""
|
||||
new_event = {}
|
||||
|
||||
# If it's a failure, make the new event's log_failure be the traceback text.
|
||||
if "log_failure" in event:
|
||||
new_event["log_failure"] = event["log_failure"].getTraceback()
|
||||
|
||||
# If it's a warning, copy over a string representation of the warning.
|
||||
if "warning" in event:
|
||||
new_event["warning"] = str(event["warning"])
|
||||
|
||||
# Stdlib logging events have "log_text" as their human-readable portion,
|
||||
# Twisted ones have "log_format". For now, include the log_format, so that
|
||||
# context only given in the log format (e.g. what is being logged) is
|
||||
# available.
|
||||
if "log_text" in event:
|
||||
new_event["log"] = event["log_text"]
|
||||
else:
|
||||
new_event["log"] = event["log_format"]
|
||||
|
||||
# We want to include the timestamp when forwarding over the network, but
|
||||
# exclude it when we are writing to stdout. This is because the log ingester
|
||||
# (e.g. logstash, fluentd) can add its own timestamp.
|
||||
if include_time:
|
||||
new_event["time"] = round(event["log_time"], 2)
|
||||
|
||||
# Convert the log level to a textual representation.
|
||||
new_event["level"] = event["log_level"].name.upper()
|
||||
|
||||
# Ignore these keys, and do not transfer them over to the new log object.
|
||||
# They are either useless (isError), transferred manually above (log_time,
|
||||
# log_level, etc), or contain Python objects which are not useful for output
|
||||
# (log_logger, log_source).
|
||||
keys_to_delete = [
|
||||
"isError",
|
||||
"log_failure",
|
||||
"log_format",
|
||||
"log_level",
|
||||
"log_logger",
|
||||
"log_source",
|
||||
"log_system",
|
||||
"log_time",
|
||||
"log_text",
|
||||
"observer",
|
||||
"warning",
|
||||
]
|
||||
|
||||
# If it's from the Twisted legacy logger (twisted.python.log), it adds some
|
||||
# more keys we want to purge.
|
||||
if event.get("log_namespace") == "log_legacy":
|
||||
keys_to_delete.extend(["message", "system", "time"])
|
||||
|
||||
# Rather than modify the dictionary in place, construct a new one with only
|
||||
# the content we want. The original event should be considered 'frozen'.
|
||||
for key in event.keys():
|
||||
|
||||
if key in keys_to_delete:
|
||||
continue
|
||||
|
||||
if isinstance(event[key], (str, int, bool, float)) or event[key] is None:
|
||||
# If it's a plain type, include it as is.
|
||||
new_event[key] = event[key]
|
||||
else:
|
||||
# If it's not one of those basic types, write out a string
|
||||
# representation. This should probably be a warning in development,
|
||||
# so that we are sure we are only outputting useful data.
|
||||
new_event[key] = str(event[key])
|
||||
|
||||
# Add the metadata information to the event (e.g. the server_name).
|
||||
new_event.update(metadata)
|
||||
|
||||
return new_event
|
||||
# The properties of a standard LogRecord.
|
||||
_LOG_RECORD_ATTRIBUTES = {
|
||||
"args",
|
||||
"asctime",
|
||||
"created",
|
||||
"exc_info",
|
||||
# exc_text isn't a public attribute, but is used to cache the result of formatException.
|
||||
"exc_text",
|
||||
"filename",
|
||||
"funcName",
|
||||
"levelname",
|
||||
"levelno",
|
||||
"lineno",
|
||||
"message",
|
||||
"module",
|
||||
"msecs",
|
||||
"msg",
|
||||
"name",
|
||||
"pathname",
|
||||
"process",
|
||||
"processName",
|
||||
"relativeCreated",
|
||||
"stack_info",
|
||||
"thread",
|
||||
"threadName",
|
||||
}
|
||||
|
||||
|
||||
def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogObserver:
|
||||
"""
|
||||
A log observer that formats events to a flattened JSON representation.
|
||||
class JsonFormatter(logging.Formatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
event = {
|
||||
"log": record.getMessage(),
|
||||
"namespace": record.name,
|
||||
"level": record.levelname,
|
||||
}
|
||||
|
||||
Args:
|
||||
outFile: The file object to write to.
|
||||
metadata: Metadata to be added to each log object.
|
||||
"""
|
||||
return self._format(record, event)
|
||||
|
||||
def formatEvent(_event: dict) -> str:
|
||||
flattened = flatten_event(_event, metadata)
|
||||
return _encoder.encode(flattened) + "\n"
|
||||
def _format(self, record: logging.LogRecord, event: dict) -> str:
|
||||
# Add any extra attributes to the event.
|
||||
for key, value in record.__dict__.items():
|
||||
if key not in _LOG_RECORD_ATTRIBUTES:
|
||||
event[key] = value
|
||||
|
||||
return FileLogObserver(outFile, formatEvent)
|
||||
return _encoder.encode(event)
|
||||
|
||||
|
||||
def TerseJSONToTCPLogObserver(
|
||||
hs, host: str, port: int, metadata: dict, maximum_buffer: int
|
||||
) -> FileLogObserver:
|
||||
"""
|
||||
A log observer that formats events to a flattened JSON representation.
|
||||
class TerseJsonFormatter(JsonFormatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
event = {
|
||||
"log": record.getMessage(),
|
||||
"namespace": record.name,
|
||||
"level": record.levelname,
|
||||
"time": round(record.created, 2),
|
||||
}
|
||||
|
||||
Args:
|
||||
hs (HomeServer): The homeserver that is being logged for.
|
||||
host: The host of the logging target.
|
||||
port: The logging target's port.
|
||||
metadata: Metadata to be added to each log object.
|
||||
maximum_buffer: The maximum buffer size.
|
||||
"""
|
||||
|
||||
def formatEvent(_event: dict) -> str:
|
||||
flattened = flatten_event(_event, metadata, include_time=True)
|
||||
return _encoder.encode(flattened) + "\n"
|
||||
|
||||
return TCPLogObserver(hs, host, port, formatEvent, maximum_buffer)
|
||||
return self._format(record, event)
|
||||
|
|
33
synapse/logging/filter.py
Normal file
33
synapse/logging/filter.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
class MetadataFilter(logging.Filter):
|
||||
"""Logging filter that adds constant values to each record.
|
||||
|
||||
Args:
|
||||
metadata: Key-value pairs to add to each record.
|
||||
"""
|
||||
|
||||
def __init__(self, metadata: dict):
|
||||
self._metadata = metadata
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> Literal[True]:
|
||||
for key, value in self._metadata.items():
|
||||
setattr(record, key, value)
|
||||
return True
|
|
@ -28,6 +28,7 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
|
||||
import attr
|
||||
from prometheus_client import Counter
|
||||
|
||||
from twisted.internet import defer
|
||||
|
@ -173,6 +174,17 @@ class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
|
|||
return bool(self.events)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class _PendingRoomEventEntry:
|
||||
event_pos = attr.ib(type=PersistedEventPosition)
|
||||
extra_users = attr.ib(type=Collection[UserID])
|
||||
|
||||
room_id = attr.ib(type=str)
|
||||
type = attr.ib(type=str)
|
||||
state_key = attr.ib(type=Optional[str])
|
||||
membership = attr.ib(type=Optional[str])
|
||||
|
||||
|
||||
class Notifier:
|
||||
""" This class is responsible for notifying any listeners when there are
|
||||
new events available for it.
|
||||
|
@ -190,9 +202,7 @@ class Notifier:
|
|||
self.storage = hs.get_storage()
|
||||
self.event_sources = hs.get_event_sources()
|
||||
self.store = hs.get_datastore()
|
||||
self.pending_new_room_events = (
|
||||
[]
|
||||
) # type: List[Tuple[PersistedEventPosition, EventBase, Collection[UserID]]]
|
||||
self.pending_new_room_events = [] # type: List[_PendingRoomEventEntry]
|
||||
|
||||
# Called when there are new things to stream over replication
|
||||
self.replication_callbacks = [] # type: List[Callable[[], None]]
|
||||
|
@ -254,6 +264,28 @@ class Notifier:
|
|||
event_pos: PersistedEventPosition,
|
||||
max_room_stream_token: RoomStreamToken,
|
||||
extra_users: Collection[UserID] = [],
|
||||
):
|
||||
"""Unwraps event and calls `on_new_room_event_args`.
|
||||
"""
|
||||
self.on_new_room_event_args(
|
||||
event_pos=event_pos,
|
||||
room_id=event.room_id,
|
||||
event_type=event.type,
|
||||
state_key=event.get("state_key"),
|
||||
membership=event.content.get("membership"),
|
||||
max_room_stream_token=max_room_stream_token,
|
||||
extra_users=extra_users,
|
||||
)
|
||||
|
||||
def on_new_room_event_args(
|
||||
self,
|
||||
room_id: str,
|
||||
event_type: str,
|
||||
state_key: Optional[str],
|
||||
membership: Optional[str],
|
||||
event_pos: PersistedEventPosition,
|
||||
max_room_stream_token: RoomStreamToken,
|
||||
extra_users: Collection[UserID] = [],
|
||||
):
|
||||
"""Used by handlers to inform the notifier something has happened
|
||||
in the room, room event wise.
|
||||
|
@ -266,7 +298,16 @@ class Notifier:
|
|||
until all previous events have been persisted before notifying
|
||||
the client streams.
|
||||
"""
|
||||
self.pending_new_room_events.append((event_pos, event, extra_users))
|
||||
self.pending_new_room_events.append(
|
||||
_PendingRoomEventEntry(
|
||||
event_pos=event_pos,
|
||||
extra_users=extra_users,
|
||||
room_id=room_id,
|
||||
type=event_type,
|
||||
state_key=state_key,
|
||||
membership=membership,
|
||||
)
|
||||
)
|
||||
self._notify_pending_new_room_events(max_room_stream_token)
|
||||
|
||||
self.notify_replication()
|
||||
|
@ -284,18 +325,19 @@ class Notifier:
|
|||
users = set() # type: Set[UserID]
|
||||
rooms = set() # type: Set[str]
|
||||
|
||||
for event_pos, event, extra_users in pending:
|
||||
if event_pos.persisted_after(max_room_stream_token):
|
||||
self.pending_new_room_events.append((event_pos, event, extra_users))
|
||||
for entry in pending:
|
||||
if entry.event_pos.persisted_after(max_room_stream_token):
|
||||
self.pending_new_room_events.append(entry)
|
||||
else:
|
||||
if (
|
||||
event.type == EventTypes.Member
|
||||
and event.membership == Membership.JOIN
|
||||
entry.type == EventTypes.Member
|
||||
and entry.membership == Membership.JOIN
|
||||
and entry.state_key
|
||||
):
|
||||
self._user_joined_room(event.state_key, event.room_id)
|
||||
self._user_joined_room(entry.state_key, entry.room_id)
|
||||
|
||||
users.update(extra_users)
|
||||
rooms.add(event.room_id)
|
||||
users.update(entry.extra_users)
|
||||
rooms.add(entry.room_id)
|
||||
|
||||
if users or rooms:
|
||||
self.on_new_event(
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
import attr
|
||||
from prometheus_client import Counter
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership, RelationTypes
|
||||
|
@ -26,7 +26,8 @@ from synapse.events.snapshot import EventContext
|
|||
from synapse.state import POWER_KEY
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches import register_cache
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.caches.descriptors import lru_cache
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
|
||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||
|
||||
|
@ -120,7 +121,7 @@ class BulkPushRuleEvaluator:
|
|||
dict of user_id -> push_rules
|
||||
"""
|
||||
room_id = event.room_id
|
||||
rules_for_room = await self._get_rules_for_room(room_id)
|
||||
rules_for_room = self._get_rules_for_room(room_id)
|
||||
|
||||
rules_by_user = await rules_for_room.get_rules(event, context)
|
||||
|
||||
|
@ -138,7 +139,7 @@ class BulkPushRuleEvaluator:
|
|||
|
||||
return rules_by_user
|
||||
|
||||
@cached()
|
||||
@lru_cache()
|
||||
def _get_rules_for_room(self, room_id):
|
||||
"""Get the current RulesForRoom object for the given room id
|
||||
|
||||
|
@ -275,12 +276,14 @@ class RulesForRoom:
|
|||
the entire cache for the room.
|
||||
"""
|
||||
|
||||
def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics):
|
||||
def __init__(
|
||||
self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hs (HomeServer)
|
||||
room_id (str)
|
||||
rules_for_room_cache(Cache): The cache object that caches these
|
||||
rules_for_room_cache: The cache object that caches these
|
||||
RoomsForUser objects.
|
||||
room_push_rule_cache_metrics (CacheMetric)
|
||||
"""
|
||||
|
@ -489,13 +492,21 @@ class RulesForRoom:
|
|||
self.state_group = state_group
|
||||
|
||||
|
||||
class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
|
||||
# We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
|
||||
# which namedtuple does for us (i.e. two _CacheContext are the same if
|
||||
# their caches and keys match). This is important in particular to
|
||||
# dedupe when we add callbacks to lru cache nodes, otherwise the number
|
||||
# of callbacks would grow.
|
||||
@attr.attrs(slots=True, frozen=True)
|
||||
class _Invalidation:
|
||||
# _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules,
|
||||
# which means that it it is stored on the bulk_get_push_rules cache entry. In order
|
||||
# to ensure that we don't accumulate lots of redunant callbacks on the cache entry,
|
||||
# we need to ensure that two _Invalidation objects are "equal" if they refer to the
|
||||
# same `cache` and `room_id`.
|
||||
#
|
||||
# attrs provides suitable __hash__ and __eq__ methods, provided we remember to
|
||||
# set `frozen=True`.
|
||||
|
||||
cache = attr.ib(type=LruCache)
|
||||
room_id = attr.ib(type=str)
|
||||
|
||||
def __call__(self):
|
||||
rules = self.cache.get_immediate(self.room_id, None, update_metrics=False)
|
||||
rules = self.cache.get(self.room_id, None, update_metrics=False)
|
||||
if rules:
|
||||
rules.invalidate_all()
|
||||
|
|
|
@ -77,8 +77,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
|
|||
|
||||
requester = Requester.deserialize(self.store, content["requester"])
|
||||
|
||||
if requester.user:
|
||||
request.authenticated_entity = requester.user.to_string()
|
||||
request.requester = requester
|
||||
|
||||
logger.info("remote_join: %s into room: %s", user_id, room_id)
|
||||
|
||||
|
@ -142,8 +141,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
|
|||
|
||||
requester = Requester.deserialize(self.store, content["requester"])
|
||||
|
||||
if requester.user:
|
||||
request.authenticated_entity = requester.user.to_string()
|
||||
request.requester = requester
|
||||
|
||||
# hopefully we're now on the master, so this won't recurse!
|
||||
event_id, stream_id = await self.member_handler.remote_reject_invite(
|
||||
|
|
|
@ -115,8 +115,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
|||
ratelimit = content["ratelimit"]
|
||||
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
|
||||
|
||||
if requester.user:
|
||||
request.authenticated_entity = requester.user.to_string()
|
||||
request.requester = requester
|
||||
|
||||
logger.info(
|
||||
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id
|
||||
|
|
|
@ -141,21 +141,25 @@ class ReplicationDataHandler:
|
|||
if row.type != EventsStreamEventRow.TypeId:
|
||||
continue
|
||||
assert isinstance(row, EventsStreamRow)
|
||||
assert isinstance(row.data, EventsStreamEventRow)
|
||||
|
||||
event = await self.store.get_event(
|
||||
row.data.event_id, allow_rejected=True
|
||||
)
|
||||
if event.rejected_reason:
|
||||
if row.data.rejected:
|
||||
continue
|
||||
|
||||
extra_users = () # type: Tuple[UserID, ...]
|
||||
if event.type == EventTypes.Member:
|
||||
extra_users = (UserID.from_string(event.state_key),)
|
||||
if row.data.type == EventTypes.Member and row.data.state_key:
|
||||
extra_users = (UserID.from_string(row.data.state_key),)
|
||||
|
||||
max_token = self.store.get_room_max_token()
|
||||
event_pos = PersistedEventPosition(instance_name, token)
|
||||
self.notifier.on_new_room_event(
|
||||
event, event_pos, max_token, extra_users
|
||||
self.notifier.on_new_room_event_args(
|
||||
event_pos=event_pos,
|
||||
max_room_stream_token=max_token,
|
||||
extra_users=extra_users,
|
||||
room_id=row.data.room_id,
|
||||
event_type=row.data.type,
|
||||
state_key=row.data.state_key,
|
||||
membership=row.data.membership,
|
||||
)
|
||||
|
||||
# Notify any waiting deferreds. The list is ordered by position so we
|
||||
|
|
|
@ -15,12 +15,15 @@
|
|||
# limitations under the License.
|
||||
import heapq
|
||||
from collections.abc import Iterable
|
||||
from typing import List, Tuple, Type
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
|
||||
|
||||
import attr
|
||||
|
||||
from ._base import Stream, StreamUpdateResult, Token
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
"""Handling of the 'events' replication stream
|
||||
|
||||
This stream contains rows of various types. Each row therefore contains a 'type'
|
||||
|
@ -81,12 +84,14 @@ class BaseEventsStreamRow:
|
|||
class EventsStreamEventRow(BaseEventsStreamRow):
|
||||
TypeId = "ev"
|
||||
|
||||
event_id = attr.ib() # str
|
||||
room_id = attr.ib() # str
|
||||
type = attr.ib() # str
|
||||
state_key = attr.ib() # str, optional
|
||||
redacts = attr.ib() # str, optional
|
||||
relates_to = attr.ib() # str, optional
|
||||
event_id = attr.ib(type=str)
|
||||
room_id = attr.ib(type=str)
|
||||
type = attr.ib(type=str)
|
||||
state_key = attr.ib(type=Optional[str])
|
||||
redacts = attr.ib(type=Optional[str])
|
||||
relates_to = attr.ib(type=Optional[str])
|
||||
membership = attr.ib(type=Optional[str])
|
||||
rejected = attr.ib(type=bool)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
|
@ -113,7 +118,7 @@ class EventsStream(Stream):
|
|||
|
||||
NAME = "events"
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
|
|
|
@ -50,6 +50,7 @@ from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
|
|||
from synapse.rest.admin.users import (
|
||||
AccountValidityRenewServlet,
|
||||
DeactivateAccountRestServlet,
|
||||
PushersRestServlet,
|
||||
ResetPasswordRestServlet,
|
||||
SearchUsersRestServlet,
|
||||
UserAdminServlet,
|
||||
|
@ -226,8 +227,9 @@ def register_servlets(hs, http_server):
|
|||
DeviceRestServlet(hs).register(http_server)
|
||||
DevicesRestServlet(hs).register(http_server)
|
||||
DeleteDevicesRestServlet(hs).register(http_server)
|
||||
EventReportsRestServlet(hs).register(http_server)
|
||||
EventReportDetailRestServlet(hs).register(http_server)
|
||||
EventReportsRestServlet(hs).register(http_server)
|
||||
PushersRestServlet(hs).register(http_server)
|
||||
|
||||
|
||||
def register_servlets_for_client_rest_resource(hs, http_server):
|
||||
|
|
|
@ -39,6 +39,17 @@ from synapse.types import JsonDict, UserID
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GET_PUSHERS_ALLOWED_KEYS = {
|
||||
"app_display_name",
|
||||
"app_id",
|
||||
"data",
|
||||
"device_display_name",
|
||||
"kind",
|
||||
"lang",
|
||||
"profile_tag",
|
||||
"pushkey",
|
||||
}
|
||||
|
||||
|
||||
class UsersRestServlet(RestServlet):
|
||||
PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$")
|
||||
|
@ -713,6 +724,47 @@ class UserMembershipRestServlet(RestServlet):
|
|||
return 200, ret
|
||||
|
||||
|
||||
class PushersRestServlet(RestServlet):
|
||||
"""
|
||||
Gets information about all pushers for a specific `user_id`.
|
||||
|
||||
Example:
|
||||
http://localhost:8008/_synapse/admin/v1/users/
|
||||
@user:server/pushers
|
||||
|
||||
Returns:
|
||||
pushers: Dictionary containing pushers information.
|
||||
total: Number of pushers in dictonary `pushers`.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
|
||||
|
||||
def __init__(self, hs):
|
||||
self.is_mine = hs.is_mine
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
if not self.is_mine(UserID.from_string(user_id)):
|
||||
raise SynapseError(400, "Can only lookup local users")
|
||||
|
||||
if not await self.store.get_user_by_id(user_id):
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
pushers = await self.store.get_pushers_by_user_id(user_id)
|
||||
|
||||
filtered_pushers = [
|
||||
{k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS}
|
||||
for p in pushers
|
||||
]
|
||||
|
||||
return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
|
||||
|
||||
|
||||
class UserMediaRestServlet(RestServlet):
|
||||
"""
|
||||
Gets information about all uploaded local media for a specific `user_id`.
|
||||
|
|
|
@ -305,15 +305,12 @@ class MediaRepository:
|
|||
# file_id is the ID we use to track the file locally. If we've already
|
||||
# seen the file then reuse the existing ID, otherwise genereate a new
|
||||
# one.
|
||||
if media_info:
|
||||
file_id = media_info["filesystem_id"]
|
||||
else:
|
||||
file_id = random_string(24)
|
||||
|
||||
file_info = FileInfo(server_name, file_id)
|
||||
|
||||
# If we have an entry in the DB, try and look for it
|
||||
if media_info:
|
||||
file_id = media_info["filesystem_id"]
|
||||
file_info = FileInfo(server_name, file_id)
|
||||
|
||||
if media_info["quarantined_by"]:
|
||||
logger.info("Media is quarantined")
|
||||
raise NotFoundError()
|
||||
|
@ -324,14 +321,34 @@ class MediaRepository:
|
|||
|
||||
# Failed to find the file anywhere, lets download it.
|
||||
|
||||
media_info = await self._download_remote_file(server_name, media_id, file_id)
|
||||
try:
|
||||
media_info = await self._download_remote_file(server_name, media_id,)
|
||||
except SynapseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
# An exception may be because we downloaded media in another
|
||||
# process, so let's check if we magically have the media.
|
||||
media_info = await self.store.get_cached_remote_media(server_name, media_id)
|
||||
if not media_info:
|
||||
raise e
|
||||
|
||||
file_id = media_info["filesystem_id"]
|
||||
file_info = FileInfo(server_name, file_id)
|
||||
|
||||
# We generate thumbnails even if another process downloaded the media
|
||||
# as a) it's conceivable that the other download request dies before it
|
||||
# generates thumbnails, but mainly b) we want to be sure the thumbnails
|
||||
# have finished being generated before responding to the client,
|
||||
# otherwise they'll request thumbnails and get a 404 if they're not
|
||||
# ready yet.
|
||||
await self._generate_thumbnails(
|
||||
server_name, media_id, file_id, media_info["media_type"]
|
||||
)
|
||||
|
||||
responder = await self.media_storage.fetch_media(file_info)
|
||||
return responder, media_info
|
||||
|
||||
async def _download_remote_file(
|
||||
self, server_name: str, media_id: str, file_id: str
|
||||
) -> dict:
|
||||
async def _download_remote_file(self, server_name: str, media_id: str,) -> dict:
|
||||
"""Attempt to download the remote file from the given server name,
|
||||
using the given file_id as the local id.
|
||||
|
||||
|
@ -346,6 +363,8 @@ class MediaRepository:
|
|||
The media info of the file.
|
||||
"""
|
||||
|
||||
file_id = random_string(24)
|
||||
|
||||
file_info = FileInfo(server_name=server_name, file_id=file_id)
|
||||
|
||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
||||
|
@ -405,8 +424,16 @@ class MediaRepository:
|
|||
upload_name = get_filename_from_headers(headers)
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
logger.info("Stored remote media in file %r", fname)
|
||||
|
||||
# Multiple remote media download requests can race (when using
|
||||
# multiple media repos), so this may throw a violation constraint
|
||||
# exception. If it does we'll delete the newly downloaded file from
|
||||
# disk (as we're in the ctx manager).
|
||||
#
|
||||
# However: we've already called `finish()` so we may have also
|
||||
# written to the storage providers. This is preferable to the
|
||||
# alternative where we call `finish()` *after* this, where we could
|
||||
# end up having an entry in the DB but fail to write the files to
|
||||
# the storage providers.
|
||||
await self.store.store_cached_remote_media(
|
||||
origin=server_name,
|
||||
media_id=media_id,
|
||||
|
@ -417,6 +444,8 @@ class MediaRepository:
|
|||
filesystem_id=file_id,
|
||||
)
|
||||
|
||||
logger.info("Stored remote media in file %r", fname)
|
||||
|
||||
media_info = {
|
||||
"media_type": media_type,
|
||||
"media_length": length,
|
||||
|
@ -425,8 +454,6 @@ class MediaRepository:
|
|||
"filesystem_id": file_id,
|
||||
}
|
||||
|
||||
await self._generate_thumbnails(server_name, media_id, file_id, media_type)
|
||||
|
||||
return media_info
|
||||
|
||||
def _get_thumbnail_requirements(self, media_type):
|
||||
|
@ -692,7 +719,6 @@ class MediaRepository:
|
|||
if not t_byte_source:
|
||||
continue
|
||||
|
||||
try:
|
||||
file_info = FileInfo(
|
||||
server_name=server_name,
|
||||
file_id=file_id,
|
||||
|
@ -704,16 +730,29 @@ class MediaRepository:
|
|||
url_cache=url_cache,
|
||||
)
|
||||
|
||||
output_path = await self.media_storage.store_file(
|
||||
t_byte_source, file_info
|
||||
)
|
||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
||||
try:
|
||||
await self.media_storage.write_to_file(t_byte_source, f)
|
||||
await finish()
|
||||
finally:
|
||||
t_byte_source.close()
|
||||
|
||||
t_len = os.path.getsize(output_path)
|
||||
t_len = os.path.getsize(fname)
|
||||
|
||||
# Write to database
|
||||
if server_name:
|
||||
# Multiple remote media download requests can race (when
|
||||
# using multiple media repos), so this may throw a violation
|
||||
# constraint exception. If it does we'll delete the newly
|
||||
# generated thumbnail from disk (as we're in the ctx
|
||||
# manager).
|
||||
#
|
||||
# However: we've already called `finish()` so we may have
|
||||
# also written to the storage providers. This is preferable
|
||||
# to the alternative where we call `finish()` *after* this,
|
||||
# where we could end up having an entry in the DB but fail
|
||||
# to write the files to the storage providers.
|
||||
try:
|
||||
await self.store.store_remote_media_thumbnail(
|
||||
server_name,
|
||||
media_id,
|
||||
|
@ -724,6 +763,12 @@ class MediaRepository:
|
|||
t_method,
|
||||
t_len,
|
||||
)
|
||||
except Exception as e:
|
||||
thumbnail_exists = await self.store.get_remote_media_thumbnail(
|
||||
server_name, media_id, t_width, t_height, t_type,
|
||||
)
|
||||
if not thumbnail_exists:
|
||||
raise e
|
||||
else:
|
||||
await self.store.store_local_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
|
|
|
@ -52,6 +52,7 @@ class MediaStorage:
|
|||
storage_providers: Sequence["StorageProviderWrapper"],
|
||||
):
|
||||
self.hs = hs
|
||||
self.reactor = hs.get_reactor()
|
||||
self.local_media_directory = local_media_directory
|
||||
self.filepaths = filepaths
|
||||
self.storage_providers = storage_providers
|
||||
|
@ -70,13 +71,16 @@ class MediaStorage:
|
|||
|
||||
with self.store_into_file(file_info) as (f, fname, finish_cb):
|
||||
# Write to the main repository
|
||||
await defer_to_thread(
|
||||
self.hs.get_reactor(), _write_file_synchronously, source, f
|
||||
)
|
||||
await self.write_to_file(source, f)
|
||||
await finish_cb()
|
||||
|
||||
return fname
|
||||
|
||||
async def write_to_file(self, source: IO, output: IO):
|
||||
"""Asynchronously write the `source` to `output`.
|
||||
"""
|
||||
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def store_into_file(self, file_info: FileInfo):
|
||||
"""Context manager used to get a file like object to write into, as
|
||||
|
@ -112,14 +116,20 @@ class MediaStorage:
|
|||
|
||||
finished_called = [False]
|
||||
|
||||
try:
|
||||
with open(fname, "wb") as f:
|
||||
|
||||
async def finish():
|
||||
# Ensure that all writes have been flushed and close the
|
||||
# file.
|
||||
f.flush()
|
||||
f.close()
|
||||
|
||||
for provider in self.storage_providers:
|
||||
await provider.store_file(path, file_info)
|
||||
|
||||
finished_called[0] = True
|
||||
|
||||
try:
|
||||
with open(fname, "wb") as f:
|
||||
yield f, fname, finish
|
||||
except Exception:
|
||||
try:
|
||||
|
@ -210,7 +220,7 @@ class MediaStorage:
|
|||
if res:
|
||||
with res:
|
||||
consumer = BackgroundFileConsumer(
|
||||
open(local_path, "wb"), self.hs.get_reactor()
|
||||
open(local_path, "wb"), self.reactor
|
||||
)
|
||||
await res.write_to_consumer(consumer)
|
||||
await consumer.wait()
|
||||
|
|
|
@ -94,7 +94,7 @@ def make_pool(
|
|||
cp_openfun=lambda conn: engine.on_new_connection(
|
||||
LoggingDatabaseConnection(conn, engine, "on_new_connection")
|
||||
),
|
||||
**db_config.config.get("args", {})
|
||||
**db_config.config.get("args", {}),
|
||||
)
|
||||
|
||||
|
||||
|
@ -632,7 +632,7 @@ class DatabasePool:
|
|||
func,
|
||||
*args,
|
||||
db_autocommit=db_autocommit,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
for after_callback, after_args, after_kwargs in after_callbacks:
|
||||
|
|
|
@ -15,21 +15,31 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import re
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
|
||||
|
||||
from synapse.appservice import ApplicationService, AppServiceTransaction
|
||||
from synapse.appservice import (
|
||||
ApplicationService,
|
||||
ApplicationServiceState,
|
||||
AppServiceTransaction,
|
||||
)
|
||||
from synapse.config.appservice import load_appservices
|
||||
from synapse.events import EventBase
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.types import Connection
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _make_exclusive_regex(services_cache):
|
||||
def _make_exclusive_regex(
|
||||
services_cache: List[ApplicationService],
|
||||
) -> Optional[Pattern]:
|
||||
# We precompile a regex constructed from all the regexes that the AS's
|
||||
# have registered for exclusive users.
|
||||
exclusive_user_regexes = [
|
||||
|
@ -39,17 +49,19 @@ def _make_exclusive_regex(services_cache):
|
|||
]
|
||||
if exclusive_user_regexes:
|
||||
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
|
||||
exclusive_user_regex = re.compile(exclusive_user_regex)
|
||||
exclusive_user_pattern = re.compile(
|
||||
exclusive_user_regex
|
||||
) # type: Optional[Pattern]
|
||||
else:
|
||||
# We handle this case specially otherwise the constructed regex
|
||||
# will always match
|
||||
exclusive_user_regex = None
|
||||
exclusive_user_pattern = None
|
||||
|
||||
return exclusive_user_regex
|
||||
return exclusive_user_pattern
|
||||
|
||||
|
||||
class ApplicationServiceWorkerStore(SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||
self.services_cache = load_appservices(
|
||||
hs.hostname, hs.config.app_service_config_files
|
||||
)
|
||||
|
@ -60,7 +72,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
|
|||
def get_app_services(self):
|
||||
return self.services_cache
|
||||
|
||||
def get_if_app_services_interested_in_user(self, user_id):
|
||||
def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
|
||||
"""Check if the user is one associated with an app service (exclusively)
|
||||
"""
|
||||
if self.exclusive_user_regex:
|
||||
|
@ -68,7 +80,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
|
|||
else:
|
||||
return False
|
||||
|
||||
def get_app_service_by_user_id(self, user_id):
|
||||
def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]:
|
||||
"""Retrieve an application service from their user ID.
|
||||
|
||||
All application services have associated with them a particular user ID.
|
||||
|
@ -77,35 +89,35 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
|
|||
a user ID to an application service.
|
||||
|
||||
Args:
|
||||
user_id(str): The user ID to see if it is an application service.
|
||||
user_id: The user ID to see if it is an application service.
|
||||
Returns:
|
||||
synapse.appservice.ApplicationService or None.
|
||||
The application service or None.
|
||||
"""
|
||||
for service in self.services_cache:
|
||||
if service.sender == user_id:
|
||||
return service
|
||||
return None
|
||||
|
||||
def get_app_service_by_token(self, token):
|
||||
def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]:
|
||||
"""Get the application service with the given appservice token.
|
||||
|
||||
Args:
|
||||
token (str): The application service token.
|
||||
token: The application service token.
|
||||
Returns:
|
||||
synapse.appservice.ApplicationService or None.
|
||||
The application service or None.
|
||||
"""
|
||||
for service in self.services_cache:
|
||||
if service.token == token:
|
||||
return service
|
||||
return None
|
||||
|
||||
def get_app_service_by_id(self, as_id):
|
||||
def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]:
|
||||
"""Get the application service with the given appservice ID.
|
||||
|
||||
Args:
|
||||
as_id (str): The application service ID.
|
||||
as_id: The application service ID.
|
||||
Returns:
|
||||
synapse.appservice.ApplicationService or None.
|
||||
The application service or None.
|
||||
"""
|
||||
for service in self.services_cache:
|
||||
if service.id == as_id:
|
||||
|
@ -124,11 +136,13 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
|
|||
class ApplicationServiceTransactionWorkerStore(
|
||||
ApplicationServiceWorkerStore, EventsWorkerStore
|
||||
):
|
||||
async def get_appservices_by_state(self, state):
|
||||
async def get_appservices_by_state(
|
||||
self, state: ApplicationServiceState
|
||||
) -> List[ApplicationService]:
|
||||
"""Get a list of application services based on their state.
|
||||
|
||||
Args:
|
||||
state(ApplicationServiceState): The state to filter on.
|
||||
state: The state to filter on.
|
||||
Returns:
|
||||
A list of ApplicationServices, which may be empty.
|
||||
"""
|
||||
|
@ -145,13 +159,15 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
services.append(service)
|
||||
return services
|
||||
|
||||
async def get_appservice_state(self, service):
|
||||
async def get_appservice_state(
|
||||
self, service: ApplicationService
|
||||
) -> Optional[ApplicationServiceState]:
|
||||
"""Get the application service state.
|
||||
|
||||
Args:
|
||||
service(ApplicationService): The service whose state to set.
|
||||
service: The service whose state to set.
|
||||
Returns:
|
||||
An ApplicationServiceState.
|
||||
An ApplicationServiceState or none.
|
||||
"""
|
||||
result = await self.db_pool.simple_select_one(
|
||||
"application_services_state",
|
||||
|
@ -164,12 +180,14 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
return result.get("state")
|
||||
return None
|
||||
|
||||
async def set_appservice_state(self, service, state) -> None:
|
||||
async def set_appservice_state(
|
||||
self, service: ApplicationService, state: ApplicationServiceState
|
||||
) -> None:
|
||||
"""Set the application service state.
|
||||
|
||||
Args:
|
||||
service(ApplicationService): The service whose state to set.
|
||||
state(ApplicationServiceState): The connectivity state to apply.
|
||||
service: The service whose state to set.
|
||||
state: The connectivity state to apply.
|
||||
"""
|
||||
await self.db_pool.simple_upsert(
|
||||
"application_services_state", {"as_id": service.id}, {"state": state}
|
||||
|
@ -226,13 +244,14 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
"create_appservice_txn", _create_appservice_txn
|
||||
)
|
||||
|
||||
async def complete_appservice_txn(self, txn_id, service) -> None:
|
||||
async def complete_appservice_txn(
|
||||
self, txn_id: int, service: ApplicationService
|
||||
) -> None:
|
||||
"""Completes an application service transaction.
|
||||
|
||||
Args:
|
||||
txn_id(str): The transaction ID being completed.
|
||||
service(ApplicationService): The application service which was sent
|
||||
this transaction.
|
||||
txn_id: The transaction ID being completed.
|
||||
service: The application service which was sent this transaction.
|
||||
"""
|
||||
txn_id = int(txn_id)
|
||||
|
||||
|
@ -242,7 +261,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
# has probably missed some events), so whine loudly but still continue,
|
||||
# since it shouldn't fail completion of the transaction.
|
||||
last_txn_id = self._get_last_txn(txn, service.id)
|
||||
if (last_txn_id + 1) != txn_id:
|
||||
if (txn_id + 1) != txn_id:
|
||||
logger.error(
|
||||
"appservice: Completing a transaction which has an ID > 1 from "
|
||||
"the last ID sent to this AS. We've either dropped events or "
|
||||
|
@ -272,12 +291,13 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
"complete_appservice_txn", _complete_appservice_txn
|
||||
)
|
||||
|
||||
async def get_oldest_unsent_txn(self, service):
|
||||
"""Get the oldest transaction which has not been sent for this
|
||||
service.
|
||||
async def get_oldest_unsent_txn(
|
||||
self, service: ApplicationService
|
||||
) -> Optional[AppServiceTransaction]:
|
||||
"""Get the oldest transaction which has not been sent for this service.
|
||||
|
||||
Args:
|
||||
service(ApplicationService): The app service to get the oldest txn.
|
||||
service: The app service to get the oldest txn.
|
||||
Returns:
|
||||
An AppServiceTransaction or None.
|
||||
"""
|
||||
|
@ -313,7 +333,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
service=service, id=entry["txn_id"], events=events, ephemeral=[]
|
||||
)
|
||||
|
||||
def _get_last_txn(self, txn, service_id):
|
||||
def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
|
||||
txn.execute(
|
||||
"SELECT last_txn FROM application_services_state WHERE as_id=?",
|
||||
(service_id,),
|
||||
|
@ -324,7 +344,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
else:
|
||||
return int(last_txn_id[0]) # select 'last_txn' col
|
||||
|
||||
async def set_appservice_last_pos(self, pos) -> None:
|
||||
async def set_appservice_last_pos(self, pos: int) -> None:
|
||||
def set_appservice_last_pos_txn(txn):
|
||||
txn.execute(
|
||||
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
|
||||
|
@ -334,7 +354,9 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
"set_appservice_last_pos", set_appservice_last_pos_txn
|
||||
)
|
||||
|
||||
async def get_new_events_for_appservice(self, current_id, limit):
|
||||
async def get_new_events_for_appservice(
|
||||
self, current_id: int, limit: int
|
||||
) -> Tuple[int, List[EventBase]]:
|
||||
"""Get all new events for an appservice"""
|
||||
|
||||
def get_new_events_for_appservice_txn(txn):
|
||||
|
@ -394,7 +416,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
)
|
||||
|
||||
async def set_type_stream_id_for_appservice(
|
||||
self, service: ApplicationService, type: str, pos: int
|
||||
self, service: ApplicationService, type: str, pos: Optional[int]
|
||||
) -> None:
|
||||
if type not in ("read_receipt", "presence"):
|
||||
raise ValueError(
|
||||
|
|
|
@ -22,7 +22,7 @@ from synapse.storage._base import SQLBaseStore
|
|||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.util.frozenutils import frozendict_json_encoder
|
||||
from synapse.util import json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -104,7 +104,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
|||
and original_event.internal_metadata.is_redacted()
|
||||
):
|
||||
# Redaction was allowed
|
||||
pruned_json = frozendict_json_encoder.encode(
|
||||
pruned_json = json_encoder.encode(
|
||||
prune_event_dict(
|
||||
original_event.room_version, original_event.get_dict()
|
||||
)
|
||||
|
@ -170,7 +170,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
|||
return
|
||||
|
||||
# Prune the event's dict then convert it to JSON.
|
||||
pruned_json = frozendict_json_encoder.encode(
|
||||
pruned_json = json_encoder.encode(
|
||||
prune_event_dict(event.room_version, event.get_dict())
|
||||
)
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
|
|||
from synapse.storage.databases.main.search import SearchEntry
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import StateMap, get_domain_from_id
|
||||
from synapse.util.frozenutils import frozendict_json_encoder
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -769,9 +769,7 @@ class PersistEventsStore:
|
|||
logger.exception("")
|
||||
raise
|
||||
|
||||
metadata_json = frozendict_json_encoder.encode(
|
||||
event.internal_metadata.get_dict()
|
||||
)
|
||||
metadata_json = json_encoder.encode(event.internal_metadata.get_dict())
|
||||
|
||||
sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
|
||||
txn.execute(sql, (metadata_json, event.event_id))
|
||||
|
@ -826,10 +824,10 @@ class PersistEventsStore:
|
|||
{
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"internal_metadata": frozendict_json_encoder.encode(
|
||||
"internal_metadata": json_encoder.encode(
|
||||
event.internal_metadata.get_dict()
|
||||
),
|
||||
"json": frozendict_json_encoder.encode(event_dict(event)),
|
||||
"json": json_encoder.encode(event_dict(event)),
|
||||
"format_version": event.format_version,
|
||||
}
|
||||
for event, _ in events_and_contexts
|
||||
|
|
|
@ -31,6 +31,7 @@ from synapse.api.room_versions import (
|
|||
RoomVersions,
|
||||
)
|
||||
from synapse.events import EventBase, make_event_from_dict
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.logging.context import PreserveLoggingContext, current_context
|
||||
from synapse.metrics.background_process_metrics import (
|
||||
|
@ -44,7 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
|
|||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
||||
from synapse.types import Collection, get_domain_from_id
|
||||
from synapse.types import Collection, JsonDict, get_domain_from_id
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
@ -525,6 +526,57 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
return event_map
|
||||
|
||||
async def get_stripped_room_state_from_event_context(
|
||||
self,
|
||||
context: EventContext,
|
||||
state_types_to_include: List[EventTypes],
|
||||
membership_user_id: Optional[str] = None,
|
||||
) -> List[JsonDict]:
|
||||
"""
|
||||
Retrieve the stripped state from a room, given an event context to retrieve state
|
||||
from as well as the state types to include. Optionally, include the membership
|
||||
events from a specific user.
|
||||
|
||||
"Stripped" state means that only the `type`, `state_key`, `content` and `sender` keys
|
||||
are included from each state event.
|
||||
|
||||
Args:
|
||||
context: The event context to retrieve state of the room from.
|
||||
state_types_to_include: The type of state events to include.
|
||||
membership_user_id: An optional user ID to include the stripped membership state
|
||||
events of. This is useful when generating the stripped state of a room for
|
||||
invites. We want to send membership events of the inviter, so that the
|
||||
invitee can display the inviter's profile information if the room lacks any.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries, each representing a stripped state event from the room.
|
||||
"""
|
||||
current_state_ids = await context.get_current_state_ids()
|
||||
|
||||
# We know this event is not an outlier, so this must be
|
||||
# non-None.
|
||||
assert current_state_ids is not None
|
||||
|
||||
# The state to include
|
||||
state_to_include_ids = [
|
||||
e_id
|
||||
for k, e_id in current_state_ids.items()
|
||||
if k[0] in state_types_to_include
|
||||
or (membership_user_id and k == (EventTypes.Member, membership_user_id))
|
||||
]
|
||||
|
||||
state_to_include = await self.get_events(state_to_include_ids)
|
||||
|
||||
return [
|
||||
{
|
||||
"type": e.type,
|
||||
"state_key": e.state_key,
|
||||
"content": e.content,
|
||||
"sender": e.sender,
|
||||
}
|
||||
for e in state_to_include.values()
|
||||
]
|
||||
|
||||
def _do_fetch(self, conn):
|
||||
"""Takes a database connection and waits for requests for events from
|
||||
the _event_fetch_list queue.
|
||||
|
@ -1065,11 +1117,13 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
def get_all_new_forward_event_rows(txn):
|
||||
sql = (
|
||||
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" state_key, redacts, relates_to_id"
|
||||
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
|
||||
" FROM events AS e"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events USING (event_id)"
|
||||
" LEFT JOIN event_relations USING (event_id)"
|
||||
" LEFT JOIN room_memberships USING (event_id)"
|
||||
" LEFT JOIN rejections USING (event_id)"
|
||||
" WHERE ? < stream_ordering AND stream_ordering <= ?"
|
||||
" AND instance_name = ?"
|
||||
" ORDER BY stream_ordering ASC"
|
||||
|
@ -1100,12 +1154,14 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
def get_ex_outlier_stream_rows_txn(txn):
|
||||
sql = (
|
||||
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" state_key, redacts, relates_to_id"
|
||||
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
|
||||
" FROM events AS e"
|
||||
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events USING (event_id)"
|
||||
" LEFT JOIN event_relations USING (event_id)"
|
||||
" LEFT JOIN room_memberships USING (event_id)"
|
||||
" LEFT JOIN rejections USING (event_id)"
|
||||
" WHERE ? < event_stream_ordering"
|
||||
" AND event_stream_ordering <= ?"
|
||||
" AND out.instance_name = ?"
|
||||
|
|
|
@ -452,6 +452,33 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
desc="get_remote_media_thumbnails",
|
||||
)
|
||||
|
||||
async def get_remote_media_thumbnail(
|
||||
self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch the thumbnail info of given width, height and type.
|
||||
"""
|
||||
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="remote_media_cache_thumbnails",
|
||||
keyvalues={
|
||||
"media_origin": origin,
|
||||
"media_id": media_id,
|
||||
"thumbnail_width": t_width,
|
||||
"thumbnail_height": t_height,
|
||||
"thumbnail_type": t_type,
|
||||
},
|
||||
retcols=(
|
||||
"thumbnail_width",
|
||||
"thumbnail_height",
|
||||
"thumbnail_method",
|
||||
"thumbnail_type",
|
||||
"thumbnail_length",
|
||||
"filesystem_id",
|
||||
),
|
||||
allow_none=True,
|
||||
desc="get_remote_media_thumbnail",
|
||||
)
|
||||
|
||||
async def store_remote_media_thumbnail(
|
||||
self,
|
||||
origin,
|
||||
|
|
|
@ -18,6 +18,8 @@ import logging
|
|||
import re
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
|
@ -38,6 +40,35 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True)
|
||||
class TokenLookupResult:
|
||||
"""Result of looking up an access token.
|
||||
|
||||
Attributes:
|
||||
user_id: The user that this token authenticates as
|
||||
is_guest
|
||||
shadow_banned
|
||||
token_id: The ID of the access token looked up
|
||||
device_id: The device associated with the token, if any.
|
||||
valid_until_ms: The timestamp the token expires, if any.
|
||||
token_owner: The "owner" of the token. This is either the same as the
|
||||
user, or a server admin who is logged in as the user.
|
||||
"""
|
||||
|
||||
user_id = attr.ib(type=str)
|
||||
is_guest = attr.ib(type=bool, default=False)
|
||||
shadow_banned = attr.ib(type=bool, default=False)
|
||||
token_id = attr.ib(type=Optional[int], default=None)
|
||||
device_id = attr.ib(type=Optional[str], default=None)
|
||||
valid_until_ms = attr.ib(type=Optional[int], default=None)
|
||||
token_owner = attr.ib(type=str)
|
||||
|
||||
# Make the token owner default to the user ID, which is the common case.
|
||||
@token_owner.default
|
||||
def _default_token_owner(self):
|
||||
return self.user_id
|
||||
|
||||
|
||||
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
@ -102,15 +133,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
return is_trial
|
||||
|
||||
@cached()
|
||||
async def get_user_by_access_token(self, token: str) -> Optional[dict]:
|
||||
async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]:
|
||||
"""Get a user from the given access token.
|
||||
|
||||
Args:
|
||||
token: The access token of a user.
|
||||
Returns:
|
||||
None, if the token did not match, otherwise dict
|
||||
including the keys `name`, `is_guest`, `device_id`, `token_id`,
|
||||
`valid_until_ms`.
|
||||
None, if the token did not match, otherwise a `TokenLookupResult`
|
||||
"""
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_user_by_access_token", self._query_for_auth, token
|
||||
|
@ -331,23 +360,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
|
||||
|
||||
def _query_for_auth(self, txn, token):
|
||||
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
|
||||
sql = """
|
||||
SELECT users.name,
|
||||
SELECT users.name as user_id,
|
||||
users.is_guest,
|
||||
users.shadow_banned,
|
||||
access_tokens.id as token_id,
|
||||
access_tokens.device_id,
|
||||
access_tokens.valid_until_ms
|
||||
access_tokens.valid_until_ms,
|
||||
access_tokens.user_id as token_owner
|
||||
FROM users
|
||||
INNER JOIN access_tokens on users.name = access_tokens.user_id
|
||||
INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
|
||||
WHERE token = ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (token,))
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
if rows:
|
||||
return rows[0]
|
||||
return TokenLookupResult(**rows[0])
|
||||
|
||||
return None
|
||||
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Whether the access token is an admin token for controlling another user.
|
||||
ALTER TABLE access_tokens ADD COLUMN puppets_user_id TEXT;
|
|
@ -29,6 +29,7 @@ from typing import (
|
|||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import attr
|
||||
|
@ -38,6 +39,7 @@ from unpaddedbase64 import decode_base64
|
|||
from synapse.api.errors import Codes, SynapseError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.appservice.api import ApplicationService
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
# define a version of typing.Collection that works on python 3.5
|
||||
|
@ -74,6 +76,7 @@ class Requester(
|
|||
"shadow_banned",
|
||||
"device_id",
|
||||
"app_service",
|
||||
"authenticated_entity",
|
||||
],
|
||||
)
|
||||
):
|
||||
|
@ -104,6 +107,7 @@ class Requester(
|
|||
"shadow_banned": self.shadow_banned,
|
||||
"device_id": self.device_id,
|
||||
"app_server_id": self.app_service.id if self.app_service else None,
|
||||
"authenticated_entity": self.authenticated_entity,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
@ -129,16 +133,18 @@ class Requester(
|
|||
shadow_banned=input["shadow_banned"],
|
||||
device_id=input["device_id"],
|
||||
app_service=appservice,
|
||||
authenticated_entity=input["authenticated_entity"],
|
||||
)
|
||||
|
||||
|
||||
def create_requester(
|
||||
user_id,
|
||||
access_token_id=None,
|
||||
is_guest=False,
|
||||
shadow_banned=False,
|
||||
device_id=None,
|
||||
app_service=None,
|
||||
user_id: Union[str, "UserID"],
|
||||
access_token_id: Optional[int] = None,
|
||||
is_guest: Optional[bool] = False,
|
||||
shadow_banned: Optional[bool] = False,
|
||||
device_id: Optional[str] = None,
|
||||
app_service: Optional["ApplicationService"] = None,
|
||||
authenticated_entity: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Create a new ``Requester`` object
|
||||
|
@ -151,14 +157,27 @@ def create_requester(
|
|||
shadow_banned (bool): True if the user making this request is shadow-banned.
|
||||
device_id (str|None): device_id which was set at authentication time
|
||||
app_service (ApplicationService|None): the AS requesting on behalf of the user
|
||||
authenticated_entity: The entity that authenticated when making the request.
|
||||
This is different to the user_id when an admin user or the server is
|
||||
"puppeting" the user.
|
||||
|
||||
Returns:
|
||||
Requester
|
||||
"""
|
||||
if not isinstance(user_id, UserID):
|
||||
user_id = UserID.from_string(user_id)
|
||||
|
||||
if authenticated_entity is None:
|
||||
authenticated_entity = user_id.to_string()
|
||||
|
||||
return Requester(
|
||||
user_id, access_token_id, is_guest, shadow_banned, device_id, app_service
|
||||
user_id,
|
||||
access_token_id,
|
||||
is_guest,
|
||||
shadow_banned,
|
||||
device_id,
|
||||
app_service,
|
||||
authenticated_entity,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ import logging
|
|||
import re
|
||||
|
||||
import attr
|
||||
from frozendict import frozendict
|
||||
|
||||
from twisted.internet import defer, task
|
||||
|
||||
|
@ -31,9 +32,26 @@ def _reject_invalid_json(val):
|
|||
raise ValueError("Invalid JSON value: '%s'" % val)
|
||||
|
||||
|
||||
# Create a custom encoder to reduce the whitespace produced by JSON encoding and
|
||||
# ensure that valid JSON is produced.
|
||||
json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":"))
|
||||
def _handle_frozendict(obj):
|
||||
"""Helper for json_encoder. Makes frozendicts serializable by returning
|
||||
the underlying dict
|
||||
"""
|
||||
if type(obj) is frozendict:
|
||||
# fishing the protected dict out of the object is a bit nasty,
|
||||
# but we don't really want the overhead of copying the dict.
|
||||
return obj._dict
|
||||
raise TypeError(
|
||||
"Object of type %s is not JSON serializable" % obj.__class__.__name__
|
||||
)
|
||||
|
||||
|
||||
# A custom JSON encoder which:
|
||||
# * handles frozendicts
|
||||
# * produces valid JSON (no NaNs etc)
|
||||
# * reduces redundant whitespace
|
||||
json_encoder = json.JSONEncoder(
|
||||
allow_nan=False, separators=(",", ":"), default=_handle_frozendict
|
||||
)
|
||||
|
||||
# Create a custom decoder to reject Python extensions to JSON.
|
||||
json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
|
||||
|
|
|
@ -13,10 +13,23 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import enum
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Iterable,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from twisted.internet import defer
|
||||
|
@ -24,6 +37,7 @@ from twisted.internet import defer
|
|||
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.caches.deferred_cache import DeferredCache
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -48,7 +62,7 @@ class _CachedFunction(Generic[F]):
|
|||
|
||||
|
||||
class _CacheDescriptorBase:
|
||||
def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
|
||||
def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
|
||||
self.orig = orig
|
||||
|
||||
arg_spec = inspect.getfullargspec(orig)
|
||||
|
@ -97,8 +111,107 @@ class _CacheDescriptorBase:
|
|||
|
||||
self.add_cache_context = cache_context
|
||||
|
||||
self.cache_key_builder = get_cache_key_builder(
|
||||
self.arg_names, self.arg_defaults
|
||||
)
|
||||
|
||||
class CacheDescriptor(_CacheDescriptorBase):
|
||||
|
||||
class _LruCachedFunction(Generic[F]):
|
||||
cache = None # type: LruCache[CacheKey, Any]
|
||||
__call__ = None # type: F
|
||||
|
||||
|
||||
def lru_cache(
|
||||
max_entries: int = 1000, cache_context: bool = False,
|
||||
) -> Callable[[F], _LruCachedFunction[F]]:
|
||||
"""A method decorator that applies a memoizing cache around the function.
|
||||
|
||||
This is more-or-less a drop-in equivalent to functools.lru_cache, although note
|
||||
that the signature is slightly different.
|
||||
|
||||
The main differences with functools.lru_cache are:
|
||||
(a) the size of the cache can be controlled via the cache_factor mechanism
|
||||
(b) the wrapped function can request a "cache_context" which provides a
|
||||
callback mechanism to indicate that the result is no longer valid
|
||||
(c) prometheus metrics are exposed automatically.
|
||||
|
||||
The function should take zero or more arguments, which are used as the key for the
|
||||
cache. Single-argument functions use that argument as the cache key; otherwise the
|
||||
arguments are built into a tuple.
|
||||
|
||||
Cached functions can be "chained" (i.e. a cached function can call other cached
|
||||
functions and get appropriately invalidated when they called caches are
|
||||
invalidated) by adding a special "cache_context" argument to the function
|
||||
and passing that as a kwarg to all caches called. For example:
|
||||
|
||||
@lru_cache(cache_context=True)
|
||||
def foo(self, key, cache_context):
|
||||
r1 = self.bar1(key, on_invalidate=cache_context.invalidate)
|
||||
r2 = self.bar2(key, on_invalidate=cache_context.invalidate)
|
||||
return r1 + r2
|
||||
|
||||
The wrapped function also has a 'cache' property which offers direct access to the
|
||||
underlying LruCache.
|
||||
"""
|
||||
|
||||
def func(orig: F) -> _LruCachedFunction[F]:
|
||||
desc = LruCacheDescriptor(
|
||||
orig, max_entries=max_entries, cache_context=cache_context,
|
||||
)
|
||||
return cast(_LruCachedFunction[F], desc)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
class LruCacheDescriptor(_CacheDescriptorBase):
|
||||
"""Helper for @lru_cache"""
|
||||
|
||||
class _Sentinel(enum.Enum):
|
||||
sentinel = object()
|
||||
|
||||
def __init__(
|
||||
self, orig, max_entries: int = 1000, cache_context: bool = False,
|
||||
):
|
||||
super().__init__(orig, num_args=None, cache_context=cache_context)
|
||||
self.max_entries = max_entries
|
||||
|
||||
def __get__(self, obj, owner):
|
||||
cache = LruCache(
|
||||
cache_name=self.orig.__name__, max_size=self.max_entries,
|
||||
) # type: LruCache[CacheKey, Any]
|
||||
|
||||
get_cache_key = self.cache_key_builder
|
||||
sentinel = LruCacheDescriptor._Sentinel.sentinel
|
||||
|
||||
@functools.wraps(self.orig)
|
||||
def _wrapped(*args, **kwargs):
|
||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||
callbacks = (invalidate_callback,) if invalidate_callback else ()
|
||||
|
||||
cache_key = get_cache_key(args, kwargs)
|
||||
|
||||
ret = cache.get(cache_key, default=sentinel, callbacks=callbacks)
|
||||
if ret != sentinel:
|
||||
return ret
|
||||
|
||||
# Add our own `cache_context` to argument list if the wrapped function
|
||||
# has asked for one
|
||||
if self.add_cache_context:
|
||||
kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
|
||||
|
||||
ret2 = self.orig(obj, *args, **kwargs)
|
||||
cache.set(cache_key, ret2, callbacks=callbacks)
|
||||
|
||||
return ret2
|
||||
|
||||
wrapped = cast(_CachedFunction, _wrapped)
|
||||
wrapped.cache = cache
|
||||
obj.__dict__[self.orig.__name__] = wrapped
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||
""" A method decorator that applies a memoizing cache around the function.
|
||||
|
||||
This caches deferreds, rather than the results themselves. Deferreds that
|
||||
|
@ -141,7 +254,6 @@ class CacheDescriptor(_CacheDescriptorBase):
|
|||
cache_context=False,
|
||||
iterable=False,
|
||||
):
|
||||
|
||||
super().__init__(orig, num_args=num_args, cache_context=cache_context)
|
||||
|
||||
self.max_entries = max_entries
|
||||
|
@ -157,41 +269,7 @@ class CacheDescriptor(_CacheDescriptorBase):
|
|||
iterable=self.iterable,
|
||||
) # type: DeferredCache[CacheKey, Any]
|
||||
|
||||
def get_cache_key_gen(args, kwargs):
|
||||
"""Given some args/kwargs return a generator that resolves into
|
||||
the cache_key.
|
||||
|
||||
We loop through each arg name, looking up if its in the `kwargs`,
|
||||
otherwise using the next argument in `args`. If there are no more
|
||||
args then we try looking the arg name up in the defaults
|
||||
"""
|
||||
pos = 0
|
||||
for nm in self.arg_names:
|
||||
if nm in kwargs:
|
||||
yield kwargs[nm]
|
||||
elif pos < len(args):
|
||||
yield args[pos]
|
||||
pos += 1
|
||||
else:
|
||||
yield self.arg_defaults[nm]
|
||||
|
||||
# By default our cache key is a tuple, but if there is only one item
|
||||
# then don't bother wrapping in a tuple. This is to save memory.
|
||||
if self.num_args == 1:
|
||||
nm = self.arg_names[0]
|
||||
|
||||
def get_cache_key(args, kwargs):
|
||||
if nm in kwargs:
|
||||
return kwargs[nm]
|
||||
elif len(args):
|
||||
return args[0]
|
||||
else:
|
||||
return self.arg_defaults[nm]
|
||||
|
||||
else:
|
||||
|
||||
def get_cache_key(args, kwargs):
|
||||
return tuple(get_cache_key_gen(args, kwargs))
|
||||
get_cache_key = self.cache_key_builder
|
||||
|
||||
@functools.wraps(self.orig)
|
||||
def _wrapped(*args, **kwargs):
|
||||
|
@ -223,7 +301,6 @@ class CacheDescriptor(_CacheDescriptorBase):
|
|||
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
|
||||
else:
|
||||
wrapped.invalidate = cache.invalidate
|
||||
wrapped.invalidate_all = cache.invalidate_all
|
||||
wrapped.invalidate_many = cache.invalidate_many
|
||||
wrapped.prefill = cache.prefill
|
||||
|
||||
|
@ -236,7 +313,7 @@ class CacheDescriptor(_CacheDescriptorBase):
|
|||
return wrapped
|
||||
|
||||
|
||||
class CacheListDescriptor(_CacheDescriptorBase):
|
||||
class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
||||
"""Wraps an existing cache to support bulk fetching of keys.
|
||||
|
||||
Given a list of keys it looks in the cache to find any hits, then passes
|
||||
|
@ -382,11 +459,13 @@ class _CacheContext:
|
|||
on a lower level.
|
||||
"""
|
||||
|
||||
Cache = Union[DeferredCache, LruCache]
|
||||
|
||||
_cache_context_objects = (
|
||||
WeakValueDictionary()
|
||||
) # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext]
|
||||
) # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext]
|
||||
|
||||
def __init__(self, cache, cache_key): # type: (DeferredCache, CacheKey) -> None
|
||||
def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None:
|
||||
self._cache = cache
|
||||
self._cache_key = cache_key
|
||||
|
||||
|
@ -396,8 +475,8 @@ class _CacheContext:
|
|||
|
||||
@classmethod
|
||||
def get_instance(
|
||||
cls, cache, cache_key
|
||||
): # type: (DeferredCache, CacheKey) -> _CacheContext
|
||||
cls, cache: "_CacheContext.Cache", cache_key: CacheKey
|
||||
) -> "_CacheContext":
|
||||
"""Returns an instance constructed with the given arguments.
|
||||
|
||||
A new instance is only created if none already exists.
|
||||
|
@ -418,7 +497,7 @@ def cached(
|
|||
cache_context: bool = False,
|
||||
iterable: bool = False,
|
||||
) -> Callable[[F], _CachedFunction[F]]:
|
||||
func = lambda orig: CacheDescriptor(
|
||||
func = lambda orig: DeferredCacheDescriptor(
|
||||
orig,
|
||||
max_entries=max_entries,
|
||||
num_args=num_args,
|
||||
|
@ -460,7 +539,7 @@ def cachedList(
|
|||
def batch_do_something(self, first_arg, second_args):
|
||||
...
|
||||
"""
|
||||
func = lambda orig: CacheListDescriptor(
|
||||
func = lambda orig: DeferredCacheListDescriptor(
|
||||
orig,
|
||||
cached_method_name=cached_method_name,
|
||||
list_name=list_name,
|
||||
|
@ -468,3 +547,65 @@ def cachedList(
|
|||
)
|
||||
|
||||
return cast(Callable[[F], _CachedFunction[F]], func)
|
||||
|
||||
|
||||
def get_cache_key_builder(
|
||||
param_names: Sequence[str], param_defaults: Mapping[str, Any]
|
||||
) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]:
|
||||
"""Construct a function which will build cache keys suitable for a cached function
|
||||
|
||||
Args:
|
||||
param_names: list of formal parameter names for the cached function
|
||||
param_defaults: a mapping from parameter name to default value for that param
|
||||
|
||||
Returns:
|
||||
A function which will take an (args, kwargs) pair and return a cache key
|
||||
"""
|
||||
|
||||
# By default our cache key is a tuple, but if there is only one item
|
||||
# then don't bother wrapping in a tuple. This is to save memory.
|
||||
|
||||
if len(param_names) == 1:
|
||||
nm = param_names[0]
|
||||
|
||||
def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
|
||||
if nm in kwargs:
|
||||
return kwargs[nm]
|
||||
elif len(args):
|
||||
return args[0]
|
||||
else:
|
||||
return param_defaults[nm]
|
||||
|
||||
else:
|
||||
|
||||
def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
|
||||
return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs))
|
||||
|
||||
return get_cache_key
|
||||
|
||||
|
||||
def _get_cache_key_gen(
|
||||
param_names: Iterable[str],
|
||||
param_defaults: Mapping[str, Any],
|
||||
args: Sequence[Any],
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> Iterable[Any]:
|
||||
"""Given some args/kwargs return a generator that resolves into
|
||||
the cache_key.
|
||||
|
||||
This is essentially the same operation as `inspect.getcallargs`, but optimised so
|
||||
that we don't need to inspect the target function for each call.
|
||||
"""
|
||||
|
||||
# We loop through each arg name, looking up if its in the `kwargs`,
|
||||
# otherwise using the next argument in `args`. If there are no more
|
||||
# args then we try looking the arg name up in the defaults.
|
||||
pos = 0
|
||||
for nm in param_names:
|
||||
if nm in kwargs:
|
||||
yield kwargs[nm]
|
||||
elif pos < len(args):
|
||||
yield args[pos]
|
||||
pos += 1
|
||||
else:
|
||||
yield param_defaults[nm]
|
||||
|
|
|
@ -13,8 +13,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
|
||||
from frozendict import frozendict
|
||||
|
||||
|
||||
|
@ -49,23 +47,3 @@ def unfreeze(o):
|
|||
pass
|
||||
|
||||
return o
|
||||
|
||||
|
||||
def _handle_frozendict(obj):
|
||||
"""Helper for EventEncoder. Makes frozendicts serializable by returning
|
||||
the underlying dict
|
||||
"""
|
||||
if type(obj) is frozendict:
|
||||
# fishing the protected dict out of the object is a bit nasty,
|
||||
# but we don't really want the overhead of copying the dict.
|
||||
return obj._dict
|
||||
raise TypeError(
|
||||
"Object of type %s is not JSON serializable" % obj.__class__.__name__
|
||||
)
|
||||
|
||||
|
||||
# A JSONEncoder which is capable of encoding frozendicts without barfing.
|
||||
# Additionally reduce the whitespace produced by JSON encoding.
|
||||
frozendict_json_encoder = json.JSONEncoder(
|
||||
allow_nan=False, separators=(",", ":"), default=_handle_frozendict,
|
||||
)
|
||||
|
|
|
@ -110,7 +110,7 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k
|
|||
failure_ts,
|
||||
retry_interval,
|
||||
backoff_on_failure=backoff_on_failure,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -21,45 +21,6 @@ except ImportError:
|
|||
from twisted.internet.pollreactor import PollReactor as Reactor
|
||||
from twisted.internet.main import installReactor
|
||||
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.utils import default_config, setup_test_homeserver
|
||||
|
||||
|
||||
async def make_homeserver(reactor, config=None):
|
||||
"""
|
||||
Make a Homeserver suitable for running benchmarks against.
|
||||
|
||||
Args:
|
||||
reactor: A Twisted reactor to run under.
|
||||
config: A HomeServerConfig to use, or None.
|
||||
"""
|
||||
cleanup_tasks = []
|
||||
clock = Clock(reactor)
|
||||
|
||||
if not config:
|
||||
config = default_config("test")
|
||||
|
||||
config_obj = HomeServerConfig()
|
||||
config_obj.parse_config_dict(config, "", "")
|
||||
|
||||
hs = setup_test_homeserver(
|
||||
cleanup_tasks.append, config=config_obj, reactor=reactor, clock=clock
|
||||
)
|
||||
stor = hs.get_datastore()
|
||||
|
||||
# Run the database background updates.
|
||||
if hasattr(stor.db_pool.updates, "do_next_background_update"):
|
||||
while not await stor.db_pool.updates.has_completed_background_updates():
|
||||
await stor.db_pool.updates.do_next_background_update(1)
|
||||
|
||||
def cleanup():
|
||||
for i in cleanup_tasks:
|
||||
i()
|
||||
|
||||
return hs, clock.sleep, cleanup
|
||||
|
||||
|
||||
def make_reactor():
|
||||
"""
|
||||
|
|
|
@ -12,20 +12,20 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from argparse import REMAINDER
|
||||
from contextlib import redirect_stderr
|
||||
from io import StringIO
|
||||
|
||||
import pyperf
|
||||
from synmark import make_reactor
|
||||
from synmark.suites import SUITES
|
||||
|
||||
from twisted.internet.defer import Deferred, ensureDeferred
|
||||
from twisted.logger import globalLogBeginner, textFileLogObserver
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from synmark import make_reactor
|
||||
from synmark.suites import SUITES
|
||||
|
||||
from tests.utils import setupdb
|
||||
|
||||
|
||||
|
|
|
@ -13,20 +13,22 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from io import StringIO
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from pyperf import perf_counter
|
||||
from synmark import make_homeserver
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.protocol import ServerFactory
|
||||
from twisted.logger import LogBeginner, Logger, LogPublisher
|
||||
from twisted.logger import LogBeginner, LogPublisher
|
||||
from twisted.protocols.basic import LineOnlyReceiver
|
||||
|
||||
from synapse.logging._structured import setup_structured_logging
|
||||
from synapse.config.logger import _setup_stdlib_logging
|
||||
from synapse.logging import RemoteHandler
|
||||
from synapse.util import Clock
|
||||
|
||||
|
||||
class LineCounter(LineOnlyReceiver):
|
||||
|
@ -62,7 +64,15 @@ async def main(reactor, loops):
|
|||
logger_factory.on_done = Deferred()
|
||||
port = reactor.listenTCP(0, logger_factory, interface="127.0.0.1")
|
||||
|
||||
hs, wait, cleanup = await make_homeserver(reactor)
|
||||
# A fake homeserver config.
|
||||
class Config:
|
||||
server_name = "synmark-" + str(loops)
|
||||
no_redirect_stdio = True
|
||||
|
||||
hs_config = Config()
|
||||
|
||||
# To be able to sleep.
|
||||
clock = Clock(reactor)
|
||||
|
||||
errors = StringIO()
|
||||
publisher = LogPublisher()
|
||||
|
@ -72,47 +82,49 @@ async def main(reactor, loops):
|
|||
)
|
||||
|
||||
log_config = {
|
||||
"loggers": {"synapse": {"level": "DEBUG"}},
|
||||
"drains": {
|
||||
"version": 1,
|
||||
"loggers": {"synapse": {"level": "DEBUG", "handlers": ["tersejson"]}},
|
||||
"formatters": {"tersejson": {"class": "synapse.logging.TerseJsonFormatter"}},
|
||||
"handlers": {
|
||||
"tersejson": {
|
||||
"type": "network_json_terse",
|
||||
"class": "synapse.logging.RemoteHandler",
|
||||
"host": "127.0.0.1",
|
||||
"port": port.getHost().port,
|
||||
"maximum_buffer": 100,
|
||||
"_reactor": reactor,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
logger = Logger(namespace="synapse.logging.test_terse_json", observer=publisher)
|
||||
logging_system = setup_structured_logging(
|
||||
hs, hs.config, log_config, logBeginner=beginner, redirect_stdlib_logging=False
|
||||
logger = logging.getLogger("synapse.logging.test_terse_json")
|
||||
_setup_stdlib_logging(
|
||||
hs_config, log_config, logBeginner=beginner,
|
||||
)
|
||||
|
||||
# Wait for it to connect...
|
||||
await logging_system._observers[0]._service.whenConnected()
|
||||
for handler in logging.getLogger("synapse").handlers:
|
||||
if isinstance(handler, RemoteHandler):
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("Improperly configured: no RemoteHandler found.")
|
||||
|
||||
await handler._service.whenConnected()
|
||||
|
||||
start = perf_counter()
|
||||
|
||||
# Send a bunch of useful messages
|
||||
for i in range(0, loops):
|
||||
logger.info("test message %s" % (i,))
|
||||
logger.info("test message %s", i)
|
||||
|
||||
if (
|
||||
len(logging_system._observers[0]._buffer)
|
||||
== logging_system._observers[0].maximum_buffer
|
||||
):
|
||||
while (
|
||||
len(logging_system._observers[0]._buffer)
|
||||
> logging_system._observers[0].maximum_buffer / 2
|
||||
):
|
||||
await wait(0.01)
|
||||
if len(handler._buffer) == handler.maximum_buffer:
|
||||
while len(handler._buffer) > handler.maximum_buffer / 2:
|
||||
await clock.sleep(0.01)
|
||||
|
||||
await logger_factory.on_done
|
||||
|
||||
end = perf_counter() - start
|
||||
|
||||
logging_system.stop()
|
||||
handler.close()
|
||||
port.stopListening()
|
||||
cleanup()
|
||||
|
||||
return end
|
||||
|
|
|
@ -29,6 +29,7 @@ from synapse.api.errors import (
|
|||
MissingClientTokenError,
|
||||
ResourceLimitError,
|
||||
)
|
||||
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||
from synapse.types import UserID
|
||||
|
||||
from tests import unittest
|
||||
|
@ -61,7 +62,9 @@ class AuthTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_user_by_req_user_valid_token(self):
|
||||
user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
|
||||
user_info = TokenLookupResult(
|
||||
user_id=self.test_user, token_id=5, device_id="device"
|
||||
)
|
||||
self.store.get_user_by_access_token = Mock(
|
||||
return_value=defer.succeed(user_info)
|
||||
)
|
||||
|
@ -84,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
|
||||
|
||||
def test_get_user_by_req_user_missing_token(self):
|
||||
user_info = {"name": self.test_user, "token_id": "ditto"}
|
||||
user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
|
||||
self.store.get_user_by_access_token = Mock(
|
||||
return_value=defer.succeed(user_info)
|
||||
)
|
||||
|
@ -221,7 +224,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
def test_get_user_from_macaroon(self):
|
||||
self.store.get_user_by_access_token = Mock(
|
||||
return_value=defer.succeed(
|
||||
{"name": "@baldrick:matrix.org", "device_id": "device"}
|
||||
TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -237,12 +240,11 @@ class AuthTestCase(unittest.TestCase):
|
|||
user_info = yield defer.ensureDeferred(
|
||||
self.auth.get_user_by_access_token(macaroon.serialize())
|
||||
)
|
||||
user = user_info["user"]
|
||||
self.assertEqual(UserID.from_string(user_id), user)
|
||||
self.assertEqual(user_id, user_info.user_id)
|
||||
|
||||
# TODO: device_id should come from the macaroon, but currently comes
|
||||
# from the db.
|
||||
self.assertEqual(user_info["device_id"], "device")
|
||||
self.assertEqual(user_info.device_id, "device")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_guest_user_from_macaroon(self):
|
||||
|
@ -264,10 +266,8 @@ class AuthTestCase(unittest.TestCase):
|
|||
user_info = yield defer.ensureDeferred(
|
||||
self.auth.get_user_by_access_token(serialized)
|
||||
)
|
||||
user = user_info["user"]
|
||||
is_guest = user_info["is_guest"]
|
||||
self.assertEqual(UserID.from_string(user_id), user)
|
||||
self.assertTrue(is_guest)
|
||||
self.assertEqual(user_id, user_info.user_id)
|
||||
self.assertTrue(user_info.is_guest)
|
||||
self.store.get_user_by_id.assert_called_with(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -289,12 +289,9 @@ class AuthTestCase(unittest.TestCase):
|
|||
if token != tok:
|
||||
return defer.succeed(None)
|
||||
return defer.succeed(
|
||||
{
|
||||
"name": USER_ID,
|
||||
"is_guest": False,
|
||||
"token_id": 1234,
|
||||
"device_id": "DEVICE",
|
||||
}
|
||||
TokenLookupResult(
|
||||
user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE",
|
||||
)
|
||||
)
|
||||
|
||||
self.store.get_user_by_access_token = get_user
|
||||
|
|
|
@ -43,7 +43,7 @@ class TestRatelimiter(unittest.TestCase):
|
|||
|
||||
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
|
||||
appservice = ApplicationService(
|
||||
None, "example.com", id="foo", rate_limited=True,
|
||||
None, "example.com", id="foo", rate_limited=True, sender="@as:example.com",
|
||||
)
|
||||
as_requester = create_requester("@user:example.com", app_service=appservice)
|
||||
|
||||
|
@ -68,7 +68,7 @@ class TestRatelimiter(unittest.TestCase):
|
|||
|
||||
def test_allowed_appservice_via_can_requester_do_action(self):
|
||||
appservice = ApplicationService(
|
||||
None, "example.com", id="foo", rate_limited=False,
|
||||
None, "example.com", id="foo", rate_limited=False, sender="@as:example.com",
|
||||
)
|
||||
as_requester = create_requester("@user:example.com", app_service=appservice)
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
def setUp(self):
|
||||
self.service = ApplicationService(
|
||||
id="unique_identifier",
|
||||
sender="@as:test",
|
||||
url="some_url",
|
||||
token="some_token",
|
||||
hostname="matrix.org", # only used by get_groups_for_user
|
||||
|
|
|
@ -289,7 +289,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
|
|||
# make sure that our device ID has changed
|
||||
user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
|
||||
|
||||
self.assertEqual(user_info["device_id"], retrieved_device_id)
|
||||
self.assertEqual(user_info.device_id, retrieved_device_id)
|
||||
|
||||
# make sure the device has the display name that was set from the login
|
||||
res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
|
||||
|
|
|
@ -46,7 +46,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
|||
self.info = self.get_success(
|
||||
self.hs.get_datastore().get_user_by_access_token(self.access_token,)
|
||||
)
|
||||
self.token_id = self.info["token_id"]
|
||||
self.token_id = self.info.token_id
|
||||
|
||||
self.requester = create_requester(self.user_id, access_token_id=self.token_id)
|
||||
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
|
||||
class LoggerCleanupMixin:
|
||||
def get_logger(self, handler):
|
||||
"""
|
||||
Attach a handler to a logger and add clean-ups to remove revert this.
|
||||
"""
|
||||
# Create a logger and add the handler to it.
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.addHandler(handler)
|
||||
|
||||
# Ensure the logger actually logs something.
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# Ensure the logger gets cleaned-up appropriately.
|
||||
self.addCleanup(logger.removeHandler, handler)
|
||||
self.addCleanup(logger.setLevel, logging.NOTSET)
|
||||
|
||||
return logger
|
169
tests/logging/test_remote_handler.py
Normal file
169
tests/logging/test_remote_handler.py
Normal file
|
@ -0,0 +1,169 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.test.proto_helpers import AccumulatingProtocol
|
||||
|
||||
from synapse.logging import RemoteHandler
|
||||
|
||||
from tests.logging import LoggerCleanupMixin
|
||||
from tests.server import FakeTransport, get_clock
|
||||
from tests.unittest import TestCase
|
||||
|
||||
|
||||
def connect_logging_client(reactor, client_id):
|
||||
# This is essentially tests.server.connect_client, but disabling autoflush on
|
||||
# the client transport. This is necessary to avoid an infinite loop due to
|
||||
# sending of data via the logging transport causing additional logs to be
|
||||
# written.
|
||||
factory = reactor.tcpClients.pop(client_id)[2]
|
||||
client = factory.buildProtocol(None)
|
||||
server = AccumulatingProtocol()
|
||||
server.makeConnection(FakeTransport(client, reactor))
|
||||
client.makeConnection(FakeTransport(server, reactor, autoflush=False))
|
||||
|
||||
return client, server
|
||||
|
||||
|
||||
class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
|
||||
def setUp(self):
|
||||
self.reactor, _ = get_clock()
|
||||
|
||||
def test_log_output(self):
|
||||
"""
|
||||
The remote handler delivers logs over TCP.
|
||||
"""
|
||||
handler = RemoteHandler("127.0.0.1", 9000, _reactor=self.reactor)
|
||||
logger = self.get_logger(handler)
|
||||
|
||||
logger.info("Hello there, %s!", "wally")
|
||||
|
||||
# Trigger the connection
|
||||
client, server = connect_logging_client(self.reactor, 0)
|
||||
|
||||
# Trigger data being sent
|
||||
client.transport.flush()
|
||||
|
||||
# One log message, with a single trailing newline
|
||||
logs = server.data.decode("utf8").splitlines()
|
||||
self.assertEqual(len(logs), 1)
|
||||
self.assertEqual(server.data.count(b"\n"), 1)
|
||||
|
||||
# Ensure the data passed through properly.
|
||||
self.assertEqual(logs[0], "Hello there, wally!")
|
||||
|
||||
def test_log_backpressure_debug(self):
|
||||
"""
|
||||
When backpressure is hit, DEBUG logs will be shed.
|
||||
"""
|
||||
handler = RemoteHandler(
|
||||
"127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
|
||||
)
|
||||
logger = self.get_logger(handler)
|
||||
|
||||
# Send some debug messages
|
||||
for i in range(0, 3):
|
||||
logger.debug("debug %s" % (i,))
|
||||
|
||||
# Send a bunch of useful messages
|
||||
for i in range(0, 7):
|
||||
logger.info("info %s" % (i,))
|
||||
|
||||
# The last debug message pushes it past the maximum buffer
|
||||
logger.debug("too much debug")
|
||||
|
||||
# Allow the reconnection
|
||||
client, server = connect_logging_client(self.reactor, 0)
|
||||
client.transport.flush()
|
||||
|
||||
# Only the 7 infos made it through, the debugs were elided
|
||||
logs = server.data.splitlines()
|
||||
self.assertEqual(len(logs), 7)
|
||||
self.assertNotIn(b"debug", server.data)
|
||||
|
||||
def test_log_backpressure_info(self):
|
||||
"""
|
||||
When backpressure is hit, DEBUG and INFO logs will be shed.
|
||||
"""
|
||||
handler = RemoteHandler(
|
||||
"127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
|
||||
)
|
||||
logger = self.get_logger(handler)
|
||||
|
||||
# Send some debug messages
|
||||
for i in range(0, 3):
|
||||
logger.debug("debug %s" % (i,))
|
||||
|
||||
# Send a bunch of useful messages
|
||||
for i in range(0, 10):
|
||||
logger.warning("warn %s" % (i,))
|
||||
|
||||
# Send a bunch of info messages
|
||||
for i in range(0, 3):
|
||||
logger.info("info %s" % (i,))
|
||||
|
||||
# The last debug message pushes it past the maximum buffer
|
||||
logger.debug("too much debug")
|
||||
|
||||
# Allow the reconnection
|
||||
client, server = connect_logging_client(self.reactor, 0)
|
||||
client.transport.flush()
|
||||
|
||||
# The 10 warnings made it through, the debugs and infos were elided
|
||||
logs = server.data.splitlines()
|
||||
self.assertEqual(len(logs), 10)
|
||||
self.assertNotIn(b"debug", server.data)
|
||||
self.assertNotIn(b"info", server.data)
|
||||
|
||||
def test_log_backpressure_cut_middle(self):
|
||||
"""
|
||||
When backpressure is hit, and no more DEBUG and INFOs cannot be culled,
|
||||
it will cut the middle messages out.
|
||||
"""
|
||||
handler = RemoteHandler(
|
||||
"127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
|
||||
)
|
||||
logger = self.get_logger(handler)
|
||||
|
||||
# Send a bunch of useful messages
|
||||
for i in range(0, 20):
|
||||
logger.warning("warn %s" % (i,))
|
||||
|
||||
# Allow the reconnection
|
||||
client, server = connect_logging_client(self.reactor, 0)
|
||||
client.transport.flush()
|
||||
|
||||
# The first five and last five warnings made it through, the debugs and
|
||||
# infos were elided
|
||||
logs = server.data.decode("utf8").splitlines()
|
||||
self.assertEqual(
|
||||
["warn %s" % (i,) for i in range(5)]
|
||||
+ ["warn %s" % (i,) for i in range(15, 20)],
|
||||
logs,
|
||||
)
|
||||
|
||||
def test_cancel_connection(self):
|
||||
"""
|
||||
Gracefully handle the connection being cancelled.
|
||||
"""
|
||||
handler = RemoteHandler(
|
||||
"127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
|
||||
)
|
||||
logger = self.get_logger(handler)
|
||||
|
||||
# Send a message.
|
||||
logger.info("Hello there, %s!", "wally")
|
||||
|
||||
# Do not accept the connection and shutdown. This causes the pending
|
||||
# connection to be cancelled (and should not raise any exceptions).
|
||||
handler.close()
|
|
@ -1,214 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import os.path
|
||||
import shutil
|
||||
import sys
|
||||
import textwrap
|
||||
|
||||
from twisted.logger import Logger, eventAsText, eventsFromJSONLogFile
|
||||
|
||||
from synapse.config.logger import setup_logging
|
||||
from synapse.logging._structured import setup_structured_logging
|
||||
from synapse.logging.context import LoggingContext
|
||||
|
||||
from tests.unittest import DEBUG, HomeserverTestCase
|
||||
|
||||
|
||||
class FakeBeginner:
|
||||
def beginLoggingTo(self, observers, **kwargs):
|
||||
self.observers = observers
|
||||
|
||||
|
||||
class StructuredLoggingTestBase:
|
||||
"""
|
||||
Test base that registers a cleanup handler to reset the stdlib log handler
|
||||
to 'unset'.
|
||||
"""
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def _cleanup():
|
||||
logging.getLogger("synapse").setLevel(logging.NOTSET)
|
||||
|
||||
self.addCleanup(_cleanup)
|
||||
|
||||
|
||||
class StructuredLoggingTestCase(StructuredLoggingTestBase, HomeserverTestCase):
|
||||
"""
|
||||
Tests for Synapse's structured logging support.
|
||||
"""
|
||||
|
||||
def test_output_to_json_round_trip(self):
|
||||
"""
|
||||
Synapse logs can be outputted to JSON and then read back again.
|
||||
"""
|
||||
temp_dir = self.mktemp()
|
||||
os.mkdir(temp_dir)
|
||||
self.addCleanup(shutil.rmtree, temp_dir)
|
||||
|
||||
json_log_file = os.path.abspath(os.path.join(temp_dir, "out.json"))
|
||||
|
||||
log_config = {
|
||||
"drains": {"jsonfile": {"type": "file_json", "location": json_log_file}}
|
||||
}
|
||||
|
||||
# Begin the logger with our config
|
||||
beginner = FakeBeginner()
|
||||
setup_structured_logging(
|
||||
self.hs, self.hs.config, log_config, logBeginner=beginner
|
||||
)
|
||||
|
||||
# Make a logger and send an event
|
||||
logger = Logger(
|
||||
namespace="tests.logging.test_structured", observer=beginner.observers[0]
|
||||
)
|
||||
logger.info("Hello there, {name}!", name="wally")
|
||||
|
||||
# Read the log file and check it has the event we sent
|
||||
with open(json_log_file, "r") as f:
|
||||
logged_events = list(eventsFromJSONLogFile(f))
|
||||
self.assertEqual(len(logged_events), 1)
|
||||
|
||||
# The event pulled from the file should render fine
|
||||
self.assertEqual(
|
||||
eventAsText(logged_events[0], includeTimestamp=False),
|
||||
"[tests.logging.test_structured#info] Hello there, wally!",
|
||||
)
|
||||
|
||||
def test_output_to_text(self):
|
||||
"""
|
||||
Synapse logs can be outputted to text.
|
||||
"""
|
||||
temp_dir = self.mktemp()
|
||||
os.mkdir(temp_dir)
|
||||
self.addCleanup(shutil.rmtree, temp_dir)
|
||||
|
||||
log_file = os.path.abspath(os.path.join(temp_dir, "out.log"))
|
||||
|
||||
log_config = {"drains": {"file": {"type": "file", "location": log_file}}}
|
||||
|
||||
# Begin the logger with our config
|
||||
beginner = FakeBeginner()
|
||||
setup_structured_logging(
|
||||
self.hs, self.hs.config, log_config, logBeginner=beginner
|
||||
)
|
||||
|
||||
# Make a logger and send an event
|
||||
logger = Logger(
|
||||
namespace="tests.logging.test_structured", observer=beginner.observers[0]
|
||||
)
|
||||
logger.info("Hello there, {name}!", name="wally")
|
||||
|
||||
# Read the log file and check it has the event we sent
|
||||
with open(log_file, "r") as f:
|
||||
logged_events = f.read().strip().split("\n")
|
||||
self.assertEqual(len(logged_events), 1)
|
||||
|
||||
# The event pulled from the file should render fine
|
||||
self.assertTrue(
|
||||
logged_events[0].endswith(
|
||||
" - tests.logging.test_structured - INFO - None - Hello there, wally!"
|
||||
)
|
||||
)
|
||||
|
||||
def test_collects_logcontext(self):
|
||||
"""
|
||||
Test that log outputs have the attached logging context.
|
||||
"""
|
||||
log_config = {"drains": {}}
|
||||
|
||||
# Begin the logger with our config
|
||||
beginner = FakeBeginner()
|
||||
publisher = setup_structured_logging(
|
||||
self.hs, self.hs.config, log_config, logBeginner=beginner
|
||||
)
|
||||
|
||||
logs = []
|
||||
|
||||
publisher.addObserver(logs.append)
|
||||
|
||||
# Make a logger and send an event
|
||||
logger = Logger(
|
||||
namespace="tests.logging.test_structured", observer=beginner.observers[0]
|
||||
)
|
||||
|
||||
with LoggingContext("testcontext", request="somereq"):
|
||||
logger.info("Hello there, {name}!", name="steve")
|
||||
|
||||
self.assertEqual(len(logs), 1)
|
||||
self.assertEqual(logs[0]["request"], "somereq")
|
||||
|
||||
|
||||
class StructuredLoggingConfigurationFileTestCase(
|
||||
StructuredLoggingTestBase, HomeserverTestCase
|
||||
):
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
||||
tempdir = self.mktemp()
|
||||
os.mkdir(tempdir)
|
||||
log_config_file = os.path.abspath(os.path.join(tempdir, "log.config.yaml"))
|
||||
self.homeserver_log = os.path.abspath(os.path.join(tempdir, "homeserver.log"))
|
||||
|
||||
config = self.default_config()
|
||||
config["log_config"] = log_config_file
|
||||
|
||||
with open(log_config_file, "w") as f:
|
||||
f.write(
|
||||
textwrap.dedent(
|
||||
"""\
|
||||
structured: true
|
||||
|
||||
drains:
|
||||
file:
|
||||
type: file_json
|
||||
location: %s
|
||||
"""
|
||||
% (self.homeserver_log,)
|
||||
)
|
||||
)
|
||||
|
||||
self.addCleanup(self._sys_cleanup)
|
||||
|
||||
return self.setup_test_homeserver(config=config)
|
||||
|
||||
def _sys_cleanup(self):
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
# Do not remove! We need the logging system to be set other than WARNING.
|
||||
@DEBUG
|
||||
def test_log_output(self):
|
||||
"""
|
||||
When a structured logging config is given, Synapse will use it.
|
||||
"""
|
||||
beginner = FakeBeginner()
|
||||
publisher = setup_logging(self.hs, self.hs.config, logBeginner=beginner)
|
||||
|
||||
# Make a logger and send an event
|
||||
logger = Logger(namespace="tests.logging.test_structured", observer=publisher)
|
||||
|
||||
with LoggingContext("testcontext", request="somereq"):
|
||||
logger.info("Hello there, {name}!", name="steve")
|
||||
|
||||
with open(self.homeserver_log, "r") as f:
|
||||
logged_events = [
|
||||
eventAsText(x, includeTimestamp=False) for x in eventsFromJSONLogFile(f)
|
||||
]
|
||||
|
||||
logs = "\n".join(logged_events)
|
||||
self.assertTrue("***** STARTING SERVER *****" in logs)
|
||||
self.assertTrue("Hello there, steve!" in logs)
|
|
@ -14,57 +14,33 @@
|
|||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from collections import Counter
|
||||
import logging
|
||||
from io import StringIO
|
||||
|
||||
from twisted.logger import Logger
|
||||
from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
|
||||
|
||||
from synapse.logging._structured import setup_structured_logging
|
||||
|
||||
from tests.server import connect_client
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
from .test_structured import FakeBeginner, StructuredLoggingTestBase
|
||||
from tests.logging import LoggerCleanupMixin
|
||||
from tests.unittest import TestCase
|
||||
|
||||
|
||||
class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase):
|
||||
def test_log_output(self):
|
||||
class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
|
||||
def test_terse_json_output(self):
|
||||
"""
|
||||
The Terse JSON outputter delivers simplified structured logs over TCP.
|
||||
The Terse JSON formatter converts log messages to JSON.
|
||||
"""
|
||||
log_config = {
|
||||
"drains": {
|
||||
"tersejson": {
|
||||
"type": "network_json_terse",
|
||||
"host": "127.0.0.1",
|
||||
"port": 8000,
|
||||
}
|
||||
}
|
||||
}
|
||||
output = StringIO()
|
||||
|
||||
# Begin the logger with our config
|
||||
beginner = FakeBeginner()
|
||||
setup_structured_logging(
|
||||
self.hs, self.hs.config, log_config, logBeginner=beginner
|
||||
)
|
||||
handler = logging.StreamHandler(output)
|
||||
handler.setFormatter(TerseJsonFormatter())
|
||||
logger = self.get_logger(handler)
|
||||
|
||||
logger = Logger(
|
||||
namespace="tests.logging.test_terse_json", observer=beginner.observers[0]
|
||||
)
|
||||
logger.info("Hello there, {name}!", name="wally")
|
||||
logger.info("Hello there, %s!", "wally")
|
||||
|
||||
# Trigger the connection
|
||||
self.pump()
|
||||
|
||||
_, server = connect_client(self.reactor, 0)
|
||||
|
||||
# Trigger data being sent
|
||||
self.pump()
|
||||
|
||||
# One log message, with a single trailing newline
|
||||
logs = server.data.decode("utf8").splitlines()
|
||||
# One log message, with a single trailing newline.
|
||||
data = output.getvalue()
|
||||
logs = data.splitlines()
|
||||
self.assertEqual(len(logs), 1)
|
||||
self.assertEqual(server.data.count(b"\n"), 1)
|
||||
|
||||
self.assertEqual(data.count("\n"), 1)
|
||||
log = json.loads(logs[0])
|
||||
|
||||
# The terse logger should give us these keys.
|
||||
|
@ -72,163 +48,74 @@ class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase):
|
|||
"log",
|
||||
"time",
|
||||
"level",
|
||||
"log_namespace",
|
||||
"request",
|
||||
"scope",
|
||||
"server_name",
|
||||
"name",
|
||||
"namespace",
|
||||
]
|
||||
self.assertCountEqual(log.keys(), expected_log_keys)
|
||||
self.assertEqual(log["log"], "Hello there, wally!")
|
||||
|
||||
def test_extra_data(self):
|
||||
"""
|
||||
Additional information can be included in the structured logging.
|
||||
"""
|
||||
output = StringIO()
|
||||
|
||||
handler = logging.StreamHandler(output)
|
||||
handler.setFormatter(TerseJsonFormatter())
|
||||
logger = self.get_logger(handler)
|
||||
|
||||
logger.info(
|
||||
"Hello there, %s!", "wally", extra={"foo": "bar", "int": 3, "bool": True}
|
||||
)
|
||||
|
||||
# One log message, with a single trailing newline.
|
||||
data = output.getvalue()
|
||||
logs = data.splitlines()
|
||||
self.assertEqual(len(logs), 1)
|
||||
self.assertEqual(data.count("\n"), 1)
|
||||
log = json.loads(logs[0])
|
||||
|
||||
# The terse logger should give us these keys.
|
||||
expected_log_keys = [
|
||||
"log",
|
||||
"time",
|
||||
"level",
|
||||
"namespace",
|
||||
# The additional keys given via extra.
|
||||
"foo",
|
||||
"int",
|
||||
"bool",
|
||||
]
|
||||
self.assertCountEqual(log.keys(), expected_log_keys)
|
||||
|
||||
# It contains the data we expect.
|
||||
self.assertEqual(log["name"], "wally")
|
||||
# Check the values of the extra fields.
|
||||
self.assertEqual(log["foo"], "bar")
|
||||
self.assertEqual(log["int"], 3)
|
||||
self.assertIs(log["bool"], True)
|
||||
|
||||
def test_log_backpressure_debug(self):
|
||||
def test_json_output(self):
|
||||
"""
|
||||
When backpressure is hit, DEBUG logs will be shed.
|
||||
The Terse JSON formatter converts log messages to JSON.
|
||||
"""
|
||||
log_config = {
|
||||
"loggers": {"synapse": {"level": "DEBUG"}},
|
||||
"drains": {
|
||||
"tersejson": {
|
||||
"type": "network_json_terse",
|
||||
"host": "127.0.0.1",
|
||||
"port": 8000,
|
||||
"maximum_buffer": 10,
|
||||
}
|
||||
},
|
||||
}
|
||||
output = StringIO()
|
||||
|
||||
# Begin the logger with our config
|
||||
beginner = FakeBeginner()
|
||||
setup_structured_logging(
|
||||
self.hs,
|
||||
self.hs.config,
|
||||
log_config,
|
||||
logBeginner=beginner,
|
||||
redirect_stdlib_logging=False,
|
||||
)
|
||||
handler = logging.StreamHandler(output)
|
||||
handler.setFormatter(JsonFormatter())
|
||||
logger = self.get_logger(handler)
|
||||
|
||||
logger = Logger(
|
||||
namespace="synapse.logging.test_terse_json", observer=beginner.observers[0]
|
||||
)
|
||||
logger.info("Hello there, %s!", "wally")
|
||||
|
||||
# Send some debug messages
|
||||
for i in range(0, 3):
|
||||
logger.debug("debug %s" % (i,))
|
||||
# One log message, with a single trailing newline.
|
||||
data = output.getvalue()
|
||||
logs = data.splitlines()
|
||||
self.assertEqual(len(logs), 1)
|
||||
self.assertEqual(data.count("\n"), 1)
|
||||
log = json.loads(logs[0])
|
||||
|
||||
# Send a bunch of useful messages
|
||||
for i in range(0, 7):
|
||||
logger.info("test message %s" % (i,))
|
||||
|
||||
# The last debug message pushes it past the maximum buffer
|
||||
logger.debug("too much debug")
|
||||
|
||||
# Allow the reconnection
|
||||
_, server = connect_client(self.reactor, 0)
|
||||
self.pump()
|
||||
|
||||
# Only the 7 infos made it through, the debugs were elided
|
||||
logs = server.data.splitlines()
|
||||
self.assertEqual(len(logs), 7)
|
||||
|
||||
def test_log_backpressure_info(self):
|
||||
"""
|
||||
When backpressure is hit, DEBUG and INFO logs will be shed.
|
||||
"""
|
||||
log_config = {
|
||||
"loggers": {"synapse": {"level": "DEBUG"}},
|
||||
"drains": {
|
||||
"tersejson": {
|
||||
"type": "network_json_terse",
|
||||
"host": "127.0.0.1",
|
||||
"port": 8000,
|
||||
"maximum_buffer": 10,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Begin the logger with our config
|
||||
beginner = FakeBeginner()
|
||||
setup_structured_logging(
|
||||
self.hs,
|
||||
self.hs.config,
|
||||
log_config,
|
||||
logBeginner=beginner,
|
||||
redirect_stdlib_logging=False,
|
||||
)
|
||||
|
||||
logger = Logger(
|
||||
namespace="synapse.logging.test_terse_json", observer=beginner.observers[0]
|
||||
)
|
||||
|
||||
# Send some debug messages
|
||||
for i in range(0, 3):
|
||||
logger.debug("debug %s" % (i,))
|
||||
|
||||
# Send a bunch of useful messages
|
||||
for i in range(0, 10):
|
||||
logger.warn("test warn %s" % (i,))
|
||||
|
||||
# Send a bunch of info messages
|
||||
for i in range(0, 3):
|
||||
logger.info("test message %s" % (i,))
|
||||
|
||||
# The last debug message pushes it past the maximum buffer
|
||||
logger.debug("too much debug")
|
||||
|
||||
# Allow the reconnection
|
||||
client, server = connect_client(self.reactor, 0)
|
||||
self.pump()
|
||||
|
||||
# The 10 warnings made it through, the debugs and infos were elided
|
||||
logs = list(map(json.loads, server.data.decode("utf8").splitlines()))
|
||||
self.assertEqual(len(logs), 10)
|
||||
|
||||
self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10})
|
||||
|
||||
def test_log_backpressure_cut_middle(self):
|
||||
"""
|
||||
When backpressure is hit, and no more DEBUG and INFOs cannot be culled,
|
||||
it will cut the middle messages out.
|
||||
"""
|
||||
log_config = {
|
||||
"loggers": {"synapse": {"level": "DEBUG"}},
|
||||
"drains": {
|
||||
"tersejson": {
|
||||
"type": "network_json_terse",
|
||||
"host": "127.0.0.1",
|
||||
"port": 8000,
|
||||
"maximum_buffer": 10,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Begin the logger with our config
|
||||
beginner = FakeBeginner()
|
||||
setup_structured_logging(
|
||||
self.hs,
|
||||
self.hs.config,
|
||||
log_config,
|
||||
logBeginner=beginner,
|
||||
redirect_stdlib_logging=False,
|
||||
)
|
||||
|
||||
logger = Logger(
|
||||
namespace="synapse.logging.test_terse_json", observer=beginner.observers[0]
|
||||
)
|
||||
|
||||
# Send a bunch of useful messages
|
||||
for i in range(0, 20):
|
||||
logger.warn("test warn", num=i)
|
||||
|
||||
# Allow the reconnection
|
||||
client, server = connect_client(self.reactor, 0)
|
||||
self.pump()
|
||||
|
||||
# The first five and last five warnings made it through, the debugs and
|
||||
# infos were elided
|
||||
logs = list(map(json.loads, server.data.decode("utf8").splitlines()))
|
||||
self.assertEqual(len(logs), 10)
|
||||
self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10})
|
||||
self.assertEqual([0, 1, 2, 3, 4, 15, 16, 17, 18, 19], [x["num"] for x in logs])
|
||||
# The terse logger should give us these keys.
|
||||
expected_log_keys = [
|
||||
"log",
|
||||
"level",
|
||||
"namespace",
|
||||
]
|
||||
self.assertCountEqual(log.keys(), expected_log_keys)
|
||||
self.assertEqual(log["log"], "Hello there, wally!")
|
||||
|
|
|
@ -100,7 +100,7 @@ class EmailPusherTests(HomeserverTestCase):
|
|||
user_tuple = self.get_success(
|
||||
self.hs.get_datastore().get_user_by_access_token(self.access_token)
|
||||
)
|
||||
token_id = user_tuple["token_id"]
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.pusher = self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
|
|
|
@ -69,7 +69,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
user_tuple = self.get_success(
|
||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||
)
|
||||
token_id = user_tuple["token_id"]
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
|
@ -181,7 +181,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
user_tuple = self.get_success(
|
||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||
)
|
||||
token_id = user_tuple["token_id"]
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
|
@ -297,7 +297,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
user_tuple = self.get_success(
|
||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||
)
|
||||
token_id = user_tuple["token_id"]
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
|
@ -379,7 +379,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
user_tuple = self.get_success(
|
||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||
)
|
||||
token_id = user_tuple["token_id"]
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
|
@ -452,7 +452,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
user_tuple = self.get_success(
|
||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||
)
|
||||
token_id = user_tuple["token_id"]
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
|
|
|
@ -16,7 +16,6 @@ import logging
|
|||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
import hiredis
|
||||
|
||||
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
|
||||
from twisted.internet.protocol import Protocol
|
||||
|
@ -39,12 +38,22 @@ from synapse.util import Clock
|
|||
from tests import unittest
|
||||
from tests.server import FakeTransport, render
|
||||
|
||||
try:
|
||||
import hiredis
|
||||
except ImportError:
|
||||
hiredis = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
"""Base class for tests of the replication streams"""
|
||||
|
||||
# hiredis is an optional dependency so we don't want to require it for running
|
||||
# the tests.
|
||||
if not hiredis:
|
||||
skip = "Requires hiredis"
|
||||
|
||||
servlets = [
|
||||
streams.register_servlets,
|
||||
]
|
||||
|
@ -269,7 +278,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
|||
homeserver_to_use=GenericWorkerServer,
|
||||
config=config,
|
||||
reactor=self.reactor,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# If the instance is in the `instance_map` config then workers may try
|
||||
|
|
|
@ -449,7 +449,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||
sender=sender,
|
||||
type="test_event",
|
||||
content={"body": body},
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
277
tests/replication/test_multi_media_repo.py
Normal file
277
tests/replication/test_multi_media_repo.py
Normal file
|
@ -0,0 +1,277 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
from binascii import unhexlify
|
||||
from typing import Tuple
|
||||
|
||||
from twisted.internet.protocol import Factory
|
||||
from twisted.protocols.tls import TLSMemoryBIOFactory
|
||||
from twisted.web.http import HTTPChannel
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client.v1 import login
|
||||
from synapse.server import HomeServer
|
||||
|
||||
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
|
||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||
from tests.server import FakeChannel, FakeTransport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
test_server_connection_factory = None
|
||||
|
||||
|
||||
class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||
"""Checks running multiple media repos work correctly.
|
||||
"""
|
||||
|
||||
servlets = [
|
||||
admin.register_servlets_for_client_rest_resource,
|
||||
login.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.user_id = self.register_user("user", "pass")
|
||||
self.access_token = self.login("user", "pass")
|
||||
|
||||
self.reactor.lookups["example.com"] = "127.0.0.2"
|
||||
|
||||
def default_config(self):
|
||||
conf = super().default_config()
|
||||
conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
|
||||
return conf
|
||||
|
||||
def _get_media_req(
|
||||
self, hs: HomeServer, target: str, media_id: str
|
||||
) -> Tuple[FakeChannel, Request]:
|
||||
"""Request some remote media from the given HS by calling the download
|
||||
API.
|
||||
|
||||
This then triggers an outbound request from the HS to the target.
|
||||
|
||||
Returns:
|
||||
The channel for the *client* request and the *outbound* request for
|
||||
the media which the caller should respond to.
|
||||
"""
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
"/{}/{}".format(target, media_id),
|
||||
shorthand=False,
|
||||
access_token=self.access_token,
|
||||
)
|
||||
request.render(hs.get_media_repository_resource().children[b"download"])
|
||||
self.pump()
|
||||
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertGreaterEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
|
||||
|
||||
# build the test server
|
||||
server_tls_protocol = _build_test_server(get_connection_factory())
|
||||
|
||||
# now, tell the client protocol factory to build the client protocol (it will be a
|
||||
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
|
||||
# HTTP11ClientProtocol) and wire the output of said protocol up to the server via
|
||||
# a FakeTransport.
|
||||
#
|
||||
# Normally this would be done by the TCP socket code in Twisted, but we are
|
||||
# stubbing that out here.
|
||||
client_protocol = client_factory.buildProtocol(None)
|
||||
client_protocol.makeConnection(
|
||||
FakeTransport(server_tls_protocol, self.reactor, client_protocol)
|
||||
)
|
||||
|
||||
# tell the server tls protocol to send its stuff back to the client, too
|
||||
server_tls_protocol.makeConnection(
|
||||
FakeTransport(client_protocol, self.reactor, server_tls_protocol)
|
||||
)
|
||||
|
||||
# fish the test server back out of the server-side TLS protocol.
|
||||
http_server = server_tls_protocol.wrappedProtocol
|
||||
|
||||
# give the reactor a pump to get the TLS juices flowing.
|
||||
self.reactor.pump((0.1,))
|
||||
|
||||
self.assertEqual(len(http_server.requests), 1)
|
||||
request = http_server.requests[0]
|
||||
|
||||
self.assertEqual(request.method, b"GET")
|
||||
self.assertEqual(
|
||||
request.path,
|
||||
"/_matrix/media/r0/download/{}/{}".format(target, media_id).encode("utf-8"),
|
||||
)
|
||||
self.assertEqual(
|
||||
request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
|
||||
)
|
||||
|
||||
return channel, request
|
||||
|
||||
def test_basic(self):
|
||||
"""Test basic fetching of remote media from a single worker.
|
||||
"""
|
||||
hs1 = self.make_worker_hs("synapse.app.generic_worker")
|
||||
|
||||
channel, request = self._get_media_req(hs1, "example.com:443", "ABC123")
|
||||
|
||||
request.setResponseCode(200)
|
||||
request.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
|
||||
request.write(b"Hello!")
|
||||
request.finish()
|
||||
|
||||
self.pump(0.1)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.result["body"], b"Hello!")
|
||||
|
||||
def test_download_simple_file_race(self):
|
||||
"""Test that fetching remote media from two different processes at the
|
||||
same time works.
|
||||
"""
|
||||
hs1 = self.make_worker_hs("synapse.app.generic_worker")
|
||||
hs2 = self.make_worker_hs("synapse.app.generic_worker")
|
||||
|
||||
start_count = self._count_remote_media()
|
||||
|
||||
# Make two requests without responding to the outbound media requests.
|
||||
channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123")
|
||||
channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123")
|
||||
|
||||
# Respond to the first outbound media request and check that the client
|
||||
# request is successful
|
||||
request1.setResponseCode(200)
|
||||
request1.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
|
||||
request1.write(b"Hello!")
|
||||
request1.finish()
|
||||
|
||||
self.pump(0.1)
|
||||
|
||||
self.assertEqual(channel1.code, 200, channel1.result["body"])
|
||||
self.assertEqual(channel1.result["body"], b"Hello!")
|
||||
|
||||
# Now respond to the second with the same content.
|
||||
request2.setResponseCode(200)
|
||||
request2.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
|
||||
request2.write(b"Hello!")
|
||||
request2.finish()
|
||||
|
||||
self.pump(0.1)
|
||||
|
||||
self.assertEqual(channel2.code, 200, channel2.result["body"])
|
||||
self.assertEqual(channel2.result["body"], b"Hello!")
|
||||
|
||||
# We expect only one new file to have been persisted.
|
||||
self.assertEqual(start_count + 1, self._count_remote_media())
|
||||
|
||||
def test_download_image_race(self):
|
||||
"""Test that fetching remote *images* from two different processes at
|
||||
the same time works.
|
||||
|
||||
This checks that races generating thumbnails are handled correctly.
|
||||
"""
|
||||
hs1 = self.make_worker_hs("synapse.app.generic_worker")
|
||||
hs2 = self.make_worker_hs("synapse.app.generic_worker")
|
||||
|
||||
start_count = self._count_remote_thumbnails()
|
||||
|
||||
channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1")
|
||||
channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1")
|
||||
|
||||
png_data = unhexlify(
|
||||
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
|
||||
b"0000001f15c4890000000a49444154789c63000100000500010d"
|
||||
b"0a2db40000000049454e44ae426082"
|
||||
)
|
||||
|
||||
request1.setResponseCode(200)
|
||||
request1.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
|
||||
request1.write(png_data)
|
||||
request1.finish()
|
||||
|
||||
self.pump(0.1)
|
||||
|
||||
self.assertEqual(channel1.code, 200, channel1.result["body"])
|
||||
self.assertEqual(channel1.result["body"], png_data)
|
||||
|
||||
request2.setResponseCode(200)
|
||||
request2.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
|
||||
request2.write(png_data)
|
||||
request2.finish()
|
||||
|
||||
self.pump(0.1)
|
||||
|
||||
self.assertEqual(channel2.code, 200, channel2.result["body"])
|
||||
self.assertEqual(channel2.result["body"], png_data)
|
||||
|
||||
# We expect only three new thumbnails to have been persisted.
|
||||
self.assertEqual(start_count + 3, self._count_remote_thumbnails())
|
||||
|
||||
def _count_remote_media(self) -> int:
|
||||
"""Count the number of files in our remote media directory.
|
||||
"""
|
||||
path = os.path.join(
|
||||
self.hs.get_media_repository().primary_base_path, "remote_content"
|
||||
)
|
||||
return sum(len(files) for _, _, files in os.walk(path))
|
||||
|
||||
def _count_remote_thumbnails(self) -> int:
|
||||
"""Count the number of files in our remote thumbnails directory.
|
||||
"""
|
||||
path = os.path.join(
|
||||
self.hs.get_media_repository().primary_base_path, "remote_thumbnail"
|
||||
)
|
||||
return sum(len(files) for _, _, files in os.walk(path))
|
||||
|
||||
|
||||
def get_connection_factory():
|
||||
# this needs to happen once, but not until we are ready to run the first test
|
||||
global test_server_connection_factory
|
||||
if test_server_connection_factory is None:
|
||||
test_server_connection_factory = TestServerTLSConnectionFactory(
|
||||
sanlist=[b"DNS:example.com"]
|
||||
)
|
||||
return test_server_connection_factory
|
||||
|
||||
|
||||
def _build_test_server(connection_creator):
|
||||
"""Construct a test server
|
||||
|
||||
This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
|
||||
|
||||
Args:
|
||||
connection_creator (IOpenSSLServerConnectionCreator): thing to build
|
||||
SSL connections
|
||||
sanlist (list[bytes]): list of the SAN entries for the cert returned
|
||||
by the server
|
||||
|
||||
Returns:
|
||||
TLSMemoryBIOProtocol
|
||||
"""
|
||||
server_factory = Factory.forProtocol(HTTPChannel)
|
||||
# Request.finish expects the factory to have a 'log' method.
|
||||
server_factory.log = _log_request
|
||||
|
||||
server_tls_factory = TLSMemoryBIOFactory(
|
||||
connection_creator, isClient=False, wrappedFactory=server_factory
|
||||
)
|
||||
|
||||
return server_tls_factory.buildProtocol(None)
|
||||
|
||||
|
||||
def _log_request(request):
|
||||
"""Implements Factory.log, which is expected by Request.finish"""
|
||||
logger.info("Completed request %s", request)
|
|
@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
user_dict = self.get_success(
|
||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||
)
|
||||
token_id = user_dict["token_id"]
|
||||
token_id = user_dict.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
|
|
|
@ -1118,6 +1118,130 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
|
||||
|
||||
|
||||
class PushersRestTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets,
|
||||
login.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||
self.admin_user_tok = self.login("admin", "pass")
|
||||
|
||||
self.other_user = self.register_user("user", "pass")
|
||||
self.url = "/_synapse/admin/v1/users/%s/pushers" % urllib.parse.quote(
|
||||
self.other_user
|
||||
)
|
||||
|
||||
def test_no_auth(self):
|
||||
"""
|
||||
Try to list pushers of an user without authentication.
|
||||
"""
|
||||
request, channel = self.make_request("GET", self.url, b"{}")
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
|
||||
|
||||
def test_requester_is_no_admin(self):
|
||||
"""
|
||||
If the user is not a server admin, an error is returned.
|
||||
"""
|
||||
other_user_token = self.login("user", "pass")
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url, access_token=other_user_token,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||
|
||||
def test_user_does_not_exist(self):
|
||||
"""
|
||||
Tests that a lookup for a user that does not exist returns a 404
|
||||
"""
|
||||
url = "/_synapse/admin/v1/users/@unknown_person:test/pushers"
|
||||
request, channel = self.make_request(
|
||||
"GET", url, access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(404, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
|
||||
|
||||
def test_user_is_not_local(self):
|
||||
"""
|
||||
Tests that a lookup for a user that is not a local returns a 400
|
||||
"""
|
||||
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers"
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", url, access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(400, channel.code, msg=channel.json_body)
|
||||
self.assertEqual("Can only lookup local users", channel.json_body["error"])
|
||||
|
||||
def test_get_pushers(self):
|
||||
"""
|
||||
Tests that a normal lookup for pushers is successfully
|
||||
"""
|
||||
|
||||
# Get pushers
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url, access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(0, channel.json_body["total"])
|
||||
|
||||
# Register the pusher
|
||||
other_user_token = self.login("user", "pass")
|
||||
user_tuple = self.get_success(
|
||||
self.store.get_user_by_access_token(other_user_token)
|
||||
)
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
user_id=self.other_user,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
app_id="m.http",
|
||||
app_display_name="HTTP Push Notifications",
|
||||
device_display_name="pushy push",
|
||||
pushkey="a@example.com",
|
||||
lang=None,
|
||||
data={"url": "example.com"},
|
||||
)
|
||||
)
|
||||
|
||||
# Get pushers
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url, access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(1, channel.json_body["total"])
|
||||
|
||||
for p in channel.json_body["pushers"]:
|
||||
self.assertIn("pushkey", p)
|
||||
self.assertIn("kind", p)
|
||||
self.assertIn("app_id", p)
|
||||
self.assertIn("app_display_name", p)
|
||||
self.assertIn("device_display_name", p)
|
||||
self.assertIn("profile_tag", p)
|
||||
self.assertIn("lang", p)
|
||||
self.assertIn("url", p["data"])
|
||||
|
||||
|
||||
class UserMediaRestTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
servlets = [
|
||||
|
|
|
@ -55,6 +55,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
|||
self.hs.config.server_name,
|
||||
id="1234",
|
||||
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
|
||||
sender="@as:test",
|
||||
)
|
||||
|
||||
self.hs.get_datastore().services_cache.append(appservice)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue