mirror of
https://github.com/element-hq/synapse
synced 2024-10-01 21:32:40 +00:00
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
This commit is contained in:
commit
1ff3bc332a
108 changed files with 2360 additions and 1540 deletions
|
@ -46,7 +46,7 @@ locally. You'll need python 3.6 or later, and to install a number of tools:
|
||||||
|
|
||||||
```
|
```
|
||||||
# Install the dependencies
|
# Install the dependencies
|
||||||
pip install -e ".[lint]"
|
pip install -e ".[lint,mypy]"
|
||||||
|
|
||||||
# Run the linter script
|
# Run the linter script
|
||||||
./scripts-dev/lint.sh
|
./scripts-dev/lint.sh
|
||||||
|
@ -63,7 +63,7 @@ run-time:
|
||||||
./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder
|
./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also provided the `-d` option, which will lint the files that have been
|
You can also provide the `-d` option, which will lint the files that have been
|
||||||
changed since the last git commit. This will often be significantly faster than
|
changed since the last git commit. This will often be significantly faster than
|
||||||
linting the whole codebase.
|
linting the whole codebase.
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,7 @@ light workloads.
|
||||||
System requirements:
|
System requirements:
|
||||||
|
|
||||||
- POSIX-compliant system (tested on Linux & OS X)
|
- POSIX-compliant system (tested on Linux & OS X)
|
||||||
- Python 3.5.2 or later, up to Python 3.8.
|
- Python 3.5.2 or later, up to Python 3.9.
|
||||||
- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org
|
- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org
|
||||||
|
|
||||||
Synapse is written in Python but some of the libraries it uses are written in
|
Synapse is written in Python but some of the libraries it uses are written in
|
||||||
|
|
10
README.rst
10
README.rst
|
@ -256,9 +256,9 @@ directory of your choice::
|
||||||
Synapse has a number of external dependencies, that are easiest
|
Synapse has a number of external dependencies, that are easiest
|
||||||
to install using pip and a virtualenv::
|
to install using pip and a virtualenv::
|
||||||
|
|
||||||
virtualenv -p python3 env
|
python3 -m venv ./env
|
||||||
source env/bin/activate
|
source ./env/bin/activate
|
||||||
python -m pip install --no-use-pep517 -e ".[all]"
|
pip install -e ".[all,test]"
|
||||||
|
|
||||||
This will run a process of downloading and installing all the needed
|
This will run a process of downloading and installing all the needed
|
||||||
dependencies into a virtual env.
|
dependencies into a virtual env.
|
||||||
|
@ -270,9 +270,9 @@ check that everything is installed as it should be::
|
||||||
|
|
||||||
This should end with a 'PASSED' result::
|
This should end with a 'PASSED' result::
|
||||||
|
|
||||||
Ran 143 tests in 0.601s
|
Ran 1266 tests in 643.930s
|
||||||
|
|
||||||
PASSED (successes=143)
|
PASSED (skips=15, successes=1251)
|
||||||
|
|
||||||
Running the Integration Tests
|
Running the Integration Tests
|
||||||
=============================
|
=============================
|
||||||
|
|
16
UPGRADE.rst
16
UPGRADE.rst
|
@ -75,6 +75,22 @@ for example:
|
||||||
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
||||||
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
||||||
|
|
||||||
|
Upgrading to v1.23.0
|
||||||
|
====================
|
||||||
|
|
||||||
|
Structured logging configuration breaking changes
|
||||||
|
-------------------------------------------------
|
||||||
|
|
||||||
|
This release deprecates use of the ``structured: true`` logging configuration for
|
||||||
|
structured logging. If your logging configuration contains ``structured: true``
|
||||||
|
then it should be modified based on the `structured logging documentation
|
||||||
|
<https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md>`_.
|
||||||
|
|
||||||
|
The ``structured`` and ``drains`` logging options are now deprecated and should
|
||||||
|
be replaced by standard logging configuration of ``handlers`` and ``formatters`.
|
||||||
|
|
||||||
|
A future will release of Synapse will make using ``structured: true`` an error.
|
||||||
|
|
||||||
Upgrading to v1.22.0
|
Upgrading to v1.22.0
|
||||||
====================
|
====================
|
||||||
|
|
||||||
|
|
1
changelog.d/8559.misc
Normal file
1
changelog.d/8559.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Optimise `/createRoom` with multiple invited users.
|
1
changelog.d/8595.misc
Normal file
1
changelog.d/8595.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Implement and use an @lru_cache decorator.
|
1
changelog.d/8607.feature
Normal file
1
changelog.d/8607.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Support generating structured logs via the standard logging configuration.
|
1
changelog.d/8610.feature
Normal file
1
changelog.d/8610.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add an admin APIs to allow server admins to list users' pushers. Contributed by @dklimpel.
|
1
changelog.d/8616.misc
Normal file
1
changelog.d/8616.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Change schema to support access tokens belonging to one user but granting access to another.
|
1
changelog.d/8633.misc
Normal file
1
changelog.d/8633.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Run `mypy` as part of the lint.sh script.
|
1
changelog.d/8655.misc
Normal file
1
changelog.d/8655.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add more type hints to the application services code.
|
1
changelog.d/8664.misc
Normal file
1
changelog.d/8664.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Tell Black to format code for Python 3.5.
|
1
changelog.d/8665.doc
Normal file
1
changelog.d/8665.doc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Note support for Python 3.9.
|
1
changelog.d/8666.doc
Normal file
1
changelog.d/8666.doc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Minor updates to docs on running tests.
|
1
changelog.d/8667.doc
Normal file
1
changelog.d/8667.doc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Interlink prometheus/grafana documentation.
|
1
changelog.d/8669.misc
Normal file
1
changelog.d/8669.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Don't pull event from DB when handling replication traffic.
|
1
changelog.d/8671.misc
Normal file
1
changelog.d/8671.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Abstract some invite-related code in preparation for landing knocking.
|
1
changelog.d/8676.bugfix
Normal file
1
changelog.d/8676.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix a bug where an appservice may not be forwarded events for a room it was recently invited to. Broken in v1.22.0.
|
1
changelog.d/8678.bugfix
Normal file
1
changelog.d/8678.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix `Object of type frozendict is not JSON serializable` exceptions when using third-party event rules.
|
1
changelog.d/8679.misc
Normal file
1
changelog.d/8679.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Clarify representation of events in logfiles.
|
1
changelog.d/8680.misc
Normal file
1
changelog.d/8680.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Don't require `hiredis` package to be installed to run unit tests.
|
1
changelog.d/8682.bugfix
Normal file
1
changelog.d/8682.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix exception during handling multiple concurrent requests for remote media when using multiple media repositories.
|
1
changelog.d/8684.misc
Normal file
1
changelog.d/8684.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix typing info on cache call signature to accept `on_invalidate`.
|
1
changelog.d/8685.feature
Normal file
1
changelog.d/8685.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Support generating structured logs via the standard logging configuration.
|
1
changelog.d/8688.misc
Normal file
1
changelog.d/8688.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Abstract some invite-related code in preparation for landing knocking.
|
1
changelog.d/8689.feature
Normal file
1
changelog.d/8689.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add an admin APIs to allow server admins to list users' pushers. Contributed by @dklimpel.
|
1
changelog.d/8690.misc
Normal file
1
changelog.d/8690.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fail tests if they do not await coroutines.
|
|
@ -3,4 +3,4 @@
|
||||||
0. Set up Prometheus and Grafana. Out of scope for this readme. Useful documentation about using Grafana with Prometheus: http://docs.grafana.org/features/datasources/prometheus/
|
0. Set up Prometheus and Grafana. Out of scope for this readme. Useful documentation about using Grafana with Prometheus: http://docs.grafana.org/features/datasources/prometheus/
|
||||||
1. Have your Prometheus scrape your Synapse. https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md
|
1. Have your Prometheus scrape your Synapse. https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md
|
||||||
2. Import dashboard into Grafana. Download `synapse.json`. Import it to Grafana and select the correct Prometheus datasource. http://docs.grafana.org/reference/export_import/
|
2. Import dashboard into Grafana. Download `synapse.json`. Import it to Grafana and select the correct Prometheus datasource. http://docs.grafana.org/reference/export_import/
|
||||||
3. Set up additional recording rules
|
3. Set up required recording rules. https://github.com/matrix-org/synapse/tree/master/contrib/prometheus
|
||||||
|
|
|
@ -611,3 +611,82 @@ The following parameters should be set in the URL:
|
||||||
|
|
||||||
- ``user_id`` - fully qualified: for example, ``@user:server.com``.
|
- ``user_id`` - fully qualified: for example, ``@user:server.com``.
|
||||||
- ``device_id`` - The device to delete.
|
- ``device_id`` - The device to delete.
|
||||||
|
|
||||||
|
List all pushers
|
||||||
|
================
|
||||||
|
Gets information about all pushers for a specific ``user_id``.
|
||||||
|
|
||||||
|
The API is::
|
||||||
|
|
||||||
|
GET /_synapse/admin/v1/users/<user_id>/pushers
|
||||||
|
|
||||||
|
To use it, you will need to authenticate by providing an ``access_token`` for a
|
||||||
|
server admin: see `README.rst <README.rst>`_.
|
||||||
|
|
||||||
|
A response body like the following is returned:
|
||||||
|
|
||||||
|
.. code:: json
|
||||||
|
|
||||||
|
{
|
||||||
|
"pushers": [
|
||||||
|
{
|
||||||
|
"app_display_name":"HTTP Push Notifications",
|
||||||
|
"app_id":"m.http",
|
||||||
|
"data": {
|
||||||
|
"url":"example.com"
|
||||||
|
},
|
||||||
|
"device_display_name":"pushy push",
|
||||||
|
"kind":"http",
|
||||||
|
"lang":"None",
|
||||||
|
"profile_tag":"",
|
||||||
|
"pushkey":"a@example.com"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"total": 1
|
||||||
|
}
|
||||||
|
|
||||||
|
**Parameters**
|
||||||
|
|
||||||
|
The following parameters should be set in the URL:
|
||||||
|
|
||||||
|
- ``user_id`` - fully qualified: for example, ``@user:server.com``.
|
||||||
|
|
||||||
|
**Response**
|
||||||
|
|
||||||
|
The following fields are returned in the JSON response body:
|
||||||
|
|
||||||
|
- ``pushers`` - An array containing the current pushers for the user
|
||||||
|
|
||||||
|
- ``app_display_name`` - string - A string that will allow the user to identify
|
||||||
|
what application owns this pusher.
|
||||||
|
|
||||||
|
- ``app_id`` - string - This is a reverse-DNS style identifier for the application.
|
||||||
|
Max length, 64 chars.
|
||||||
|
|
||||||
|
- ``data`` - A dictionary of information for the pusher implementation itself.
|
||||||
|
|
||||||
|
- ``url`` - string - Required if ``kind`` is ``http``. The URL to use to send
|
||||||
|
notifications to.
|
||||||
|
|
||||||
|
- ``format`` - string - The format to use when sending notifications to the
|
||||||
|
Push Gateway.
|
||||||
|
|
||||||
|
- ``device_display_name`` - string - A string that will allow the user to identify
|
||||||
|
what device owns this pusher.
|
||||||
|
|
||||||
|
- ``profile_tag`` - string - This string determines which set of device specific rules
|
||||||
|
this pusher executes.
|
||||||
|
|
||||||
|
- ``kind`` - string - The kind of pusher. "http" is a pusher that sends HTTP pokes.
|
||||||
|
- ``lang`` - string - The preferred language for receiving notifications
|
||||||
|
(e.g. 'en' or 'en-US')
|
||||||
|
|
||||||
|
- ``profile_tag`` - string - This string determines which set of device specific rules
|
||||||
|
this pusher executes.
|
||||||
|
|
||||||
|
- ``pushkey`` - string - This is a unique identifier for this pusher.
|
||||||
|
Max length, 512 bytes.
|
||||||
|
|
||||||
|
- ``total`` - integer - Number of pushers.
|
||||||
|
|
||||||
|
See also `Client-Server API Spec <https://matrix.org/docs/spec/client_server/latest#get-matrix-client-r0-pushers>`_
|
||||||
|
|
|
@ -60,6 +60,8 @@
|
||||||
|
|
||||||
1. Restart Prometheus.
|
1. Restart Prometheus.
|
||||||
|
|
||||||
|
1. Consider using the [grafana dashboard](https://github.com/matrix-org/synapse/tree/master/contrib/grafana/) and required [recording rules](https://github.com/matrix-org/synapse/tree/master/contrib/prometheus/)
|
||||||
|
|
||||||
## Monitoring workers
|
## Monitoring workers
|
||||||
|
|
||||||
To monitor a Synapse installation using
|
To monitor a Synapse installation using
|
||||||
|
|
|
@ -3,7 +3,11 @@
|
||||||
# This is a YAML file containing a standard Python logging configuration
|
# This is a YAML file containing a standard Python logging configuration
|
||||||
# dictionary. See [1] for details on the valid settings.
|
# dictionary. See [1] for details on the valid settings.
|
||||||
#
|
#
|
||||||
|
# Synapse also supports structured logging for machine readable logs which can
|
||||||
|
# be ingested by ELK stacks. See [2] for details.
|
||||||
|
#
|
||||||
# [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema
|
# [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema
|
||||||
|
# [2]: https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md
|
||||||
|
|
||||||
version: 1
|
version: 1
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,116 @@
|
||||||
# Structured Logging
|
# Structured Logging
|
||||||
|
|
||||||
A structured logging system can be useful when your logs are destined for a machine to parse and process. By maintaining its machine-readable characteristics, it enables more efficient searching and aggregations when consumed by software such as the "ELK stack".
|
A structured logging system can be useful when your logs are destined for a
|
||||||
|
machine to parse and process. By maintaining its machine-readable characteristics,
|
||||||
|
it enables more efficient searching and aggregations when consumed by software
|
||||||
|
such as the "ELK stack".
|
||||||
|
|
||||||
Synapse's structured logging system is configured via the file that Synapse's `log_config` config option points to. The file must be YAML and contain `structured: true`. It must contain a list of "drains" (places where logs go to).
|
Synapse's structured logging system is configured via the file that Synapse's
|
||||||
|
`log_config` config option points to. The file should include a formatter which
|
||||||
|
uses the `synapse.logging.TerseJsonFormatter` class included with Synapse and a
|
||||||
|
handler which uses the above formatter.
|
||||||
|
|
||||||
|
There is also a `synapse.logging.JsonFormatter` option which does not include
|
||||||
|
a timestamp in the resulting JSON. This is useful if the log ingester adds its
|
||||||
|
own timestamp.
|
||||||
|
|
||||||
A structured logging configuration looks similar to the following:
|
A structured logging configuration looks similar to the following:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
version: 1
|
||||||
|
|
||||||
|
formatters:
|
||||||
|
structured:
|
||||||
|
class: synapse.logging.TerseJsonFormatter
|
||||||
|
|
||||||
|
handlers:
|
||||||
|
file:
|
||||||
|
class: logging.handlers.TimedRotatingFileHandler
|
||||||
|
formatter: structured
|
||||||
|
filename: /path/to/my/logs/homeserver.log
|
||||||
|
when: midnight
|
||||||
|
backupCount: 3 # Does not include the current log file.
|
||||||
|
encoding: utf8
|
||||||
|
|
||||||
|
loggers:
|
||||||
|
synapse:
|
||||||
|
level: INFO
|
||||||
|
handlers: [remote]
|
||||||
|
synapse.storage.SQL:
|
||||||
|
level: WARNING
|
||||||
|
```
|
||||||
|
|
||||||
|
The above logging config will set Synapse as 'INFO' logging level by default,
|
||||||
|
with the SQL layer at 'WARNING', and will log to a file, stored as JSON.
|
||||||
|
|
||||||
|
It is also possible to figure Synapse to log to a remote endpoint by using the
|
||||||
|
`synapse.logging.RemoteHandler` class included with Synapse. It takes the
|
||||||
|
following arguments:
|
||||||
|
|
||||||
|
- `host`: Hostname or IP address of the log aggregator.
|
||||||
|
- `port`: Numerical port to contact on the host.
|
||||||
|
- `maximum_buffer`: (Optional, defaults to 1000) The maximum buffer size to allow.
|
||||||
|
|
||||||
|
A remote structured logging configuration looks similar to the following:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
version: 1
|
||||||
|
|
||||||
|
formatters:
|
||||||
|
structured:
|
||||||
|
class: synapse.logging.TerseJsonFormatter
|
||||||
|
|
||||||
|
handlers:
|
||||||
|
remote:
|
||||||
|
class: synapse.logging.RemoteHandler
|
||||||
|
formatter: structured
|
||||||
|
host: 10.1.2.3
|
||||||
|
port: 9999
|
||||||
|
|
||||||
|
loggers:
|
||||||
|
synapse:
|
||||||
|
level: INFO
|
||||||
|
handlers: [remote]
|
||||||
|
synapse.storage.SQL:
|
||||||
|
level: WARNING
|
||||||
|
```
|
||||||
|
|
||||||
|
The above logging config will set Synapse as 'INFO' logging level by default,
|
||||||
|
with the SQL layer at 'WARNING', and will log JSON formatted messages to a
|
||||||
|
remote endpoint at 10.1.2.3:9999.
|
||||||
|
|
||||||
|
## Upgrading from legacy structured logging configuration
|
||||||
|
|
||||||
|
Versions of Synapse prior to v1.23.0 included a custom structured logging
|
||||||
|
configuration which is deprecated. It used a `structured: true` flag and
|
||||||
|
configured `drains` instead of ``handlers`` and `formatters`.
|
||||||
|
|
||||||
|
Synapse currently automatically converts the old configuration to the new
|
||||||
|
configuration, but this will be removed in a future version of Synapse. The
|
||||||
|
following reference can be used to update your configuration. Based on the drain
|
||||||
|
`type`, we can pick a new handler:
|
||||||
|
|
||||||
|
1. For a type of `console`, `console_json`, or `console_json_terse`: a handler
|
||||||
|
with a class of `logging.StreamHandler` and a `stream` of `ext://sys.stdout`
|
||||||
|
or `ext://sys.stderr` should be used.
|
||||||
|
2. For a type of `file` or `file_json`: a handler of `logging.FileHandler` with
|
||||||
|
a location of the file path should be used.
|
||||||
|
3. For a type of `network_json_terse`: a handler of `synapse.logging.RemoteHandler`
|
||||||
|
with the host and port should be used.
|
||||||
|
|
||||||
|
Then based on the drain `type` we can pick a new formatter:
|
||||||
|
|
||||||
|
1. For a type of `console` or `file` no formatter is necessary.
|
||||||
|
2. For a type of `console_json` or `file_json`: a formatter of
|
||||||
|
`synapse.logging.JsonFormatter` should be used.
|
||||||
|
3. For a type of `console_json_terse` or `network_json_terse`: a formatter of
|
||||||
|
`synapse.logging.TerseJsonFormatter` should be used.
|
||||||
|
|
||||||
|
For each new handler and formatter they should be added to the logging configuration
|
||||||
|
and then assigned to either a logger or the root logger.
|
||||||
|
|
||||||
|
An example legacy configuration:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
structured: true
|
structured: true
|
||||||
|
|
||||||
|
@ -24,60 +129,33 @@ drains:
|
||||||
location: homeserver.log
|
location: homeserver.log
|
||||||
```
|
```
|
||||||
|
|
||||||
The above logging config will set Synapse as 'INFO' logging level by default, with the SQL layer at 'WARNING', and will have two logging drains (to the console and to a file, stored as JSON).
|
Would be converted into a new configuration:
|
||||||
|
|
||||||
## Drain Types
|
```yaml
|
||||||
|
version: 1
|
||||||
|
|
||||||
Drain types can be specified by the `type` key.
|
formatters:
|
||||||
|
json:
|
||||||
|
class: synapse.logging.JsonFormatter
|
||||||
|
|
||||||
### `console`
|
handlers:
|
||||||
|
console:
|
||||||
|
class: logging.StreamHandler
|
||||||
|
location: ext://sys.stdout
|
||||||
|
file:
|
||||||
|
class: logging.FileHandler
|
||||||
|
formatter: json
|
||||||
|
filename: homeserver.log
|
||||||
|
|
||||||
Outputs human-readable logs to the console.
|
loggers:
|
||||||
|
synapse:
|
||||||
|
level: INFO
|
||||||
|
handlers: [console, file]
|
||||||
|
synapse.storage.SQL:
|
||||||
|
level: WARNING
|
||||||
|
```
|
||||||
|
|
||||||
Arguments:
|
The new logging configuration is a bit more verbose, but significantly more
|
||||||
|
flexible. It allows for configuration that were not previously possible, such as
|
||||||
- `location`: Either `stdout` or `stderr`.
|
sending plain logs over the network, or using different handlers for different
|
||||||
|
modules.
|
||||||
### `console_json`
|
|
||||||
|
|
||||||
Outputs machine-readable JSON logs to the console.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
|
|
||||||
- `location`: Either `stdout` or `stderr`.
|
|
||||||
|
|
||||||
### `console_json_terse`
|
|
||||||
|
|
||||||
Outputs machine-readable JSON logs to the console, separated by newlines. This
|
|
||||||
format is not designed to be read and re-formatted into human-readable text, but
|
|
||||||
is optimal for a logging aggregation system.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
|
|
||||||
- `location`: Either `stdout` or `stderr`.
|
|
||||||
|
|
||||||
### `file`
|
|
||||||
|
|
||||||
Outputs human-readable logs to a file.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
|
|
||||||
- `location`: An absolute path to the file to log to.
|
|
||||||
|
|
||||||
### `file_json`
|
|
||||||
|
|
||||||
Outputs machine-readable logs to a file.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
|
|
||||||
- `location`: An absolute path to the file to log to.
|
|
||||||
|
|
||||||
### `network_json_terse`
|
|
||||||
|
|
||||||
Delivers machine-readable JSON logs to a log aggregator over TCP. This is
|
|
||||||
compatible with LogStash's TCP input with the codec set to `json_lines`.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
|
|
||||||
- `host`: Hostname or IP address of the log aggregator.
|
|
||||||
- `port`: Numerical port to contact on the host.
|
|
||||||
|
|
4
mypy.ini
4
mypy.ini
|
@ -57,6 +57,7 @@ files =
|
||||||
synapse/server_notices,
|
synapse/server_notices,
|
||||||
synapse/spam_checker_api,
|
synapse/spam_checker_api,
|
||||||
synapse/state,
|
synapse/state,
|
||||||
|
synapse/storage/databases/main/appservice.py,
|
||||||
synapse/storage/databases/main/events.py,
|
synapse/storage/databases/main/events.py,
|
||||||
synapse/storage/databases/main/registration.py,
|
synapse/storage/databases/main/registration.py,
|
||||||
synapse/storage/databases/main/stream.py,
|
synapse/storage/databases/main/stream.py,
|
||||||
|
@ -82,6 +83,9 @@ ignore_missing_imports = True
|
||||||
[mypy-zope]
|
[mypy-zope]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
[mypy-bcrypt]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
|
||||||
[mypy-constantly]
|
[mypy-constantly]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@
|
||||||
showcontent = true
|
showcontent = true
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
target-version = ['py34']
|
target-version = ['py35']
|
||||||
exclude = '''
|
exclude = '''
|
||||||
|
|
||||||
(
|
(
|
||||||
|
|
|
@ -80,7 +80,7 @@ else
|
||||||
# then lint everything!
|
# then lint everything!
|
||||||
if [[ -z ${files+x} ]]; then
|
if [[ -z ${files+x} ]]; then
|
||||||
# Lint all source code files and directories
|
# Lint all source code files and directories
|
||||||
files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py")
|
files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark")
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
@ -94,3 +94,4 @@ isort "${files[@]}"
|
||||||
python3 -m black "${files[@]}"
|
python3 -m black "${files[@]}"
|
||||||
./scripts-dev/config-lint.sh
|
./scripts-dev/config-lint.sh
|
||||||
flake8 "${files[@]}"
|
flake8 "${files[@]}"
|
||||||
|
mypy
|
||||||
|
|
|
@ -19,9 +19,10 @@ can crop up, e.g the cache descriptors.
|
||||||
|
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from mypy.nodes import ARG_NAMED_OPT
|
||||||
from mypy.plugin import MethodSigContext, Plugin
|
from mypy.plugin import MethodSigContext, Plugin
|
||||||
from mypy.typeops import bind_self
|
from mypy.typeops import bind_self
|
||||||
from mypy.types import CallableType
|
from mypy.types import CallableType, NoneType
|
||||||
|
|
||||||
|
|
||||||
class SynapsePlugin(Plugin):
|
class SynapsePlugin(Plugin):
|
||||||
|
@ -40,8 +41,9 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
|
||||||
|
|
||||||
It already has *almost* the correct signature, except:
|
It already has *almost* the correct signature, except:
|
||||||
|
|
||||||
1. the `self` argument needs to be marked as "bound"; and
|
1. the `self` argument needs to be marked as "bound";
|
||||||
2. any `cache_context` argument should be removed.
|
2. any `cache_context` argument should be removed;
|
||||||
|
3. an optional keyword argument `on_invalidated` should be added.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# First we mark this as a bound function signature.
|
# First we mark this as a bound function signature.
|
||||||
|
@ -58,19 +60,33 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
|
||||||
context_arg_index = idx
|
context_arg_index = idx
|
||||||
break
|
break
|
||||||
|
|
||||||
|
arg_types = list(signature.arg_types)
|
||||||
|
arg_names = list(signature.arg_names)
|
||||||
|
arg_kinds = list(signature.arg_kinds)
|
||||||
|
|
||||||
if context_arg_index:
|
if context_arg_index:
|
||||||
arg_types = list(signature.arg_types)
|
|
||||||
arg_types.pop(context_arg_index)
|
arg_types.pop(context_arg_index)
|
||||||
|
|
||||||
arg_names = list(signature.arg_names)
|
|
||||||
arg_names.pop(context_arg_index)
|
arg_names.pop(context_arg_index)
|
||||||
|
|
||||||
arg_kinds = list(signature.arg_kinds)
|
|
||||||
arg_kinds.pop(context_arg_index)
|
arg_kinds.pop(context_arg_index)
|
||||||
|
|
||||||
signature = signature.copy_modified(
|
# Third, we add an optional "on_invalidate" argument.
|
||||||
arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
|
#
|
||||||
)
|
# This is a callable which accepts no input and returns nothing.
|
||||||
|
calltyp = CallableType(
|
||||||
|
arg_types=[],
|
||||||
|
arg_kinds=[],
|
||||||
|
arg_names=[],
|
||||||
|
ret_type=NoneType(),
|
||||||
|
fallback=ctx.api.named_generic_type("builtins.function", []),
|
||||||
|
)
|
||||||
|
|
||||||
|
arg_types.append(calltyp)
|
||||||
|
arg_names.append("on_invalidate")
|
||||||
|
arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg.
|
||||||
|
|
||||||
|
signature = signature.copy_modified(
|
||||||
|
arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
|
||||||
|
)
|
||||||
|
|
||||||
return signature
|
return signature
|
||||||
|
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -131,6 +131,7 @@ setup(
|
||||||
"Programming Language :: Python :: 3.6",
|
"Programming Language :: Python :: 3.6",
|
||||||
"Programming Language :: Python :: 3.7",
|
"Programming Language :: Python :: 3.7",
|
||||||
"Programming Language :: Python :: 3.8",
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
],
|
],
|
||||||
scripts=["synctl"] + glob.glob("scripts/*"),
|
scripts=["synctl"] + glob.glob("scripts/*"),
|
||||||
cmdclass={"test": TestCommand},
|
cmdclass={"test": TestCommand},
|
||||||
|
|
|
@ -33,6 +33,7 @@ from synapse.api.errors import (
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.logging import opentracing as opentracing
|
from synapse.logging import opentracing as opentracing
|
||||||
|
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||||
from synapse.types import StateMap, UserID
|
from synapse.types import StateMap, UserID
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
@ -190,10 +191,6 @@ class Auth:
|
||||||
|
|
||||||
user_id, app_service = await self._get_appservice_user_id(request)
|
user_id, app_service = await self._get_appservice_user_id(request)
|
||||||
if user_id:
|
if user_id:
|
||||||
request.authenticated_entity = user_id
|
|
||||||
opentracing.set_tag("authenticated_entity", user_id)
|
|
||||||
opentracing.set_tag("appservice_id", app_service.id)
|
|
||||||
|
|
||||||
if ip_addr and self._track_appservice_user_ips:
|
if ip_addr and self._track_appservice_user_ips:
|
||||||
await self.store.insert_client_ip(
|
await self.store.insert_client_ip(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
@ -203,31 +200,38 @@ class Auth:
|
||||||
device_id="dummy-device", # stubbed
|
device_id="dummy-device", # stubbed
|
||||||
)
|
)
|
||||||
|
|
||||||
return synapse.types.create_requester(user_id, app_service=app_service)
|
requester = synapse.types.create_requester(
|
||||||
|
user_id, app_service=app_service
|
||||||
|
)
|
||||||
|
|
||||||
|
request.requester = user_id
|
||||||
|
opentracing.set_tag("authenticated_entity", user_id)
|
||||||
|
opentracing.set_tag("user_id", user_id)
|
||||||
|
opentracing.set_tag("appservice_id", app_service.id)
|
||||||
|
|
||||||
|
return requester
|
||||||
|
|
||||||
user_info = await self.get_user_by_access_token(
|
user_info = await self.get_user_by_access_token(
|
||||||
access_token, rights, allow_expired=allow_expired
|
access_token, rights, allow_expired=allow_expired
|
||||||
)
|
)
|
||||||
user = user_info["user"]
|
token_id = user_info.token_id
|
||||||
token_id = user_info["token_id"]
|
is_guest = user_info.is_guest
|
||||||
is_guest = user_info["is_guest"]
|
shadow_banned = user_info.shadow_banned
|
||||||
shadow_banned = user_info["shadow_banned"]
|
|
||||||
|
|
||||||
# Deny the request if the user account has expired.
|
# Deny the request if the user account has expired.
|
||||||
if self._account_validity.enabled and not allow_expired:
|
if self._account_validity.enabled and not allow_expired:
|
||||||
user_id = user.to_string()
|
if await self.store.is_account_expired(
|
||||||
if await self.store.is_account_expired(user_id, self.clock.time_msec()):
|
user_info.user_id, self.clock.time_msec()
|
||||||
|
):
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
|
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
|
||||||
)
|
)
|
||||||
|
|
||||||
# device_id may not be present if get_user_by_access_token has been
|
device_id = user_info.device_id
|
||||||
# stubbed out.
|
|
||||||
device_id = user_info.get("device_id")
|
|
||||||
|
|
||||||
if user and access_token and ip_addr:
|
if access_token and ip_addr:
|
||||||
await self.store.insert_client_ip(
|
await self.store.insert_client_ip(
|
||||||
user_id=user.to_string(),
|
user_id=user_info.token_owner,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
ip=ip_addr,
|
ip=ip_addr,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
|
@ -241,19 +245,23 @@ class Auth:
|
||||||
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
|
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
request.authenticated_entity = user.to_string()
|
requester = synapse.types.create_requester(
|
||||||
opentracing.set_tag("authenticated_entity", user.to_string())
|
user_info.user_id,
|
||||||
if device_id:
|
|
||||||
opentracing.set_tag("device_id", device_id)
|
|
||||||
|
|
||||||
return synapse.types.create_requester(
|
|
||||||
user,
|
|
||||||
token_id,
|
token_id,
|
||||||
is_guest,
|
is_guest,
|
||||||
shadow_banned,
|
shadow_banned,
|
||||||
device_id,
|
device_id,
|
||||||
app_service=app_service,
|
app_service=app_service,
|
||||||
|
authenticated_entity=user_info.token_owner,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
request.requester = requester
|
||||||
|
opentracing.set_tag("authenticated_entity", user_info.token_owner)
|
||||||
|
opentracing.set_tag("user_id", user_info.user_id)
|
||||||
|
if device_id:
|
||||||
|
opentracing.set_tag("device_id", device_id)
|
||||||
|
|
||||||
|
return requester
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise MissingClientTokenError()
|
raise MissingClientTokenError()
|
||||||
|
|
||||||
|
@ -284,7 +292,7 @@ class Auth:
|
||||||
|
|
||||||
async def get_user_by_access_token(
|
async def get_user_by_access_token(
|
||||||
self, token: str, rights: str = "access", allow_expired: bool = False,
|
self, token: str, rights: str = "access", allow_expired: bool = False,
|
||||||
) -> dict:
|
) -> TokenLookupResult:
|
||||||
""" Validate access token and get user_id from it
|
""" Validate access token and get user_id from it
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -293,13 +301,7 @@ class Auth:
|
||||||
allow this
|
allow this
|
||||||
allow_expired: If False, raises an InvalidClientTokenError
|
allow_expired: If False, raises an InvalidClientTokenError
|
||||||
if the token is expired
|
if the token is expired
|
||||||
Returns:
|
|
||||||
dict that includes:
|
|
||||||
`user` (UserID)
|
|
||||||
`is_guest` (bool)
|
|
||||||
`shadow_banned` (bool)
|
|
||||||
`token_id` (int|None): access token id. May be None if guest
|
|
||||||
`device_id` (str|None): device corresponding to access token
|
|
||||||
Raises:
|
Raises:
|
||||||
InvalidClientTokenError if a user by that token exists, but the token is
|
InvalidClientTokenError if a user by that token exists, but the token is
|
||||||
expired
|
expired
|
||||||
|
@ -309,9 +311,9 @@ class Auth:
|
||||||
|
|
||||||
if rights == "access":
|
if rights == "access":
|
||||||
# first look in the database
|
# first look in the database
|
||||||
r = await self._look_up_user_by_access_token(token)
|
r = await self.store.get_user_by_access_token(token)
|
||||||
if r:
|
if r:
|
||||||
valid_until_ms = r["valid_until_ms"]
|
valid_until_ms = r.valid_until_ms
|
||||||
if (
|
if (
|
||||||
not allow_expired
|
not allow_expired
|
||||||
and valid_until_ms is not None
|
and valid_until_ms is not None
|
||||||
|
@ -328,7 +330,6 @@ class Auth:
|
||||||
# otherwise it needs to be a valid macaroon
|
# otherwise it needs to be a valid macaroon
|
||||||
try:
|
try:
|
||||||
user_id, guest = self._parse_and_validate_macaroon(token, rights)
|
user_id, guest = self._parse_and_validate_macaroon(token, rights)
|
||||||
user = UserID.from_string(user_id)
|
|
||||||
|
|
||||||
if rights == "access":
|
if rights == "access":
|
||||||
if not guest:
|
if not guest:
|
||||||
|
@ -354,23 +355,17 @@ class Auth:
|
||||||
raise InvalidClientTokenError(
|
raise InvalidClientTokenError(
|
||||||
"Guest access token used for regular user"
|
"Guest access token used for regular user"
|
||||||
)
|
)
|
||||||
ret = {
|
|
||||||
"user": user,
|
ret = TokenLookupResult(
|
||||||
"is_guest": True,
|
user_id=user_id,
|
||||||
"shadow_banned": False,
|
is_guest=True,
|
||||||
"token_id": None,
|
|
||||||
# all guests get the same device id
|
# all guests get the same device id
|
||||||
"device_id": GUEST_DEVICE_ID,
|
device_id=GUEST_DEVICE_ID,
|
||||||
}
|
)
|
||||||
elif rights == "delete_pusher":
|
elif rights == "delete_pusher":
|
||||||
# We don't store these tokens in the database
|
# We don't store these tokens in the database
|
||||||
ret = {
|
|
||||||
"user": user,
|
ret = TokenLookupResult(user_id=user_id, is_guest=False)
|
||||||
"is_guest": False,
|
|
||||||
"shadow_banned": False,
|
|
||||||
"token_id": None,
|
|
||||||
"device_id": None,
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unknown rights setting %s", rights)
|
raise RuntimeError("Unknown rights setting %s", rights)
|
||||||
return ret
|
return ret
|
||||||
|
@ -479,31 +474,15 @@ class Auth:
|
||||||
now = self.hs.get_clock().time_msec()
|
now = self.hs.get_clock().time_msec()
|
||||||
return now < expiry
|
return now < expiry
|
||||||
|
|
||||||
async def _look_up_user_by_access_token(self, token):
|
|
||||||
ret = await self.store.get_user_by_access_token(token)
|
|
||||||
if not ret:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# we use ret.get() below because *lots* of unit tests stub out
|
|
||||||
# get_user_by_access_token in a way where it only returns a couple of
|
|
||||||
# the fields.
|
|
||||||
user_info = {
|
|
||||||
"user": UserID.from_string(ret.get("name")),
|
|
||||||
"token_id": ret.get("token_id", None),
|
|
||||||
"is_guest": False,
|
|
||||||
"shadow_banned": ret.get("shadow_banned"),
|
|
||||||
"device_id": ret.get("device_id"),
|
|
||||||
"valid_until_ms": ret.get("valid_until_ms"),
|
|
||||||
}
|
|
||||||
return user_info
|
|
||||||
|
|
||||||
def get_appservice_by_req(self, request):
|
def get_appservice_by_req(self, request):
|
||||||
token = self.get_access_token_from_request(request)
|
token = self.get_access_token_from_request(request)
|
||||||
service = self.store.get_app_service_by_token(token)
|
service = self.store.get_app_service_by_token(token)
|
||||||
if not service:
|
if not service:
|
||||||
logger.warning("Unrecognised appservice access token.")
|
logger.warning("Unrecognised appservice access token.")
|
||||||
raise InvalidClientTokenError()
|
raise InvalidClientTokenError()
|
||||||
request.authenticated_entity = service.sender
|
request.requester = synapse.types.create_requester(
|
||||||
|
service.sender, app_service=service
|
||||||
|
)
|
||||||
return service
|
return service
|
||||||
|
|
||||||
async def is_server_admin(self, user: UserID) -> bool:
|
async def is_server_admin(self, user: UserID) -> bool:
|
||||||
|
|
|
@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Iterable, List, Match, Optional
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id
|
from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import _CacheContext, cached
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.appservice.api import ApplicationServiceApi
|
from synapse.appservice.api import ApplicationServiceApi
|
||||||
|
@ -52,11 +52,11 @@ class ApplicationService:
|
||||||
self,
|
self,
|
||||||
token,
|
token,
|
||||||
hostname,
|
hostname,
|
||||||
|
id,
|
||||||
|
sender,
|
||||||
url=None,
|
url=None,
|
||||||
namespaces=None,
|
namespaces=None,
|
||||||
hs_token=None,
|
hs_token=None,
|
||||||
sender=None,
|
|
||||||
id=None,
|
|
||||||
protocols=None,
|
protocols=None,
|
||||||
rate_limited=True,
|
rate_limited=True,
|
||||||
ip_range_whitelist=None,
|
ip_range_whitelist=None,
|
||||||
|
@ -164,9 +164,9 @@ class ApplicationService:
|
||||||
does_match = await self.matches_user_in_member_list(event.room_id, store)
|
does_match = await self.matches_user_in_member_list(event.room_id, store)
|
||||||
return does_match
|
return does_match
|
||||||
|
|
||||||
@cached(num_args=1)
|
@cached(num_args=1, cache_context=True)
|
||||||
async def matches_user_in_member_list(
|
async def matches_user_in_member_list(
|
||||||
self, room_id: str, store: "DataStore"
|
self, room_id: str, store: "DataStore", cache_context: _CacheContext,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if this service is interested a room based upon it's membership
|
"""Check if this service is interested a room based upon it's membership
|
||||||
|
|
||||||
|
@ -177,7 +177,9 @@ class ApplicationService:
|
||||||
Returns:
|
Returns:
|
||||||
True if this service would like to know about this room.
|
True if this service would like to know about this room.
|
||||||
"""
|
"""
|
||||||
member_list = await store.get_users_in_room(room_id)
|
member_list = await store.get_users_in_room(
|
||||||
|
room_id, on_invalidate=cache_context.invalidate
|
||||||
|
)
|
||||||
|
|
||||||
# check joined member events
|
# check joined member events
|
||||||
for user_id in member_list:
|
for user_id in member_list:
|
||||||
|
|
|
@ -23,7 +23,6 @@ from string import Template
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from twisted.logger import (
|
from twisted.logger import (
|
||||||
ILogObserver,
|
|
||||||
LogBeginner,
|
LogBeginner,
|
||||||
STDLibLogObserver,
|
STDLibLogObserver,
|
||||||
eventAsText,
|
eventAsText,
|
||||||
|
@ -32,11 +31,9 @@ from twisted.logger import (
|
||||||
|
|
||||||
import synapse
|
import synapse
|
||||||
from synapse.app import _base as appbase
|
from synapse.app import _base as appbase
|
||||||
from synapse.logging._structured import (
|
from synapse.logging._structured import setup_structured_logging
|
||||||
reload_structured_logging,
|
|
||||||
setup_structured_logging,
|
|
||||||
)
|
|
||||||
from synapse.logging.context import LoggingContextFilter
|
from synapse.logging.context import LoggingContextFilter
|
||||||
|
from synapse.logging.filter import MetadataFilter
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
|
|
||||||
from ._base import Config, ConfigError
|
from ._base import Config, ConfigError
|
||||||
|
@ -48,7 +45,11 @@ DEFAULT_LOG_CONFIG = Template(
|
||||||
# This is a YAML file containing a standard Python logging configuration
|
# This is a YAML file containing a standard Python logging configuration
|
||||||
# dictionary. See [1] for details on the valid settings.
|
# dictionary. See [1] for details on the valid settings.
|
||||||
#
|
#
|
||||||
|
# Synapse also supports structured logging for machine readable logs which can
|
||||||
|
# be ingested by ELK stacks. See [2] for details.
|
||||||
|
#
|
||||||
# [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema
|
# [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema
|
||||||
|
# [2]: https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md
|
||||||
|
|
||||||
version: 1
|
version: 1
|
||||||
|
|
||||||
|
@ -176,11 +177,11 @@ class LoggingConfig(Config):
|
||||||
log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file))
|
log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file))
|
||||||
|
|
||||||
|
|
||||||
def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
|
def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> None:
|
||||||
"""
|
"""
|
||||||
Set up Python stdlib logging.
|
Set up Python standard library logging.
|
||||||
"""
|
"""
|
||||||
if log_config is None:
|
if log_config_path is None:
|
||||||
log_format = (
|
log_format = (
|
||||||
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
|
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
|
||||||
" - %(message)s"
|
" - %(message)s"
|
||||||
|
@ -196,7 +197,8 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
else:
|
else:
|
||||||
logging.config.dictConfig(log_config)
|
# Load the logging configuration.
|
||||||
|
_load_logging_config(log_config_path)
|
||||||
|
|
||||||
# We add a log record factory that runs all messages through the
|
# We add a log record factory that runs all messages through the
|
||||||
# LoggingContextFilter so that we get the context *at the time we log*
|
# LoggingContextFilter so that we get the context *at the time we log*
|
||||||
|
@ -204,12 +206,14 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
|
||||||
# filter options, but care must when using e.g. MemoryHandler to buffer
|
# filter options, but care must when using e.g. MemoryHandler to buffer
|
||||||
# writes.
|
# writes.
|
||||||
|
|
||||||
log_filter = LoggingContextFilter(request="")
|
log_context_filter = LoggingContextFilter(request="")
|
||||||
|
log_metadata_filter = MetadataFilter({"server_name": config.server_name})
|
||||||
old_factory = logging.getLogRecordFactory()
|
old_factory = logging.getLogRecordFactory()
|
||||||
|
|
||||||
def factory(*args, **kwargs):
|
def factory(*args, **kwargs):
|
||||||
record = old_factory(*args, **kwargs)
|
record = old_factory(*args, **kwargs)
|
||||||
log_filter.filter(record)
|
log_context_filter.filter(record)
|
||||||
|
log_metadata_filter.filter(record)
|
||||||
return record
|
return record
|
||||||
|
|
||||||
logging.setLogRecordFactory(factory)
|
logging.setLogRecordFactory(factory)
|
||||||
|
@ -255,21 +259,40 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
|
||||||
if not config.no_redirect_stdio:
|
if not config.no_redirect_stdio:
|
||||||
print("Redirected stdout/stderr to logs")
|
print("Redirected stdout/stderr to logs")
|
||||||
|
|
||||||
return observer
|
|
||||||
|
|
||||||
|
def _load_logging_config(log_config_path: str) -> None:
|
||||||
def _reload_stdlib_logging(*args, log_config=None):
|
"""
|
||||||
logger = logging.getLogger("")
|
Configure logging from a log config path.
|
||||||
|
"""
|
||||||
|
with open(log_config_path, "rb") as f:
|
||||||
|
log_config = yaml.safe_load(f.read())
|
||||||
|
|
||||||
if not log_config:
|
if not log_config:
|
||||||
logger.warning("Reloaded a blank config?")
|
logging.warning("Loaded a blank logging config?")
|
||||||
|
|
||||||
|
# If the old structured logging configuration is being used, convert it to
|
||||||
|
# the new style configuration.
|
||||||
|
if "structured" in log_config and log_config.get("structured"):
|
||||||
|
log_config = setup_structured_logging(log_config)
|
||||||
|
|
||||||
logging.config.dictConfig(log_config)
|
logging.config.dictConfig(log_config)
|
||||||
|
|
||||||
|
|
||||||
|
def _reload_logging_config(log_config_path):
|
||||||
|
"""
|
||||||
|
Reload the log configuration from the file and apply it.
|
||||||
|
"""
|
||||||
|
# If no log config path was given, it cannot be reloaded.
|
||||||
|
if log_config_path is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
_load_logging_config(log_config_path)
|
||||||
|
logging.info("Reloaded log config from %s due to SIGHUP", log_config_path)
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(
|
def setup_logging(
|
||||||
hs, config, use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner
|
hs, config, use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner
|
||||||
) -> ILogObserver:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Set up the logging subsystem.
|
Set up the logging subsystem.
|
||||||
|
|
||||||
|
@ -282,41 +305,18 @@ def setup_logging(
|
||||||
|
|
||||||
logBeginner: The Twisted logBeginner to use.
|
logBeginner: The Twisted logBeginner to use.
|
||||||
|
|
||||||
Returns:
|
|
||||||
The "root" Twisted Logger observer, suitable for sending logs to from a
|
|
||||||
Logger instance.
|
|
||||||
"""
|
"""
|
||||||
log_config = config.worker_log_config if use_worker_options else config.log_config
|
log_config_path = (
|
||||||
|
config.worker_log_config if use_worker_options else config.log_config
|
||||||
|
)
|
||||||
|
|
||||||
def read_config(*args, callback=None):
|
# Perform one-time logging configuration.
|
||||||
if log_config is None:
|
_setup_stdlib_logging(config, log_config_path, logBeginner=logBeginner)
|
||||||
return None
|
# Add a SIGHUP handler to reload the logging configuration, if one is available.
|
||||||
|
appbase.register_sighup(_reload_logging_config, log_config_path)
|
||||||
|
|
||||||
with open(log_config, "rb") as f:
|
# Log immediately so we can grep backwards.
|
||||||
log_config_body = yaml.safe_load(f.read())
|
|
||||||
|
|
||||||
if callback:
|
|
||||||
callback(log_config=log_config_body)
|
|
||||||
logging.info("Reloaded log config from %s due to SIGHUP", log_config)
|
|
||||||
|
|
||||||
return log_config_body
|
|
||||||
|
|
||||||
log_config_body = read_config()
|
|
||||||
|
|
||||||
if log_config_body and log_config_body.get("structured") is True:
|
|
||||||
logger = setup_structured_logging(
|
|
||||||
hs, config, log_config_body, logBeginner=logBeginner
|
|
||||||
)
|
|
||||||
appbase.register_sighup(read_config, callback=reload_structured_logging)
|
|
||||||
else:
|
|
||||||
logger = _setup_stdlib_logging(config, log_config_body, logBeginner=logBeginner)
|
|
||||||
appbase.register_sighup(read_config, callback=_reload_stdlib_logging)
|
|
||||||
|
|
||||||
# make sure that the first thing we log is a thing we can grep backwards
|
|
||||||
# for
|
|
||||||
logging.warning("***** STARTING SERVER *****")
|
logging.warning("***** STARTING SERVER *****")
|
||||||
logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse))
|
logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse))
|
||||||
logging.info("Server hostname: %s", config.server_name)
|
logging.info("Server hostname: %s", config.server_name)
|
||||||
logging.info("Instance name: %s", hs.get_instance_name())
|
logging.info("Instance name: %s", hs.get_instance_name())
|
||||||
|
|
||||||
return logger
|
|
||||||
|
|
|
@ -368,7 +368,7 @@ class FrozenEvent(EventBase):
|
||||||
return self.__repr__()
|
return self.__repr__()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<FrozenEvent event_id='%s', type='%s', state_key='%s'>" % (
|
return "<FrozenEvent event_id=%r, type=%r, state_key=%r>" % (
|
||||||
self.get("event_id", None),
|
self.get("event_id", None),
|
||||||
self.get("type", None),
|
self.get("type", None),
|
||||||
self.get("state_key", None),
|
self.get("state_key", None),
|
||||||
|
@ -451,7 +451,7 @@ class FrozenEventV2(EventBase):
|
||||||
return self.__repr__()
|
return self.__repr__()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<%s event_id='%s', type='%s', state_key='%s'>" % (
|
return "<%s event_id=%r, type=%r, state_key=%r>" % (
|
||||||
self.__class__.__name__,
|
self.__class__.__name__,
|
||||||
self.event_id,
|
self.event_id,
|
||||||
self.get("type", None),
|
self.get("type", None),
|
||||||
|
|
|
@ -154,7 +154,7 @@ class Authenticator:
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Request from %s", origin)
|
logger.debug("Request from %s", origin)
|
||||||
request.authenticated_entity = origin
|
request.requester = origin
|
||||||
|
|
||||||
# If we get a valid signed request from the other side, its probably
|
# If we get a valid signed request from the other side, its probably
|
||||||
# alive
|
# alive
|
||||||
|
|
|
@ -12,9 +12,8 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||||
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
|
@ -34,16 +33,20 @@ from synapse.metrics.background_process_metrics import (
|
||||||
run_as_background_process,
|
run_as_background_process,
|
||||||
wrap_as_background_process,
|
wrap_as_background_process,
|
||||||
)
|
)
|
||||||
from synapse.types import Collection, JsonDict, RoomStreamToken, UserID
|
from synapse.storage.databases.main.directory import RoomAliasMapping
|
||||||
|
from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, UserID
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
|
events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
|
||||||
|
|
||||||
|
|
||||||
class ApplicationServicesHandler:
|
class ApplicationServicesHandler:
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
self.appservice_api = hs.get_application_service_api()
|
self.appservice_api = hs.get_application_service_api()
|
||||||
|
@ -247,7 +250,9 @@ class ApplicationServicesHandler:
|
||||||
service, "presence", new_token
|
service, "presence", new_token
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_typing(self, service: ApplicationService, new_token: int):
|
async def _handle_typing(
|
||||||
|
self, service: ApplicationService, new_token: int
|
||||||
|
) -> List[JsonDict]:
|
||||||
typing_source = self.event_sources.sources["typing"]
|
typing_source = self.event_sources.sources["typing"]
|
||||||
# Get the typing events from just before current
|
# Get the typing events from just before current
|
||||||
typing, _ = await typing_source.get_new_events_as(
|
typing, _ = await typing_source.get_new_events_as(
|
||||||
|
@ -259,7 +264,7 @@ class ApplicationServicesHandler:
|
||||||
)
|
)
|
||||||
return typing
|
return typing
|
||||||
|
|
||||||
async def _handle_receipts(self, service: ApplicationService):
|
async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]:
|
||||||
from_key = await self.store.get_type_stream_id_for_appservice(
|
from_key = await self.store.get_type_stream_id_for_appservice(
|
||||||
service, "read_receipt"
|
service, "read_receipt"
|
||||||
)
|
)
|
||||||
|
@ -271,7 +276,7 @@ class ApplicationServicesHandler:
|
||||||
|
|
||||||
async def _handle_presence(
|
async def _handle_presence(
|
||||||
self, service: ApplicationService, users: Collection[Union[str, UserID]]
|
self, service: ApplicationService, users: Collection[Union[str, UserID]]
|
||||||
):
|
) -> List[JsonDict]:
|
||||||
events = [] # type: List[JsonDict]
|
events = [] # type: List[JsonDict]
|
||||||
presence_source = self.event_sources.sources["presence"]
|
presence_source = self.event_sources.sources["presence"]
|
||||||
from_key = await self.store.get_type_stream_id_for_appservice(
|
from_key = await self.store.get_type_stream_id_for_appservice(
|
||||||
|
@ -301,11 +306,11 @@ class ApplicationServicesHandler:
|
||||||
|
|
||||||
return events
|
return events
|
||||||
|
|
||||||
async def query_user_exists(self, user_id):
|
async def query_user_exists(self, user_id: str) -> bool:
|
||||||
"""Check if any application service knows this user_id exists.
|
"""Check if any application service knows this user_id exists.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): The user to query if they exist on any AS.
|
user_id: The user to query if they exist on any AS.
|
||||||
Returns:
|
Returns:
|
||||||
True if this user exists on at least one application service.
|
True if this user exists on at least one application service.
|
||||||
"""
|
"""
|
||||||
|
@ -316,11 +321,13 @@ class ApplicationServicesHandler:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def query_room_alias_exists(self, room_alias):
|
async def query_room_alias_exists(
|
||||||
|
self, room_alias: RoomAlias
|
||||||
|
) -> Optional[RoomAliasMapping]:
|
||||||
"""Check if an application service knows this room alias exists.
|
"""Check if an application service knows this room alias exists.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_alias(RoomAlias): The room alias to query.
|
room_alias: The room alias to query.
|
||||||
Returns:
|
Returns:
|
||||||
namedtuple: with keys "room_id" and "servers" or None if no
|
namedtuple: with keys "room_id" and "servers" or None if no
|
||||||
association can be found.
|
association can be found.
|
||||||
|
@ -336,10 +343,13 @@ class ApplicationServicesHandler:
|
||||||
)
|
)
|
||||||
if is_known_alias:
|
if is_known_alias:
|
||||||
# the alias exists now so don't query more ASes.
|
# the alias exists now so don't query more ASes.
|
||||||
result = await self.store.get_association_from_room_alias(room_alias)
|
return await self.store.get_association_from_room_alias(room_alias)
|
||||||
return result
|
|
||||||
|
|
||||||
async def query_3pe(self, kind, protocol, fields):
|
return None
|
||||||
|
|
||||||
|
async def query_3pe(
|
||||||
|
self, kind: str, protocol: str, fields: Dict[bytes, List[bytes]]
|
||||||
|
) -> List[JsonDict]:
|
||||||
services = self._get_services_for_3pn(protocol)
|
services = self._get_services_for_3pn(protocol)
|
||||||
|
|
||||||
results = await make_deferred_yieldable(
|
results = await make_deferred_yieldable(
|
||||||
|
@ -361,7 +371,9 @@ class ApplicationServicesHandler:
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
async def get_3pe_protocols(self, only_protocol=None):
|
async def get_3pe_protocols(
|
||||||
|
self, only_protocol: Optional[str] = None
|
||||||
|
) -> Dict[str, JsonDict]:
|
||||||
services = self.store.get_app_services()
|
services = self.store.get_app_services()
|
||||||
protocols = {} # type: Dict[str, List[JsonDict]]
|
protocols = {} # type: Dict[str, List[JsonDict]]
|
||||||
|
|
||||||
|
@ -379,7 +391,7 @@ class ApplicationServicesHandler:
|
||||||
if info is not None:
|
if info is not None:
|
||||||
protocols[p].append(info)
|
protocols[p].append(info)
|
||||||
|
|
||||||
def _merge_instances(infos):
|
def _merge_instances(infos: List[JsonDict]) -> JsonDict:
|
||||||
if not infos:
|
if not infos:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@ -394,19 +406,17 @@ class ApplicationServicesHandler:
|
||||||
|
|
||||||
return combined
|
return combined
|
||||||
|
|
||||||
for p in protocols.keys():
|
return {p: _merge_instances(protocols[p]) for p in protocols.keys()}
|
||||||
protocols[p] = _merge_instances(protocols[p])
|
|
||||||
|
|
||||||
return protocols
|
async def _get_services_for_event(
|
||||||
|
self, event: EventBase
|
||||||
async def _get_services_for_event(self, event):
|
) -> List[ApplicationService]:
|
||||||
"""Retrieve a list of application services interested in this event.
|
"""Retrieve a list of application services interested in this event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event(Event): The event to check. Can be None if alias_list is not.
|
event: The event to check. Can be None if alias_list is not.
|
||||||
Returns:
|
Returns:
|
||||||
list<ApplicationService>: A list of services interested in this
|
A list of services interested in this event based on the service regex.
|
||||||
event based on the service regex.
|
|
||||||
"""
|
"""
|
||||||
services = self.store.get_app_services()
|
services = self.store.get_app_services()
|
||||||
|
|
||||||
|
@ -420,17 +430,15 @@ class ApplicationServicesHandler:
|
||||||
|
|
||||||
return interested_list
|
return interested_list
|
||||||
|
|
||||||
def _get_services_for_user(self, user_id):
|
def _get_services_for_user(self, user_id: str) -> List[ApplicationService]:
|
||||||
services = self.store.get_app_services()
|
services = self.store.get_app_services()
|
||||||
interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
|
return [s for s in services if (s.is_interested_in_user(user_id))]
|
||||||
return interested_list
|
|
||||||
|
|
||||||
def _get_services_for_3pn(self, protocol):
|
def _get_services_for_3pn(self, protocol: str) -> List[ApplicationService]:
|
||||||
services = self.store.get_app_services()
|
services = self.store.get_app_services()
|
||||||
interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
|
return [s for s in services if s.is_interested_in_protocol(protocol)]
|
||||||
return interested_list
|
|
||||||
|
|
||||||
async def _is_unknown_user(self, user_id):
|
async def _is_unknown_user(self, user_id: str) -> bool:
|
||||||
if not self.is_mine_id(user_id):
|
if not self.is_mine_id(user_id):
|
||||||
# we don't know if they are unknown or not since it isn't one of our
|
# we don't know if they are unknown or not since it isn't one of our
|
||||||
# users. We can't poke ASes.
|
# users. We can't poke ASes.
|
||||||
|
@ -445,9 +453,8 @@ class ApplicationServicesHandler:
|
||||||
service_list = [s for s in services if s.sender == user_id]
|
service_list = [s for s in services if s.sender == user_id]
|
||||||
return len(service_list) == 0
|
return len(service_list) == 0
|
||||||
|
|
||||||
async def _check_user_exists(self, user_id):
|
async def _check_user_exists(self, user_id: str) -> bool:
|
||||||
unknown_user = await self._is_unknown_user(user_id)
|
unknown_user = await self._is_unknown_user(user_id)
|
||||||
if unknown_user:
|
if unknown_user:
|
||||||
exists = await self.query_user_exists(user_id)
|
return await self.query_user_exists(user_id)
|
||||||
return exists
|
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -18,10 +18,20 @@ import logging
|
||||||
import time
|
import time
|
||||||
import unicodedata
|
import unicodedata
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import bcrypt # type: ignore[import]
|
import bcrypt
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
|
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
|
@ -49,6 +59,9 @@ from synapse.util.threepids import canonicalise_email
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -149,11 +162,7 @@ class SsoLoginExtraAttributes:
|
||||||
class AuthHandler(BaseHandler):
|
class AuthHandler(BaseHandler):
|
||||||
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hs (synapse.server.HomeServer):
|
|
||||||
"""
|
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
|
self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
|
||||||
|
@ -982,17 +991,17 @@ class AuthHandler(BaseHandler):
|
||||||
# This might return an awaitable, if it does block the log out
|
# This might return an awaitable, if it does block the log out
|
||||||
# until it completes.
|
# until it completes.
|
||||||
result = provider.on_logged_out(
|
result = provider.on_logged_out(
|
||||||
user_id=str(user_info["user"]),
|
user_id=user_info.user_id,
|
||||||
device_id=user_info["device_id"],
|
device_id=user_info.device_id,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
if inspect.isawaitable(result):
|
if inspect.isawaitable(result):
|
||||||
await result
|
await result
|
||||||
|
|
||||||
# delete pushers associated with this access token
|
# delete pushers associated with this access token
|
||||||
if user_info["token_id"] is not None:
|
if user_info.token_id is not None:
|
||||||
await self.hs.get_pusherpool().remove_pushers_by_access_token(
|
await self.hs.get_pusherpool().remove_pushers_by_access_token(
|
||||||
str(user_info["user"]), (user_info["token_id"],)
|
user_info.user_id, (user_info.token_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_access_tokens_for_user(
|
async def delete_access_tokens_for_user(
|
||||||
|
|
|
@ -50,9 +50,8 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
||||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
|
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder, json_encoder
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.frozenutils import frozendict_json_encoder
|
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
|
|
||||||
|
@ -928,7 +927,7 @@ class EventCreationHandler:
|
||||||
|
|
||||||
# Ensure that we can round trip before trying to persist in db
|
# Ensure that we can round trip before trying to persist in db
|
||||||
try:
|
try:
|
||||||
dump = frozendict_json_encoder.encode(event.content)
|
dump = json_encoder.encode(event.content)
|
||||||
json_decoder.decode(dump)
|
json_decoder.decode(dump)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to encode content: %r", event.content)
|
logger.exception("Failed to encode content: %r", event.content)
|
||||||
|
@ -1100,34 +1099,13 @@ class EventCreationHandler:
|
||||||
|
|
||||||
if event.type == EventTypes.Member:
|
if event.type == EventTypes.Member:
|
||||||
if event.content["membership"] == Membership.INVITE:
|
if event.content["membership"] == Membership.INVITE:
|
||||||
|
event.unsigned[
|
||||||
def is_inviter_member_event(e):
|
"invite_room_state"
|
||||||
return e.type == EventTypes.Member and e.sender == event.sender
|
] = await self.store.get_stripped_room_state_from_event_context(
|
||||||
|
context,
|
||||||
current_state_ids = await context.get_current_state_ids()
|
self.room_invite_state_types,
|
||||||
|
membership_user_id=event.sender,
|
||||||
# We know this event is not an outlier, so this must be
|
)
|
||||||
# non-None.
|
|
||||||
assert current_state_ids is not None
|
|
||||||
|
|
||||||
state_to_include_ids = [
|
|
||||||
e_id
|
|
||||||
for k, e_id in current_state_ids.items()
|
|
||||||
if k[0] in self.room_invite_state_types
|
|
||||||
or k == (EventTypes.Member, event.sender)
|
|
||||||
]
|
|
||||||
|
|
||||||
state_to_include = await self.store.get_events(state_to_include_ids)
|
|
||||||
|
|
||||||
event.unsigned["invite_room_state"] = [
|
|
||||||
{
|
|
||||||
"type": e.type,
|
|
||||||
"state_key": e.state_key,
|
|
||||||
"content": e.content,
|
|
||||||
"sender": e.sender,
|
|
||||||
}
|
|
||||||
for e in state_to_include.values()
|
|
||||||
]
|
|
||||||
|
|
||||||
invitee = UserID.from_string(event.state_key)
|
invitee = UserID.from_string(event.state_key)
|
||||||
if not self.hs.is_mine(invitee):
|
if not self.hs.is_mine(invitee):
|
||||||
|
|
|
@ -48,7 +48,7 @@ from synapse.util.wheel_timer import WheelTimer
|
||||||
|
|
||||||
MYPY = False
|
MYPY = False
|
||||||
if MYPY:
|
if MYPY:
|
||||||
import synapse.server
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -101,7 +101,7 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
|
||||||
class BasePresenceHandler(abc.ABC):
|
class BasePresenceHandler(abc.ABC):
|
||||||
"""Parts of the PresenceHandler that are shared between workers and master"""
|
"""Parts of the PresenceHandler that are shared between workers and master"""
|
||||||
|
|
||||||
def __init__(self, hs: "synapse.server.HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
@ -199,7 +199,7 @@ class BasePresenceHandler(abc.ABC):
|
||||||
|
|
||||||
|
|
||||||
class PresenceHandler(BasePresenceHandler):
|
class PresenceHandler(BasePresenceHandler):
|
||||||
def __init__(self, hs: "synapse.server.HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
@ -1011,7 +1011,7 @@ def format_user_presence_state(state, now, include_user_id=True):
|
||||||
|
|
||||||
|
|
||||||
class PresenceEventSource:
|
class PresenceEventSource:
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
# We can't call get_presence_handler here because there's a cycle:
|
# We can't call get_presence_handler here because there's a cycle:
|
||||||
#
|
#
|
||||||
# Presence -> Notifier -> PresenceEventSource -> Presence
|
# Presence -> Notifier -> PresenceEventSource -> Presence
|
||||||
|
@ -1071,12 +1071,14 @@ class PresenceEventSource:
|
||||||
|
|
||||||
users_interested_in = await self._get_interested_in(user, explicit_room_id)
|
users_interested_in = await self._get_interested_in(user, explicit_room_id)
|
||||||
|
|
||||||
user_ids_changed = set()
|
user_ids_changed = set() # type: Collection[str]
|
||||||
changed = None
|
changed = None
|
||||||
if from_key:
|
if from_key:
|
||||||
changed = stream_change_cache.get_all_entities_changed(from_key)
|
changed = stream_change_cache.get_all_entities_changed(from_key)
|
||||||
|
|
||||||
if changed is not None and len(changed) < 500:
|
if changed is not None and len(changed) < 500:
|
||||||
|
assert isinstance(user_ids_changed, set)
|
||||||
|
|
||||||
# For small deltas, its quicker to get all changes and then
|
# For small deltas, its quicker to get all changes and then
|
||||||
# work out if we share a room or they're in our presence list
|
# work out if we share a room or they're in our presence list
|
||||||
get_updates_counter.labels("stream").inc()
|
get_updates_counter.labels("stream").inc()
|
||||||
|
|
|
@ -115,7 +115,10 @@ class RegistrationHandler(BaseHandler):
|
||||||
400, "User ID already taken.", errcode=Codes.USER_IN_USE
|
400, "User ID already taken.", errcode=Codes.USER_IN_USE
|
||||||
)
|
)
|
||||||
user_data = await self.auth.get_user_by_access_token(guest_access_token)
|
user_data = await self.auth.get_user_by_access_token(guest_access_token)
|
||||||
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
|
if (
|
||||||
|
not user_data.is_guest
|
||||||
|
or UserID.from_string(user_data.user_id).localpart != localpart
|
||||||
|
):
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403,
|
403,
|
||||||
"Cannot register taken user ID without valid guest "
|
"Cannot register taken user ID without valid guest "
|
||||||
|
@ -741,7 +744,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
# up when the access token is saved, but that's quite an
|
# up when the access token is saved, but that's quite an
|
||||||
# invasive change I'd rather do separately.
|
# invasive change I'd rather do separately.
|
||||||
user_tuple = await self.store.get_user_by_access_token(token)
|
user_tuple = await self.store.get_user_by_access_token(token)
|
||||||
token_id = user_tuple["token_id"]
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
await self.pusher_pool.add_pusher(
|
await self.pusher_pool.add_pusher(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
|
@ -771,22 +771,29 @@ class RoomCreationHandler(BaseHandler):
|
||||||
ratelimit=False,
|
ratelimit=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
for invitee in invite_list:
|
# we avoid dropping the lock between invites, as otherwise joins can
|
||||||
|
# start coming in and making the createRoom slow.
|
||||||
|
#
|
||||||
|
# we also don't need to check the requester's shadow-ban here, as we
|
||||||
|
# have already done so above (and potentially emptied invite_list).
|
||||||
|
with (await self.room_member_handler.member_linearizer.queue((room_id,))):
|
||||||
content = {}
|
content = {}
|
||||||
is_direct = config.get("is_direct", None)
|
is_direct = config.get("is_direct", None)
|
||||||
if is_direct:
|
if is_direct:
|
||||||
content["is_direct"] = is_direct
|
content["is_direct"] = is_direct
|
||||||
|
|
||||||
# Note that update_membership with an action of "invite" can raise a
|
for invitee in invite_list:
|
||||||
# ShadowBanError, but this was handled above by emptying invite_list.
|
(
|
||||||
_, last_stream_id = await self.room_member_handler.update_membership(
|
_,
|
||||||
requester,
|
last_stream_id,
|
||||||
UserID.from_string(invitee),
|
) = await self.room_member_handler.update_membership_locked(
|
||||||
room_id,
|
requester,
|
||||||
"invite",
|
UserID.from_string(invitee),
|
||||||
ratelimit=False,
|
room_id,
|
||||||
content=content,
|
"invite",
|
||||||
)
|
ratelimit=False,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
for invite_3pid in invite_3pid_list:
|
for invite_3pid in invite_3pid_list:
|
||||||
id_server = invite_3pid["id_server"]
|
id_server = invite_3pid["id_server"]
|
||||||
|
|
|
@ -327,7 +327,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
# haproxy would have timed the request out anyway...
|
# haproxy would have timed the request out anyway...
|
||||||
raise SynapseError(504, "took to long to process")
|
raise SynapseError(504, "took to long to process")
|
||||||
|
|
||||||
result = await self._update_membership(
|
result = await self.update_membership_locked(
|
||||||
requester,
|
requester,
|
||||||
target,
|
target,
|
||||||
room_id,
|
room_id,
|
||||||
|
@ -342,7 +342,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _update_membership(
|
async def update_membership_locked(
|
||||||
self,
|
self,
|
||||||
requester: Requester,
|
requester: Requester,
|
||||||
target: UserID,
|
target: UserID,
|
||||||
|
@ -355,6 +355,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
content: Optional[dict] = None,
|
content: Optional[dict] = None,
|
||||||
require_consent: bool = True,
|
require_consent: bool = True,
|
||||||
) -> Tuple[str, int]:
|
) -> Tuple[str, int]:
|
||||||
|
"""Helper for update_membership.
|
||||||
|
|
||||||
|
Assumes that the membership linearizer is already held for the room.
|
||||||
|
"""
|
||||||
content_specified = bool(content)
|
content_specified = bool(content)
|
||||||
if content is None:
|
if content is None:
|
||||||
content = {}
|
content = {}
|
||||||
|
|
|
@ -359,7 +359,7 @@ class SimpleHttpClient:
|
||||||
agent=self.agent,
|
agent=self.agent,
|
||||||
data=body_producer,
|
data=body_producer,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
**self._extra_treq_args
|
**self._extra_treq_args,
|
||||||
) # type: defer.Deferred
|
) # type: defer.Deferred
|
||||||
|
|
||||||
# we use our own timeout mechanism rather than treq's as a workaround
|
# we use our own timeout mechanism rather than treq's as a workaround
|
||||||
|
|
|
@ -35,8 +35,6 @@ from twisted.web.server import NOT_DONE_YET, Request
|
||||||
from twisted.web.static import File, NoRangeStaticProducer
|
from twisted.web.static import File, NoRangeStaticProducer
|
||||||
from twisted.web.util import redirectTo
|
from twisted.web.util import redirectTo
|
||||||
|
|
||||||
import synapse.events
|
|
||||||
import synapse.metrics
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
CodeMessageException,
|
CodeMessageException,
|
||||||
Codes,
|
Codes,
|
||||||
|
@ -620,7 +618,7 @@ def respond_with_json(
|
||||||
if pretty_print:
|
if pretty_print:
|
||||||
encoder = iterencode_pretty_printed_json
|
encoder = iterencode_pretty_printed_json
|
||||||
else:
|
else:
|
||||||
if canonical_json or synapse.events.USE_FROZEN_DICTS:
|
if canonical_json:
|
||||||
encoder = iterencode_canonical_json
|
encoder = iterencode_canonical_json
|
||||||
else:
|
else:
|
||||||
encoder = _encode_json_bytes
|
encoder = _encode_json_bytes
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.web.server import Request, Site
|
from twisted.web.server import Request, Site
|
||||||
|
@ -23,6 +23,7 @@ from synapse.config.server import ListenerConfig
|
||||||
from synapse.http import redact_uri
|
from synapse.http import redact_uri
|
||||||
from synapse.http.request_metrics import RequestMetrics, requests_counter
|
from synapse.http.request_metrics import RequestMetrics, requests_counter
|
||||||
from synapse.logging.context import LoggingContext, PreserveLoggingContext
|
from synapse.logging.context import LoggingContext, PreserveLoggingContext
|
||||||
|
from synapse.types import Requester
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -54,9 +55,12 @@ class SynapseRequest(Request):
|
||||||
Request.__init__(self, channel, *args, **kw)
|
Request.__init__(self, channel, *args, **kw)
|
||||||
self.site = channel.site
|
self.site = channel.site
|
||||||
self._channel = channel # this is used by the tests
|
self._channel = channel # this is used by the tests
|
||||||
self.authenticated_entity = None
|
|
||||||
self.start_time = 0.0
|
self.start_time = 0.0
|
||||||
|
|
||||||
|
# The requester, if authenticated. For federation requests this is the
|
||||||
|
# server name, for client requests this is the Requester object.
|
||||||
|
self.requester = None # type: Optional[Union[Requester, str]]
|
||||||
|
|
||||||
# we can't yet create the logcontext, as we don't know the method.
|
# we can't yet create the logcontext, as we don't know the method.
|
||||||
self.logcontext = None # type: Optional[LoggingContext]
|
self.logcontext = None # type: Optional[LoggingContext]
|
||||||
|
|
||||||
|
@ -271,11 +275,23 @@ class SynapseRequest(Request):
|
||||||
# to the client (nb may be negative)
|
# to the client (nb may be negative)
|
||||||
response_send_time = self.finish_time - self._processing_finished_time
|
response_send_time = self.finish_time - self._processing_finished_time
|
||||||
|
|
||||||
# need to decode as it could be raw utf-8 bytes
|
# Convert the requester into a string that we can log
|
||||||
# from a IDN servname in an auth header
|
authenticated_entity = None
|
||||||
authenticated_entity = self.authenticated_entity
|
if isinstance(self.requester, str):
|
||||||
if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
|
authenticated_entity = self.requester
|
||||||
authenticated_entity = authenticated_entity.decode("utf-8", "replace")
|
elif isinstance(self.requester, Requester):
|
||||||
|
authenticated_entity = self.requester.authenticated_entity
|
||||||
|
|
||||||
|
# If this is a request where the target user doesn't match the user who
|
||||||
|
# authenticated (e.g. and admin is puppetting a user) then we log both.
|
||||||
|
if self.requester.user.to_string() != authenticated_entity:
|
||||||
|
authenticated_entity = "{},{}".format(
|
||||||
|
authenticated_entity, self.requester.user.to_string(),
|
||||||
|
)
|
||||||
|
elif self.requester is not None:
|
||||||
|
# This shouldn't happen, but we log it so we don't lose information
|
||||||
|
# and can see that we're doing something wrong.
|
||||||
|
authenticated_entity = repr(self.requester) # type: ignore[unreachable]
|
||||||
|
|
||||||
# ...or could be raw utf-8 bytes in the User-Agent header.
|
# ...or could be raw utf-8 bytes in the User-Agent header.
|
||||||
# N.B. if you don't do this, the logger explodes cryptically
|
# N.B. if you don't do this, the logger explodes cryptically
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# These are imported to allow for nicer logging configuration files.
|
||||||
|
from synapse.logging._remote import RemoteHandler
|
||||||
|
from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
|
||||||
|
|
||||||
|
__all__ = ["RemoteHandler", "JsonFormatter", "TerseJsonFormatter"]
|
|
@ -13,6 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
@ -21,10 +22,11 @@ from math import floor
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
from typing_extensions import Deque
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
|
|
||||||
from twisted.application.internet import ClientService
|
from twisted.application.internet import ClientService
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import CancelledError, Deferred
|
||||||
from twisted.internet.endpoints import (
|
from twisted.internet.endpoints import (
|
||||||
HostnameEndpoint,
|
HostnameEndpoint,
|
||||||
TCP4ClientEndpoint,
|
TCP4ClientEndpoint,
|
||||||
|
@ -32,7 +34,9 @@ from twisted.internet.endpoints import (
|
||||||
)
|
)
|
||||||
from twisted.internet.interfaces import IPushProducer, ITransport
|
from twisted.internet.interfaces import IPushProducer, ITransport
|
||||||
from twisted.internet.protocol import Factory, Protocol
|
from twisted.internet.protocol import Factory, Protocol
|
||||||
from twisted.logger import ILogObserver, Logger, LogLevel
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s
|
||||||
|
@ -45,11 +49,11 @@ class LogProducer:
|
||||||
Args:
|
Args:
|
||||||
buffer: Log buffer to read logs from.
|
buffer: Log buffer to read logs from.
|
||||||
transport: Transport to write to.
|
transport: Transport to write to.
|
||||||
format_event: A callable to format the log entry to a string.
|
format: A callable to format the log record to a string.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
transport = attr.ib(type=ITransport)
|
transport = attr.ib(type=ITransport)
|
||||||
format_event = attr.ib(type=Callable[[dict], str])
|
_format = attr.ib(type=Callable[[logging.LogRecord], str])
|
||||||
_buffer = attr.ib(type=deque)
|
_buffer = attr.ib(type=deque)
|
||||||
_paused = attr.ib(default=False, type=bool, init=False)
|
_paused = attr.ib(default=False, type=bool, init=False)
|
||||||
|
|
||||||
|
@ -61,16 +65,19 @@ class LogProducer:
|
||||||
self._buffer = deque()
|
self._buffer = deque()
|
||||||
|
|
||||||
def resumeProducing(self):
|
def resumeProducing(self):
|
||||||
|
# If we're already producing, nothing to do.
|
||||||
self._paused = False
|
self._paused = False
|
||||||
|
|
||||||
|
# Loop until paused.
|
||||||
while self._paused is False and (self._buffer and self.transport.connected):
|
while self._paused is False and (self._buffer and self.transport.connected):
|
||||||
try:
|
try:
|
||||||
# Request the next event and format it.
|
# Request the next record and format it.
|
||||||
event = self._buffer.popleft()
|
record = self._buffer.popleft()
|
||||||
msg = self.format_event(event)
|
msg = self._format(record)
|
||||||
|
|
||||||
# Send it as a new line over the transport.
|
# Send it as a new line over the transport.
|
||||||
self.transport.write(msg.encode("utf8"))
|
self.transport.write(msg.encode("utf8"))
|
||||||
|
self.transport.write(b"\n")
|
||||||
except Exception:
|
except Exception:
|
||||||
# Something has gone wrong writing to the transport -- log it
|
# Something has gone wrong writing to the transport -- log it
|
||||||
# and break out of the while.
|
# and break out of the while.
|
||||||
|
@ -78,76 +85,85 @@ class LogProducer:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
class RemoteHandler(logging.Handler):
|
||||||
@implementer(ILogObserver)
|
|
||||||
class TCPLogObserver:
|
|
||||||
"""
|
"""
|
||||||
An IObserver that writes JSON logs to a TCP target.
|
An logging handler that writes logs to a TCP target.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hs (HomeServer): The homeserver that is being logged for.
|
|
||||||
host: The host of the logging target.
|
host: The host of the logging target.
|
||||||
port: The logging target's port.
|
port: The logging target's port.
|
||||||
format_event: A callable to format the log entry to a string.
|
|
||||||
maximum_buffer: The maximum buffer size.
|
maximum_buffer: The maximum buffer size.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
hs = attr.ib()
|
def __init__(
|
||||||
host = attr.ib(type=str)
|
self,
|
||||||
port = attr.ib(type=int)
|
host: str,
|
||||||
format_event = attr.ib(type=Callable[[dict], str])
|
port: int,
|
||||||
maximum_buffer = attr.ib(type=int)
|
maximum_buffer: int = 1000,
|
||||||
_buffer = attr.ib(default=attr.Factory(deque), type=deque)
|
level=logging.NOTSET,
|
||||||
_connection_waiter = attr.ib(default=None, type=Optional[Deferred])
|
_reactor=None,
|
||||||
_logger = attr.ib(default=attr.Factory(Logger))
|
):
|
||||||
_producer = attr.ib(default=None, type=Optional[LogProducer])
|
super().__init__(level=level)
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.maximum_buffer = maximum_buffer
|
||||||
|
|
||||||
def start(self) -> None:
|
self._buffer = deque() # type: Deque[logging.LogRecord]
|
||||||
|
self._connection_waiter = None # type: Optional[Deferred]
|
||||||
|
self._producer = None # type: Optional[LogProducer]
|
||||||
|
|
||||||
# Connect without DNS lookups if it's a direct IP.
|
# Connect without DNS lookups if it's a direct IP.
|
||||||
|
if _reactor is None:
|
||||||
|
from twisted.internet import reactor
|
||||||
|
|
||||||
|
_reactor = reactor
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ip = ip_address(self.host)
|
ip = ip_address(self.host)
|
||||||
if isinstance(ip, IPv4Address):
|
if isinstance(ip, IPv4Address):
|
||||||
endpoint = TCP4ClientEndpoint(
|
endpoint = TCP4ClientEndpoint(_reactor, self.host, self.port)
|
||||||
self.hs.get_reactor(), self.host, self.port
|
|
||||||
)
|
|
||||||
elif isinstance(ip, IPv6Address):
|
elif isinstance(ip, IPv6Address):
|
||||||
endpoint = TCP6ClientEndpoint(
|
endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port)
|
||||||
self.hs.get_reactor(), self.host, self.port
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown IP address provided: %s" % (self.host,))
|
raise ValueError("Unknown IP address provided: %s" % (self.host,))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
|
endpoint = HostnameEndpoint(_reactor, self.host, self.port)
|
||||||
|
|
||||||
factory = Factory.forProtocol(Protocol)
|
factory = Factory.forProtocol(Protocol)
|
||||||
self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
|
self._service = ClientService(endpoint, factory, clock=_reactor)
|
||||||
self._service.startService()
|
self._service.startService()
|
||||||
|
self._stopping = False
|
||||||
self._connect()
|
self._connect()
|
||||||
|
|
||||||
def stop(self):
|
def close(self):
|
||||||
|
self._stopping = True
|
||||||
self._service.stopService()
|
self._service.stopService()
|
||||||
|
|
||||||
def _connect(self) -> None:
|
def _connect(self) -> None:
|
||||||
"""
|
"""
|
||||||
Triggers an attempt to connect then write to the remote if not already writing.
|
Triggers an attempt to connect then write to the remote if not already writing.
|
||||||
"""
|
"""
|
||||||
|
# Do not attempt to open multiple connections.
|
||||||
if self._connection_waiter:
|
if self._connection_waiter:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
|
self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
|
||||||
|
|
||||||
@self._connection_waiter.addErrback
|
def fail(failure: Failure) -> None:
|
||||||
def fail(r):
|
# If the Deferred was cancelled (e.g. during shutdown) do not try to
|
||||||
r.printTraceback(file=sys.__stderr__)
|
# reconnect (this will cause an infinite loop of errors).
|
||||||
|
if failure.check(CancelledError) and self._stopping:
|
||||||
|
return
|
||||||
|
|
||||||
|
# For a different error, print the traceback and re-connect.
|
||||||
|
failure.printTraceback(file=sys.__stderr__)
|
||||||
self._connection_waiter = None
|
self._connection_waiter = None
|
||||||
self._connect()
|
self._connect()
|
||||||
|
|
||||||
@self._connection_waiter.addCallback
|
def writer(result: Protocol) -> None:
|
||||||
def writer(r):
|
|
||||||
# We have a connection. If we already have a producer, and its
|
# We have a connection. If we already have a producer, and its
|
||||||
# transport is the same, just trigger a resumeProducing.
|
# transport is the same, just trigger a resumeProducing.
|
||||||
if self._producer and r.transport is self._producer.transport:
|
if self._producer and result.transport is self._producer.transport:
|
||||||
self._producer.resumeProducing()
|
self._producer.resumeProducing()
|
||||||
self._connection_waiter = None
|
self._connection_waiter = None
|
||||||
return
|
return
|
||||||
|
@ -158,29 +174,29 @@ class TCPLogObserver:
|
||||||
|
|
||||||
# Make a new producer and start it.
|
# Make a new producer and start it.
|
||||||
self._producer = LogProducer(
|
self._producer = LogProducer(
|
||||||
buffer=self._buffer,
|
buffer=self._buffer, transport=result.transport, format=self.format,
|
||||||
transport=r.transport,
|
|
||||||
format_event=self.format_event,
|
|
||||||
)
|
)
|
||||||
r.transport.registerProducer(self._producer, True)
|
result.transport.registerProducer(self._producer, True)
|
||||||
self._producer.resumeProducing()
|
self._producer.resumeProducing()
|
||||||
self._connection_waiter = None
|
self._connection_waiter = None
|
||||||
|
|
||||||
|
self._connection_waiter.addCallbacks(writer, fail)
|
||||||
|
|
||||||
def _handle_pressure(self) -> None:
|
def _handle_pressure(self) -> None:
|
||||||
"""
|
"""
|
||||||
Handle backpressure by shedding events.
|
Handle backpressure by shedding records.
|
||||||
|
|
||||||
The buffer will, in this order, until the buffer is below the maximum:
|
The buffer will, in this order, until the buffer is below the maximum:
|
||||||
- Shed DEBUG events
|
- Shed DEBUG records.
|
||||||
- Shed INFO events
|
- Shed INFO records.
|
||||||
- Shed the middle 50% of the events.
|
- Shed the middle 50% of the records.
|
||||||
"""
|
"""
|
||||||
if len(self._buffer) <= self.maximum_buffer:
|
if len(self._buffer) <= self.maximum_buffer:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Strip out DEBUGs
|
# Strip out DEBUGs
|
||||||
self._buffer = deque(
|
self._buffer = deque(
|
||||||
filter(lambda event: event["log_level"] != LogLevel.debug, self._buffer)
|
filter(lambda record: record.levelno > logging.DEBUG, self._buffer)
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(self._buffer) <= self.maximum_buffer:
|
if len(self._buffer) <= self.maximum_buffer:
|
||||||
|
@ -188,7 +204,7 @@ class TCPLogObserver:
|
||||||
|
|
||||||
# Strip out INFOs
|
# Strip out INFOs
|
||||||
self._buffer = deque(
|
self._buffer = deque(
|
||||||
filter(lambda event: event["log_level"] != LogLevel.info, self._buffer)
|
filter(lambda record: record.levelno > logging.INFO, self._buffer)
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(self._buffer) <= self.maximum_buffer:
|
if len(self._buffer) <= self.maximum_buffer:
|
||||||
|
@ -209,17 +225,17 @@ class TCPLogObserver:
|
||||||
|
|
||||||
self._buffer.extend(reversed(end_buffer))
|
self._buffer.extend(reversed(end_buffer))
|
||||||
|
|
||||||
def __call__(self, event: dict) -> None:
|
def emit(self, record: logging.LogRecord) -> None:
|
||||||
self._buffer.append(event)
|
self._buffer.append(record)
|
||||||
|
|
||||||
# Handle backpressure, if it exists.
|
# Handle backpressure, if it exists.
|
||||||
try:
|
try:
|
||||||
self._handle_pressure()
|
self._handle_pressure()
|
||||||
except Exception:
|
except Exception:
|
||||||
# If handling backpressure fails,clear the buffer and log the
|
# If handling backpressure fails, clear the buffer and log the
|
||||||
# exception.
|
# exception.
|
||||||
self._buffer.clear()
|
self._buffer.clear()
|
||||||
self._logger.failure("Failed clearing backpressure")
|
logger.warning("Failed clearing backpressure")
|
||||||
|
|
||||||
# Try and write immediately.
|
# Try and write immediately.
|
||||||
self._connect()
|
self._connect()
|
||||||
|
|
|
@ -12,138 +12,12 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
|
||||||
import os.path
|
import os.path
|
||||||
import sys
|
from typing import Any, Dict, Generator, Optional, Tuple
|
||||||
import typing
|
|
||||||
import warnings
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import attr
|
from constantly import NamedConstant, Names
|
||||||
from constantly import NamedConstant, Names, ValueConstant, Values
|
|
||||||
from zope.interface import implementer
|
|
||||||
|
|
||||||
from twisted.logger import (
|
|
||||||
FileLogObserver,
|
|
||||||
FilteringLogObserver,
|
|
||||||
ILogObserver,
|
|
||||||
LogBeginner,
|
|
||||||
Logger,
|
|
||||||
LogLevel,
|
|
||||||
LogLevelFilterPredicate,
|
|
||||||
LogPublisher,
|
|
||||||
eventAsText,
|
|
||||||
jsonFileLogObserver,
|
|
||||||
)
|
|
||||||
|
|
||||||
from synapse.config._base import ConfigError
|
from synapse.config._base import ConfigError
|
||||||
from synapse.logging._terse_json import (
|
|
||||||
TerseJSONToConsoleLogObserver,
|
|
||||||
TerseJSONToTCPLogObserver,
|
|
||||||
)
|
|
||||||
from synapse.logging.context import current_context
|
|
||||||
|
|
||||||
|
|
||||||
def stdlib_log_level_to_twisted(level: str) -> LogLevel:
|
|
||||||
"""
|
|
||||||
Convert a stdlib log level to Twisted's log level.
|
|
||||||
"""
|
|
||||||
lvl = level.lower().replace("warning", "warn")
|
|
||||||
return LogLevel.levelWithName(lvl)
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
|
||||||
@implementer(ILogObserver)
|
|
||||||
class LogContextObserver:
|
|
||||||
"""
|
|
||||||
An ILogObserver which adds Synapse-specific log context information.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
observer (ILogObserver): The target parent observer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
observer = attr.ib()
|
|
||||||
|
|
||||||
def __call__(self, event: dict) -> None:
|
|
||||||
"""
|
|
||||||
Consume a log event and emit it to the parent observer after filtering
|
|
||||||
and adding log context information.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event (dict)
|
|
||||||
"""
|
|
||||||
# Filter out some useless events that Twisted outputs
|
|
||||||
if "log_text" in event:
|
|
||||||
if event["log_text"].startswith("DNSDatagramProtocol starting on "):
|
|
||||||
return
|
|
||||||
|
|
||||||
if event["log_text"].startswith("(UDP Port "):
|
|
||||||
return
|
|
||||||
|
|
||||||
if event["log_text"].startswith("Timing out client") or event[
|
|
||||||
"log_format"
|
|
||||||
].startswith("Timing out client"):
|
|
||||||
return
|
|
||||||
|
|
||||||
context = current_context()
|
|
||||||
|
|
||||||
# Copy the context information to the log event.
|
|
||||||
context.copy_to_twisted_log_entry(event)
|
|
||||||
|
|
||||||
self.observer(event)
|
|
||||||
|
|
||||||
|
|
||||||
class PythonStdlibToTwistedLogger(logging.Handler):
|
|
||||||
"""
|
|
||||||
Transform a Python stdlib log message into a Twisted one.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, observer, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
observer (ILogObserver): A Twisted logging observer.
|
|
||||||
*args, **kwargs: Args/kwargs to be passed to logging.Handler.
|
|
||||||
"""
|
|
||||||
self.observer = observer
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def emit(self, record: logging.LogRecord) -> None:
|
|
||||||
"""
|
|
||||||
Emit a record to Twisted's observer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
record (logging.LogRecord)
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.observer(
|
|
||||||
{
|
|
||||||
"log_time": record.created,
|
|
||||||
"log_text": record.getMessage(),
|
|
||||||
"log_format": "{log_text}",
|
|
||||||
"log_namespace": record.name,
|
|
||||||
"log_level": stdlib_log_level_to_twisted(record.levelname),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def SynapseFileLogObserver(outFile: typing.IO[str]) -> FileLogObserver:
|
|
||||||
"""
|
|
||||||
A log observer that formats events like the traditional log formatter and
|
|
||||||
sends them to `outFile`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
outFile (file object): The file object to write to.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def formatEvent(_event: dict) -> str:
|
|
||||||
event = dict(_event)
|
|
||||||
event["log_level"] = event["log_level"].name.upper()
|
|
||||||
event["log_format"] = "- {log_namespace} - {log_level} - {request} - " + (
|
|
||||||
event.get("log_format", "{log_text}") or "{log_text}"
|
|
||||||
)
|
|
||||||
return eventAsText(event, includeSystem=False) + "\n"
|
|
||||||
|
|
||||||
return FileLogObserver(outFile, formatEvent)
|
|
||||||
|
|
||||||
|
|
||||||
class DrainType(Names):
|
class DrainType(Names):
|
||||||
|
@ -155,30 +29,12 @@ class DrainType(Names):
|
||||||
NETWORK_JSON_TERSE = NamedConstant()
|
NETWORK_JSON_TERSE = NamedConstant()
|
||||||
|
|
||||||
|
|
||||||
class OutputPipeType(Values):
|
DEFAULT_LOGGERS = {"synapse": {"level": "info"}}
|
||||||
stdout = ValueConstant(sys.__stdout__)
|
|
||||||
stderr = ValueConstant(sys.__stderr__)
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
|
||||||
class DrainConfiguration:
|
|
||||||
name = attr.ib()
|
|
||||||
type = attr.ib()
|
|
||||||
location = attr.ib()
|
|
||||||
options = attr.ib(default=None)
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
|
||||||
class NetworkJSONTerseOptions:
|
|
||||||
maximum_buffer = attr.ib(type=int)
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}}
|
|
||||||
|
|
||||||
|
|
||||||
def parse_drain_configs(
|
def parse_drain_configs(
|
||||||
drains: dict,
|
drains: dict,
|
||||||
) -> typing.Generator[DrainConfiguration, None, None]:
|
) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||||
"""
|
"""
|
||||||
Parse the drain configurations.
|
Parse the drain configurations.
|
||||||
|
|
||||||
|
@ -186,11 +42,12 @@ def parse_drain_configs(
|
||||||
drains (dict): A list of drain configurations.
|
drains (dict): A list of drain configurations.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
DrainConfiguration instances.
|
dict instances representing a logging handler.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ConfigError: If any of the drain configuration items are invalid.
|
ConfigError: If any of the drain configuration items are invalid.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for name, config in drains.items():
|
for name, config in drains.items():
|
||||||
if "type" not in config:
|
if "type" not in config:
|
||||||
raise ConfigError("Logging drains require a 'type' key.")
|
raise ConfigError("Logging drains require a 'type' key.")
|
||||||
|
@ -202,6 +59,18 @@ def parse_drain_configs(
|
||||||
"%s is not a known logging drain type." % (config["type"],)
|
"%s is not a known logging drain type." % (config["type"],)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Either use the default formatter or the tersejson one.
|
||||||
|
if logging_type in (DrainType.CONSOLE_JSON, DrainType.FILE_JSON,):
|
||||||
|
formatter = "json" # type: Optional[str]
|
||||||
|
elif logging_type in (
|
||||||
|
DrainType.CONSOLE_JSON_TERSE,
|
||||||
|
DrainType.NETWORK_JSON_TERSE,
|
||||||
|
):
|
||||||
|
formatter = "tersejson"
|
||||||
|
else:
|
||||||
|
# A formatter of None implies using the default formatter.
|
||||||
|
formatter = None
|
||||||
|
|
||||||
if logging_type in [
|
if logging_type in [
|
||||||
DrainType.CONSOLE,
|
DrainType.CONSOLE,
|
||||||
DrainType.CONSOLE_JSON,
|
DrainType.CONSOLE_JSON,
|
||||||
|
@ -217,9 +86,11 @@ def parse_drain_configs(
|
||||||
% (logging_type,)
|
% (logging_type,)
|
||||||
)
|
)
|
||||||
|
|
||||||
pipe = OutputPipeType.lookupByName(location).value
|
yield name, {
|
||||||
|
"class": "logging.StreamHandler",
|
||||||
yield DrainConfiguration(name=name, type=logging_type, location=pipe)
|
"formatter": formatter,
|
||||||
|
"stream": "ext://sys." + location,
|
||||||
|
}
|
||||||
|
|
||||||
elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]:
|
elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]:
|
||||||
if "location" not in config:
|
if "location" not in config:
|
||||||
|
@ -233,18 +104,25 @@ def parse_drain_configs(
|
||||||
"File paths need to be absolute, '%s' is a relative path"
|
"File paths need to be absolute, '%s' is a relative path"
|
||||||
% (location,)
|
% (location,)
|
||||||
)
|
)
|
||||||
yield DrainConfiguration(name=name, type=logging_type, location=location)
|
|
||||||
|
yield name, {
|
||||||
|
"class": "logging.FileHandler",
|
||||||
|
"formatter": formatter,
|
||||||
|
"filename": location,
|
||||||
|
}
|
||||||
|
|
||||||
elif logging_type in [DrainType.NETWORK_JSON_TERSE]:
|
elif logging_type in [DrainType.NETWORK_JSON_TERSE]:
|
||||||
host = config.get("host")
|
host = config.get("host")
|
||||||
port = config.get("port")
|
port = config.get("port")
|
||||||
maximum_buffer = config.get("maximum_buffer", 1000)
|
maximum_buffer = config.get("maximum_buffer", 1000)
|
||||||
yield DrainConfiguration(
|
|
||||||
name=name,
|
yield name, {
|
||||||
type=logging_type,
|
"class": "synapse.logging.RemoteHandler",
|
||||||
location=(host, port),
|
"formatter": formatter,
|
||||||
options=NetworkJSONTerseOptions(maximum_buffer=maximum_buffer),
|
"host": host,
|
||||||
)
|
"port": port,
|
||||||
|
"maximum_buffer": maximum_buffer,
|
||||||
|
}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ConfigError(
|
raise ConfigError(
|
||||||
|
@ -253,126 +131,29 @@ def parse_drain_configs(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class StoppableLogPublisher(LogPublisher):
|
def setup_structured_logging(log_config: dict,) -> dict:
|
||||||
"""
|
"""
|
||||||
A log publisher that can tell its observers to shut down any external
|
Convert a legacy structured logging configuration (from Synapse < v1.23.0)
|
||||||
communications.
|
to one compatible with the new standard library handlers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
for obs in self._observers:
|
|
||||||
if hasattr(obs, "stop"):
|
|
||||||
obs.stop()
|
|
||||||
|
|
||||||
|
|
||||||
def setup_structured_logging(
|
|
||||||
hs,
|
|
||||||
config,
|
|
||||||
log_config: dict,
|
|
||||||
logBeginner: LogBeginner,
|
|
||||||
redirect_stdlib_logging: bool = True,
|
|
||||||
) -> LogPublisher:
|
|
||||||
"""
|
|
||||||
Set up Twisted's structured logging system.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hs: The homeserver to use.
|
|
||||||
config (HomeserverConfig): The configuration of the Synapse homeserver.
|
|
||||||
log_config (dict): The log configuration to use.
|
|
||||||
"""
|
|
||||||
if config.no_redirect_stdio:
|
|
||||||
raise ConfigError(
|
|
||||||
"no_redirect_stdio cannot be defined using structured logging."
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = Logger()
|
|
||||||
|
|
||||||
if "drains" not in log_config:
|
if "drains" not in log_config:
|
||||||
raise ConfigError("The logging configuration requires a list of drains.")
|
raise ConfigError("The logging configuration requires a list of drains.")
|
||||||
|
|
||||||
observers = [] # type: List[ILogObserver]
|
new_config = {
|
||||||
|
"version": 1,
|
||||||
|
"formatters": {
|
||||||
|
"json": {"class": "synapse.logging.JsonFormatter"},
|
||||||
|
"tersejson": {"class": "synapse.logging.TerseJsonFormatter"},
|
||||||
|
},
|
||||||
|
"handlers": {},
|
||||||
|
"loggers": log_config.get("loggers", DEFAULT_LOGGERS),
|
||||||
|
"root": {"handlers": []},
|
||||||
|
}
|
||||||
|
|
||||||
for observer in parse_drain_configs(log_config["drains"]):
|
for handler_name, handler in parse_drain_configs(log_config["drains"]):
|
||||||
# Pipe drains
|
new_config["handlers"][handler_name] = handler
|
||||||
if observer.type == DrainType.CONSOLE:
|
|
||||||
logger.debug(
|
|
||||||
"Starting up the {name} console logger drain", name=observer.name
|
|
||||||
)
|
|
||||||
observers.append(SynapseFileLogObserver(observer.location))
|
|
||||||
elif observer.type == DrainType.CONSOLE_JSON:
|
|
||||||
logger.debug(
|
|
||||||
"Starting up the {name} JSON console logger drain", name=observer.name
|
|
||||||
)
|
|
||||||
observers.append(jsonFileLogObserver(observer.location))
|
|
||||||
elif observer.type == DrainType.CONSOLE_JSON_TERSE:
|
|
||||||
logger.debug(
|
|
||||||
"Starting up the {name} terse JSON console logger drain",
|
|
||||||
name=observer.name,
|
|
||||||
)
|
|
||||||
observers.append(
|
|
||||||
TerseJSONToConsoleLogObserver(observer.location, metadata={})
|
|
||||||
)
|
|
||||||
|
|
||||||
# File drains
|
# Add each handler to the root logger.
|
||||||
elif observer.type == DrainType.FILE:
|
new_config["root"]["handlers"].append(handler_name)
|
||||||
logger.debug("Starting up the {name} file logger drain", name=observer.name)
|
|
||||||
log_file = open(observer.location, "at", buffering=1, encoding="utf8")
|
|
||||||
observers.append(SynapseFileLogObserver(log_file))
|
|
||||||
elif observer.type == DrainType.FILE_JSON:
|
|
||||||
logger.debug(
|
|
||||||
"Starting up the {name} JSON file logger drain", name=observer.name
|
|
||||||
)
|
|
||||||
log_file = open(observer.location, "at", buffering=1, encoding="utf8")
|
|
||||||
observers.append(jsonFileLogObserver(log_file))
|
|
||||||
|
|
||||||
elif observer.type == DrainType.NETWORK_JSON_TERSE:
|
return new_config
|
||||||
metadata = {"server_name": hs.config.server_name}
|
|
||||||
log_observer = TerseJSONToTCPLogObserver(
|
|
||||||
hs=hs,
|
|
||||||
host=observer.location[0],
|
|
||||||
port=observer.location[1],
|
|
||||||
metadata=metadata,
|
|
||||||
maximum_buffer=observer.options.maximum_buffer,
|
|
||||||
)
|
|
||||||
log_observer.start()
|
|
||||||
observers.append(log_observer)
|
|
||||||
else:
|
|
||||||
# We should never get here, but, just in case, throw an error.
|
|
||||||
raise ConfigError("%s drain type cannot be configured" % (observer.type,))
|
|
||||||
|
|
||||||
publisher = StoppableLogPublisher(*observers)
|
|
||||||
log_filter = LogLevelFilterPredicate()
|
|
||||||
|
|
||||||
for namespace, namespace_config in log_config.get(
|
|
||||||
"loggers", DEFAULT_LOGGERS
|
|
||||||
).items():
|
|
||||||
# Set the log level for twisted.logger.Logger namespaces
|
|
||||||
log_filter.setLogLevelForNamespace(
|
|
||||||
namespace,
|
|
||||||
stdlib_log_level_to_twisted(namespace_config.get("level", "INFO")),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Also set the log levels for the stdlib logger namespaces, to prevent
|
|
||||||
# them getting to PythonStdlibToTwistedLogger and having to be formatted
|
|
||||||
if "level" in namespace_config:
|
|
||||||
logging.getLogger(namespace).setLevel(namespace_config.get("level"))
|
|
||||||
|
|
||||||
f = FilteringLogObserver(publisher, [log_filter])
|
|
||||||
lco = LogContextObserver(f)
|
|
||||||
|
|
||||||
if redirect_stdlib_logging:
|
|
||||||
stuff_into_twisted = PythonStdlibToTwistedLogger(lco)
|
|
||||||
stdliblogger = logging.getLogger()
|
|
||||||
stdliblogger.addHandler(stuff_into_twisted)
|
|
||||||
|
|
||||||
# Always redirect standard I/O, otherwise other logging outputs might miss
|
|
||||||
# it.
|
|
||||||
logBeginner.beginLoggingTo([lco], redirectStandardIO=True)
|
|
||||||
|
|
||||||
return publisher
|
|
||||||
|
|
||||||
|
|
||||||
def reload_structured_logging(*args, log_config=None) -> None:
|
|
||||||
warnings.warn(
|
|
||||||
"Currently the structured logging system can not be reloaded, doing nothing"
|
|
||||||
)
|
|
||||||
|
|
|
@ -16,141 +16,65 @@
|
||||||
"""
|
"""
|
||||||
Log formatters that output terse JSON.
|
Log formatters that output terse JSON.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import IO
|
import logging
|
||||||
|
|
||||||
from twisted.logger import FileLogObserver
|
|
||||||
|
|
||||||
from synapse.logging._remote import TCPLogObserver
|
|
||||||
|
|
||||||
_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"))
|
_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"))
|
||||||
|
|
||||||
|
# The properties of a standard LogRecord.
|
||||||
def flatten_event(event: dict, metadata: dict, include_time: bool = False):
|
_LOG_RECORD_ATTRIBUTES = {
|
||||||
"""
|
"args",
|
||||||
Flatten a Twisted logging event to an dictionary capable of being sent
|
"asctime",
|
||||||
as a log event to a logging aggregation system.
|
"created",
|
||||||
|
"exc_info",
|
||||||
The format is vastly simplified and is not designed to be a "human readable
|
# exc_text isn't a public attribute, but is used to cache the result of formatException.
|
||||||
string" in the sense that traditional logs are. Instead, the structure is
|
"exc_text",
|
||||||
optimised for searchability and filtering, with human-understandable log
|
"filename",
|
||||||
keys.
|
"funcName",
|
||||||
|
"levelname",
|
||||||
Args:
|
"levelno",
|
||||||
event (dict): The Twisted logging event we are flattening.
|
"lineno",
|
||||||
metadata (dict): Additional data to include with each log message. This
|
"message",
|
||||||
can be information like the server name. Since the target log
|
"module",
|
||||||
consumer does not know who we are other than by host IP, this
|
"msecs",
|
||||||
allows us to forward through static information.
|
"msg",
|
||||||
include_time (bool): Should we include the `time` key? If False, the
|
"name",
|
||||||
event time is stripped from the event.
|
"pathname",
|
||||||
"""
|
"process",
|
||||||
new_event = {}
|
"processName",
|
||||||
|
"relativeCreated",
|
||||||
# If it's a failure, make the new event's log_failure be the traceback text.
|
"stack_info",
|
||||||
if "log_failure" in event:
|
"thread",
|
||||||
new_event["log_failure"] = event["log_failure"].getTraceback()
|
"threadName",
|
||||||
|
}
|
||||||
# If it's a warning, copy over a string representation of the warning.
|
|
||||||
if "warning" in event:
|
|
||||||
new_event["warning"] = str(event["warning"])
|
|
||||||
|
|
||||||
# Stdlib logging events have "log_text" as their human-readable portion,
|
|
||||||
# Twisted ones have "log_format". For now, include the log_format, so that
|
|
||||||
# context only given in the log format (e.g. what is being logged) is
|
|
||||||
# available.
|
|
||||||
if "log_text" in event:
|
|
||||||
new_event["log"] = event["log_text"]
|
|
||||||
else:
|
|
||||||
new_event["log"] = event["log_format"]
|
|
||||||
|
|
||||||
# We want to include the timestamp when forwarding over the network, but
|
|
||||||
# exclude it when we are writing to stdout. This is because the log ingester
|
|
||||||
# (e.g. logstash, fluentd) can add its own timestamp.
|
|
||||||
if include_time:
|
|
||||||
new_event["time"] = round(event["log_time"], 2)
|
|
||||||
|
|
||||||
# Convert the log level to a textual representation.
|
|
||||||
new_event["level"] = event["log_level"].name.upper()
|
|
||||||
|
|
||||||
# Ignore these keys, and do not transfer them over to the new log object.
|
|
||||||
# They are either useless (isError), transferred manually above (log_time,
|
|
||||||
# log_level, etc), or contain Python objects which are not useful for output
|
|
||||||
# (log_logger, log_source).
|
|
||||||
keys_to_delete = [
|
|
||||||
"isError",
|
|
||||||
"log_failure",
|
|
||||||
"log_format",
|
|
||||||
"log_level",
|
|
||||||
"log_logger",
|
|
||||||
"log_source",
|
|
||||||
"log_system",
|
|
||||||
"log_time",
|
|
||||||
"log_text",
|
|
||||||
"observer",
|
|
||||||
"warning",
|
|
||||||
]
|
|
||||||
|
|
||||||
# If it's from the Twisted legacy logger (twisted.python.log), it adds some
|
|
||||||
# more keys we want to purge.
|
|
||||||
if event.get("log_namespace") == "log_legacy":
|
|
||||||
keys_to_delete.extend(["message", "system", "time"])
|
|
||||||
|
|
||||||
# Rather than modify the dictionary in place, construct a new one with only
|
|
||||||
# the content we want. The original event should be considered 'frozen'.
|
|
||||||
for key in event.keys():
|
|
||||||
|
|
||||||
if key in keys_to_delete:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(event[key], (str, int, bool, float)) or event[key] is None:
|
|
||||||
# If it's a plain type, include it as is.
|
|
||||||
new_event[key] = event[key]
|
|
||||||
else:
|
|
||||||
# If it's not one of those basic types, write out a string
|
|
||||||
# representation. This should probably be a warning in development,
|
|
||||||
# so that we are sure we are only outputting useful data.
|
|
||||||
new_event[key] = str(event[key])
|
|
||||||
|
|
||||||
# Add the metadata information to the event (e.g. the server_name).
|
|
||||||
new_event.update(metadata)
|
|
||||||
|
|
||||||
return new_event
|
|
||||||
|
|
||||||
|
|
||||||
def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogObserver:
|
class JsonFormatter(logging.Formatter):
|
||||||
"""
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
A log observer that formats events to a flattened JSON representation.
|
event = {
|
||||||
|
"log": record.getMessage(),
|
||||||
|
"namespace": record.name,
|
||||||
|
"level": record.levelname,
|
||||||
|
}
|
||||||
|
|
||||||
Args:
|
return self._format(record, event)
|
||||||
outFile: The file object to write to.
|
|
||||||
metadata: Metadata to be added to each log object.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def formatEvent(_event: dict) -> str:
|
def _format(self, record: logging.LogRecord, event: dict) -> str:
|
||||||
flattened = flatten_event(_event, metadata)
|
# Add any extra attributes to the event.
|
||||||
return _encoder.encode(flattened) + "\n"
|
for key, value in record.__dict__.items():
|
||||||
|
if key not in _LOG_RECORD_ATTRIBUTES:
|
||||||
|
event[key] = value
|
||||||
|
|
||||||
return FileLogObserver(outFile, formatEvent)
|
return _encoder.encode(event)
|
||||||
|
|
||||||
|
|
||||||
def TerseJSONToTCPLogObserver(
|
class TerseJsonFormatter(JsonFormatter):
|
||||||
hs, host: str, port: int, metadata: dict, maximum_buffer: int
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
) -> FileLogObserver:
|
event = {
|
||||||
"""
|
"log": record.getMessage(),
|
||||||
A log observer that formats events to a flattened JSON representation.
|
"namespace": record.name,
|
||||||
|
"level": record.levelname,
|
||||||
|
"time": round(record.created, 2),
|
||||||
|
}
|
||||||
|
|
||||||
Args:
|
return self._format(record, event)
|
||||||
hs (HomeServer): The homeserver that is being logged for.
|
|
||||||
host: The host of the logging target.
|
|
||||||
port: The logging target's port.
|
|
||||||
metadata: Metadata to be added to each log object.
|
|
||||||
maximum_buffer: The maximum buffer size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def formatEvent(_event: dict) -> str:
|
|
||||||
flattened = flatten_event(_event, metadata, include_time=True)
|
|
||||||
return _encoder.encode(flattened) + "\n"
|
|
||||||
|
|
||||||
return TCPLogObserver(hs, host, port, formatEvent, maximum_buffer)
|
|
||||||
|
|
33
synapse/logging/filter.py
Normal file
33
synapse/logging/filter.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataFilter(logging.Filter):
|
||||||
|
"""Logging filter that adds constant values to each record.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata: Key-value pairs to add to each record.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, metadata: dict):
|
||||||
|
self._metadata = metadata
|
||||||
|
|
||||||
|
def filter(self, record: logging.LogRecord) -> Literal[True]:
|
||||||
|
for key, value in self._metadata.items():
|
||||||
|
setattr(record, key, value)
|
||||||
|
return True
|
|
@ -28,6 +28,7 @@ from typing import (
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import attr
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
@ -173,6 +174,17 @@ class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
|
||||||
return bool(self.events)
|
return bool(self.events)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True)
|
||||||
|
class _PendingRoomEventEntry:
|
||||||
|
event_pos = attr.ib(type=PersistedEventPosition)
|
||||||
|
extra_users = attr.ib(type=Collection[UserID])
|
||||||
|
|
||||||
|
room_id = attr.ib(type=str)
|
||||||
|
type = attr.ib(type=str)
|
||||||
|
state_key = attr.ib(type=Optional[str])
|
||||||
|
membership = attr.ib(type=Optional[str])
|
||||||
|
|
||||||
|
|
||||||
class Notifier:
|
class Notifier:
|
||||||
""" This class is responsible for notifying any listeners when there are
|
""" This class is responsible for notifying any listeners when there are
|
||||||
new events available for it.
|
new events available for it.
|
||||||
|
@ -190,9 +202,7 @@ class Notifier:
|
||||||
self.storage = hs.get_storage()
|
self.storage = hs.get_storage()
|
||||||
self.event_sources = hs.get_event_sources()
|
self.event_sources = hs.get_event_sources()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.pending_new_room_events = (
|
self.pending_new_room_events = [] # type: List[_PendingRoomEventEntry]
|
||||||
[]
|
|
||||||
) # type: List[Tuple[PersistedEventPosition, EventBase, Collection[UserID]]]
|
|
||||||
|
|
||||||
# Called when there are new things to stream over replication
|
# Called when there are new things to stream over replication
|
||||||
self.replication_callbacks = [] # type: List[Callable[[], None]]
|
self.replication_callbacks = [] # type: List[Callable[[], None]]
|
||||||
|
@ -255,7 +265,29 @@ class Notifier:
|
||||||
max_room_stream_token: RoomStreamToken,
|
max_room_stream_token: RoomStreamToken,
|
||||||
extra_users: Collection[UserID] = [],
|
extra_users: Collection[UserID] = [],
|
||||||
):
|
):
|
||||||
""" Used by handlers to inform the notifier something has happened
|
"""Unwraps event and calls `on_new_room_event_args`.
|
||||||
|
"""
|
||||||
|
self.on_new_room_event_args(
|
||||||
|
event_pos=event_pos,
|
||||||
|
room_id=event.room_id,
|
||||||
|
event_type=event.type,
|
||||||
|
state_key=event.get("state_key"),
|
||||||
|
membership=event.content.get("membership"),
|
||||||
|
max_room_stream_token=max_room_stream_token,
|
||||||
|
extra_users=extra_users,
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_new_room_event_args(
|
||||||
|
self,
|
||||||
|
room_id: str,
|
||||||
|
event_type: str,
|
||||||
|
state_key: Optional[str],
|
||||||
|
membership: Optional[str],
|
||||||
|
event_pos: PersistedEventPosition,
|
||||||
|
max_room_stream_token: RoomStreamToken,
|
||||||
|
extra_users: Collection[UserID] = [],
|
||||||
|
):
|
||||||
|
"""Used by handlers to inform the notifier something has happened
|
||||||
in the room, room event wise.
|
in the room, room event wise.
|
||||||
|
|
||||||
This triggers the notifier to wake up any listeners that are
|
This triggers the notifier to wake up any listeners that are
|
||||||
|
@ -266,7 +298,16 @@ class Notifier:
|
||||||
until all previous events have been persisted before notifying
|
until all previous events have been persisted before notifying
|
||||||
the client streams.
|
the client streams.
|
||||||
"""
|
"""
|
||||||
self.pending_new_room_events.append((event_pos, event, extra_users))
|
self.pending_new_room_events.append(
|
||||||
|
_PendingRoomEventEntry(
|
||||||
|
event_pos=event_pos,
|
||||||
|
extra_users=extra_users,
|
||||||
|
room_id=room_id,
|
||||||
|
type=event_type,
|
||||||
|
state_key=state_key,
|
||||||
|
membership=membership,
|
||||||
|
)
|
||||||
|
)
|
||||||
self._notify_pending_new_room_events(max_room_stream_token)
|
self._notify_pending_new_room_events(max_room_stream_token)
|
||||||
|
|
||||||
self.notify_replication()
|
self.notify_replication()
|
||||||
|
@ -284,18 +325,19 @@ class Notifier:
|
||||||
users = set() # type: Set[UserID]
|
users = set() # type: Set[UserID]
|
||||||
rooms = set() # type: Set[str]
|
rooms = set() # type: Set[str]
|
||||||
|
|
||||||
for event_pos, event, extra_users in pending:
|
for entry in pending:
|
||||||
if event_pos.persisted_after(max_room_stream_token):
|
if entry.event_pos.persisted_after(max_room_stream_token):
|
||||||
self.pending_new_room_events.append((event_pos, event, extra_users))
|
self.pending_new_room_events.append(entry)
|
||||||
else:
|
else:
|
||||||
if (
|
if (
|
||||||
event.type == EventTypes.Member
|
entry.type == EventTypes.Member
|
||||||
and event.membership == Membership.JOIN
|
and entry.membership == Membership.JOIN
|
||||||
|
and entry.state_key
|
||||||
):
|
):
|
||||||
self._user_joined_room(event.state_key, event.room_id)
|
self._user_joined_room(entry.state_key, entry.room_id)
|
||||||
|
|
||||||
users.update(extra_users)
|
users.update(entry.extra_users)
|
||||||
rooms.add(event.room_id)
|
rooms.add(entry.room_id)
|
||||||
|
|
||||||
if users or rooms:
|
if users or rooms:
|
||||||
self.on_new_event(
|
self.on_new_event(
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
|
import attr
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership, RelationTypes
|
from synapse.api.constants import EventTypes, Membership, RelationTypes
|
||||||
|
@ -26,7 +26,8 @@ from synapse.events.snapshot import EventContext
|
||||||
from synapse.state import POWER_KEY
|
from synapse.state import POWER_KEY
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches import register_cache
|
from synapse.util.caches import register_cache
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import lru_cache
|
||||||
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
|
||||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||||
|
|
||||||
|
@ -120,7 +121,7 @@ class BulkPushRuleEvaluator:
|
||||||
dict of user_id -> push_rules
|
dict of user_id -> push_rules
|
||||||
"""
|
"""
|
||||||
room_id = event.room_id
|
room_id = event.room_id
|
||||||
rules_for_room = await self._get_rules_for_room(room_id)
|
rules_for_room = self._get_rules_for_room(room_id)
|
||||||
|
|
||||||
rules_by_user = await rules_for_room.get_rules(event, context)
|
rules_by_user = await rules_for_room.get_rules(event, context)
|
||||||
|
|
||||||
|
@ -138,7 +139,7 @@ class BulkPushRuleEvaluator:
|
||||||
|
|
||||||
return rules_by_user
|
return rules_by_user
|
||||||
|
|
||||||
@cached()
|
@lru_cache()
|
||||||
def _get_rules_for_room(self, room_id):
|
def _get_rules_for_room(self, room_id):
|
||||||
"""Get the current RulesForRoom object for the given room id
|
"""Get the current RulesForRoom object for the given room id
|
||||||
|
|
||||||
|
@ -275,12 +276,14 @@ class RulesForRoom:
|
||||||
the entire cache for the room.
|
the entire cache for the room.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics):
|
def __init__(
|
||||||
|
self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hs (HomeServer)
|
hs (HomeServer)
|
||||||
room_id (str)
|
room_id (str)
|
||||||
rules_for_room_cache(Cache): The cache object that caches these
|
rules_for_room_cache: The cache object that caches these
|
||||||
RoomsForUser objects.
|
RoomsForUser objects.
|
||||||
room_push_rule_cache_metrics (CacheMetric)
|
room_push_rule_cache_metrics (CacheMetric)
|
||||||
"""
|
"""
|
||||||
|
@ -489,13 +492,21 @@ class RulesForRoom:
|
||||||
self.state_group = state_group
|
self.state_group = state_group
|
||||||
|
|
||||||
|
|
||||||
class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
|
@attr.attrs(slots=True, frozen=True)
|
||||||
# We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
|
class _Invalidation:
|
||||||
# which namedtuple does for us (i.e. two _CacheContext are the same if
|
# _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules,
|
||||||
# their caches and keys match). This is important in particular to
|
# which means that it it is stored on the bulk_get_push_rules cache entry. In order
|
||||||
# dedupe when we add callbacks to lru cache nodes, otherwise the number
|
# to ensure that we don't accumulate lots of redunant callbacks on the cache entry,
|
||||||
# of callbacks would grow.
|
# we need to ensure that two _Invalidation objects are "equal" if they refer to the
|
||||||
|
# same `cache` and `room_id`.
|
||||||
|
#
|
||||||
|
# attrs provides suitable __hash__ and __eq__ methods, provided we remember to
|
||||||
|
# set `frozen=True`.
|
||||||
|
|
||||||
|
cache = attr.ib(type=LruCache)
|
||||||
|
room_id = attr.ib(type=str)
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
rules = self.cache.get_immediate(self.room_id, None, update_metrics=False)
|
rules = self.cache.get(self.room_id, None, update_metrics=False)
|
||||||
if rules:
|
if rules:
|
||||||
rules.invalidate_all()
|
rules.invalidate_all()
|
||||||
|
|
|
@ -77,8 +77,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
|
||||||
|
|
||||||
requester = Requester.deserialize(self.store, content["requester"])
|
requester = Requester.deserialize(self.store, content["requester"])
|
||||||
|
|
||||||
if requester.user:
|
request.requester = requester
|
||||||
request.authenticated_entity = requester.user.to_string()
|
|
||||||
|
|
||||||
logger.info("remote_join: %s into room: %s", user_id, room_id)
|
logger.info("remote_join: %s into room: %s", user_id, room_id)
|
||||||
|
|
||||||
|
@ -142,8 +141,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
|
||||||
|
|
||||||
requester = Requester.deserialize(self.store, content["requester"])
|
requester = Requester.deserialize(self.store, content["requester"])
|
||||||
|
|
||||||
if requester.user:
|
request.requester = requester
|
||||||
request.authenticated_entity = requester.user.to_string()
|
|
||||||
|
|
||||||
# hopefully we're now on the master, so this won't recurse!
|
# hopefully we're now on the master, so this won't recurse!
|
||||||
event_id, stream_id = await self.member_handler.remote_reject_invite(
|
event_id, stream_id = await self.member_handler.remote_reject_invite(
|
||||||
|
|
|
@ -115,8 +115,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
||||||
ratelimit = content["ratelimit"]
|
ratelimit = content["ratelimit"]
|
||||||
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
|
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
|
||||||
|
|
||||||
if requester.user:
|
request.requester = requester
|
||||||
request.authenticated_entity = requester.user.to_string()
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id
|
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id
|
||||||
|
|
|
@ -141,21 +141,25 @@ class ReplicationDataHandler:
|
||||||
if row.type != EventsStreamEventRow.TypeId:
|
if row.type != EventsStreamEventRow.TypeId:
|
||||||
continue
|
continue
|
||||||
assert isinstance(row, EventsStreamRow)
|
assert isinstance(row, EventsStreamRow)
|
||||||
|
assert isinstance(row.data, EventsStreamEventRow)
|
||||||
|
|
||||||
event = await self.store.get_event(
|
if row.data.rejected:
|
||||||
row.data.event_id, allow_rejected=True
|
|
||||||
)
|
|
||||||
if event.rejected_reason:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
extra_users = () # type: Tuple[UserID, ...]
|
extra_users = () # type: Tuple[UserID, ...]
|
||||||
if event.type == EventTypes.Member:
|
if row.data.type == EventTypes.Member and row.data.state_key:
|
||||||
extra_users = (UserID.from_string(event.state_key),)
|
extra_users = (UserID.from_string(row.data.state_key),)
|
||||||
|
|
||||||
max_token = self.store.get_room_max_token()
|
max_token = self.store.get_room_max_token()
|
||||||
event_pos = PersistedEventPosition(instance_name, token)
|
event_pos = PersistedEventPosition(instance_name, token)
|
||||||
self.notifier.on_new_room_event(
|
self.notifier.on_new_room_event_args(
|
||||||
event, event_pos, max_token, extra_users
|
event_pos=event_pos,
|
||||||
|
max_room_stream_token=max_token,
|
||||||
|
extra_users=extra_users,
|
||||||
|
room_id=row.data.room_id,
|
||||||
|
event_type=row.data.type,
|
||||||
|
state_key=row.data.state_key,
|
||||||
|
membership=row.data.membership,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Notify any waiting deferreds. The list is ordered by position so we
|
# Notify any waiting deferreds. The list is ordered by position so we
|
||||||
|
|
|
@ -15,12 +15,15 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import heapq
|
import heapq
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import List, Tuple, Type
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from ._base import Stream, StreamUpdateResult, Token
|
from ._base import Stream, StreamUpdateResult, Token
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
"""Handling of the 'events' replication stream
|
"""Handling of the 'events' replication stream
|
||||||
|
|
||||||
This stream contains rows of various types. Each row therefore contains a 'type'
|
This stream contains rows of various types. Each row therefore contains a 'type'
|
||||||
|
@ -81,12 +84,14 @@ class BaseEventsStreamRow:
|
||||||
class EventsStreamEventRow(BaseEventsStreamRow):
|
class EventsStreamEventRow(BaseEventsStreamRow):
|
||||||
TypeId = "ev"
|
TypeId = "ev"
|
||||||
|
|
||||||
event_id = attr.ib() # str
|
event_id = attr.ib(type=str)
|
||||||
room_id = attr.ib() # str
|
room_id = attr.ib(type=str)
|
||||||
type = attr.ib() # str
|
type = attr.ib(type=str)
|
||||||
state_key = attr.ib() # str, optional
|
state_key = attr.ib(type=Optional[str])
|
||||||
redacts = attr.ib() # str, optional
|
redacts = attr.ib(type=Optional[str])
|
||||||
relates_to = attr.ib() # str, optional
|
relates_to = attr.ib(type=Optional[str])
|
||||||
|
membership = attr.ib(type=Optional[str])
|
||||||
|
rejected = attr.ib(type=bool)
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True)
|
@attr.s(slots=True, frozen=True)
|
||||||
|
@ -113,7 +118,7 @@ class EventsStream(Stream):
|
||||||
|
|
||||||
NAME = "events"
|
NAME = "events"
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self._store = hs.get_datastore()
|
self._store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
hs.get_instance_name(),
|
hs.get_instance_name(),
|
||||||
|
|
|
@ -50,6 +50,7 @@ from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
|
||||||
from synapse.rest.admin.users import (
|
from synapse.rest.admin.users import (
|
||||||
AccountValidityRenewServlet,
|
AccountValidityRenewServlet,
|
||||||
DeactivateAccountRestServlet,
|
DeactivateAccountRestServlet,
|
||||||
|
PushersRestServlet,
|
||||||
ResetPasswordRestServlet,
|
ResetPasswordRestServlet,
|
||||||
SearchUsersRestServlet,
|
SearchUsersRestServlet,
|
||||||
UserAdminServlet,
|
UserAdminServlet,
|
||||||
|
@ -226,8 +227,9 @@ def register_servlets(hs, http_server):
|
||||||
DeviceRestServlet(hs).register(http_server)
|
DeviceRestServlet(hs).register(http_server)
|
||||||
DevicesRestServlet(hs).register(http_server)
|
DevicesRestServlet(hs).register(http_server)
|
||||||
DeleteDevicesRestServlet(hs).register(http_server)
|
DeleteDevicesRestServlet(hs).register(http_server)
|
||||||
EventReportsRestServlet(hs).register(http_server)
|
|
||||||
EventReportDetailRestServlet(hs).register(http_server)
|
EventReportDetailRestServlet(hs).register(http_server)
|
||||||
|
EventReportsRestServlet(hs).register(http_server)
|
||||||
|
PushersRestServlet(hs).register(http_server)
|
||||||
|
|
||||||
|
|
||||||
def register_servlets_for_client_rest_resource(hs, http_server):
|
def register_servlets_for_client_rest_resource(hs, http_server):
|
||||||
|
|
|
@ -39,6 +39,17 @@ from synapse.types import JsonDict, UserID
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_GET_PUSHERS_ALLOWED_KEYS = {
|
||||||
|
"app_display_name",
|
||||||
|
"app_id",
|
||||||
|
"data",
|
||||||
|
"device_display_name",
|
||||||
|
"kind",
|
||||||
|
"lang",
|
||||||
|
"profile_tag",
|
||||||
|
"pushkey",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class UsersRestServlet(RestServlet):
|
class UsersRestServlet(RestServlet):
|
||||||
PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$")
|
PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$")
|
||||||
|
@ -713,6 +724,47 @@ class UserMembershipRestServlet(RestServlet):
|
||||||
return 200, ret
|
return 200, ret
|
||||||
|
|
||||||
|
|
||||||
|
class PushersRestServlet(RestServlet):
|
||||||
|
"""
|
||||||
|
Gets information about all pushers for a specific `user_id`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
http://localhost:8008/_synapse/admin/v1/users/
|
||||||
|
@user:server/pushers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pushers: Dictionary containing pushers information.
|
||||||
|
total: Number of pushers in dictonary `pushers`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.is_mine = hs.is_mine
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
|
if not self.is_mine(UserID.from_string(user_id)):
|
||||||
|
raise SynapseError(400, "Can only lookup local users")
|
||||||
|
|
||||||
|
if not await self.store.get_user_by_id(user_id):
|
||||||
|
raise NotFoundError("User not found")
|
||||||
|
|
||||||
|
pushers = await self.store.get_pushers_by_user_id(user_id)
|
||||||
|
|
||||||
|
filtered_pushers = [
|
||||||
|
{k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS}
|
||||||
|
for p in pushers
|
||||||
|
]
|
||||||
|
|
||||||
|
return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
|
||||||
|
|
||||||
|
|
||||||
class UserMediaRestServlet(RestServlet):
|
class UserMediaRestServlet(RestServlet):
|
||||||
"""
|
"""
|
||||||
Gets information about all uploaded local media for a specific `user_id`.
|
Gets information about all uploaded local media for a specific `user_id`.
|
||||||
|
|
|
@ -305,15 +305,12 @@ class MediaRepository:
|
||||||
# file_id is the ID we use to track the file locally. If we've already
|
# file_id is the ID we use to track the file locally. If we've already
|
||||||
# seen the file then reuse the existing ID, otherwise genereate a new
|
# seen the file then reuse the existing ID, otherwise genereate a new
|
||||||
# one.
|
# one.
|
||||||
if media_info:
|
|
||||||
file_id = media_info["filesystem_id"]
|
|
||||||
else:
|
|
||||||
file_id = random_string(24)
|
|
||||||
|
|
||||||
file_info = FileInfo(server_name, file_id)
|
|
||||||
|
|
||||||
# If we have an entry in the DB, try and look for it
|
# If we have an entry in the DB, try and look for it
|
||||||
if media_info:
|
if media_info:
|
||||||
|
file_id = media_info["filesystem_id"]
|
||||||
|
file_info = FileInfo(server_name, file_id)
|
||||||
|
|
||||||
if media_info["quarantined_by"]:
|
if media_info["quarantined_by"]:
|
||||||
logger.info("Media is quarantined")
|
logger.info("Media is quarantined")
|
||||||
raise NotFoundError()
|
raise NotFoundError()
|
||||||
|
@ -324,14 +321,34 @@ class MediaRepository:
|
||||||
|
|
||||||
# Failed to find the file anywhere, lets download it.
|
# Failed to find the file anywhere, lets download it.
|
||||||
|
|
||||||
media_info = await self._download_remote_file(server_name, media_id, file_id)
|
try:
|
||||||
|
media_info = await self._download_remote_file(server_name, media_id,)
|
||||||
|
except SynapseError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# An exception may be because we downloaded media in another
|
||||||
|
# process, so let's check if we magically have the media.
|
||||||
|
media_info = await self.store.get_cached_remote_media(server_name, media_id)
|
||||||
|
if not media_info:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
file_id = media_info["filesystem_id"]
|
||||||
|
file_info = FileInfo(server_name, file_id)
|
||||||
|
|
||||||
|
# We generate thumbnails even if another process downloaded the media
|
||||||
|
# as a) it's conceivable that the other download request dies before it
|
||||||
|
# generates thumbnails, but mainly b) we want to be sure the thumbnails
|
||||||
|
# have finished being generated before responding to the client,
|
||||||
|
# otherwise they'll request thumbnails and get a 404 if they're not
|
||||||
|
# ready yet.
|
||||||
|
await self._generate_thumbnails(
|
||||||
|
server_name, media_id, file_id, media_info["media_type"]
|
||||||
|
)
|
||||||
|
|
||||||
responder = await self.media_storage.fetch_media(file_info)
|
responder = await self.media_storage.fetch_media(file_info)
|
||||||
return responder, media_info
|
return responder, media_info
|
||||||
|
|
||||||
async def _download_remote_file(
|
async def _download_remote_file(self, server_name: str, media_id: str,) -> dict:
|
||||||
self, server_name: str, media_id: str, file_id: str
|
|
||||||
) -> dict:
|
|
||||||
"""Attempt to download the remote file from the given server name,
|
"""Attempt to download the remote file from the given server name,
|
||||||
using the given file_id as the local id.
|
using the given file_id as the local id.
|
||||||
|
|
||||||
|
@ -346,6 +363,8 @@ class MediaRepository:
|
||||||
The media info of the file.
|
The media info of the file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
file_id = random_string(24)
|
||||||
|
|
||||||
file_info = FileInfo(server_name=server_name, file_id=file_id)
|
file_info = FileInfo(server_name=server_name, file_id=file_id)
|
||||||
|
|
||||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
||||||
|
@ -401,22 +420,32 @@ class MediaRepository:
|
||||||
|
|
||||||
await finish()
|
await finish()
|
||||||
|
|
||||||
media_type = headers[b"Content-Type"][0].decode("ascii")
|
media_type = headers[b"Content-Type"][0].decode("ascii")
|
||||||
upload_name = get_filename_from_headers(headers)
|
upload_name = get_filename_from_headers(headers)
|
||||||
time_now_ms = self.clock.time_msec()
|
time_now_ms = self.clock.time_msec()
|
||||||
|
|
||||||
|
# Multiple remote media download requests can race (when using
|
||||||
|
# multiple media repos), so this may throw a violation constraint
|
||||||
|
# exception. If it does we'll delete the newly downloaded file from
|
||||||
|
# disk (as we're in the ctx manager).
|
||||||
|
#
|
||||||
|
# However: we've already called `finish()` so we may have also
|
||||||
|
# written to the storage providers. This is preferable to the
|
||||||
|
# alternative where we call `finish()` *after* this, where we could
|
||||||
|
# end up having an entry in the DB but fail to write the files to
|
||||||
|
# the storage providers.
|
||||||
|
await self.store.store_cached_remote_media(
|
||||||
|
origin=server_name,
|
||||||
|
media_id=media_id,
|
||||||
|
media_type=media_type,
|
||||||
|
time_now_ms=self.clock.time_msec(),
|
||||||
|
upload_name=upload_name,
|
||||||
|
media_length=length,
|
||||||
|
filesystem_id=file_id,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Stored remote media in file %r", fname)
|
logger.info("Stored remote media in file %r", fname)
|
||||||
|
|
||||||
await self.store.store_cached_remote_media(
|
|
||||||
origin=server_name,
|
|
||||||
media_id=media_id,
|
|
||||||
media_type=media_type,
|
|
||||||
time_now_ms=self.clock.time_msec(),
|
|
||||||
upload_name=upload_name,
|
|
||||||
media_length=length,
|
|
||||||
filesystem_id=file_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
media_info = {
|
media_info = {
|
||||||
"media_type": media_type,
|
"media_type": media_type,
|
||||||
"media_length": length,
|
"media_length": length,
|
||||||
|
@ -425,8 +454,6 @@ class MediaRepository:
|
||||||
"filesystem_id": file_id,
|
"filesystem_id": file_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
await self._generate_thumbnails(server_name, media_id, file_id, media_type)
|
|
||||||
|
|
||||||
return media_info
|
return media_info
|
||||||
|
|
||||||
def _get_thumbnail_requirements(self, media_type):
|
def _get_thumbnail_requirements(self, media_type):
|
||||||
|
@ -692,42 +719,60 @@ class MediaRepository:
|
||||||
if not t_byte_source:
|
if not t_byte_source:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
file_info = FileInfo(
|
||||||
file_info = FileInfo(
|
server_name=server_name,
|
||||||
server_name=server_name,
|
file_id=file_id,
|
||||||
file_id=file_id,
|
thumbnail=True,
|
||||||
thumbnail=True,
|
thumbnail_width=t_width,
|
||||||
thumbnail_width=t_width,
|
thumbnail_height=t_height,
|
||||||
thumbnail_height=t_height,
|
thumbnail_method=t_method,
|
||||||
thumbnail_method=t_method,
|
thumbnail_type=t_type,
|
||||||
thumbnail_type=t_type,
|
url_cache=url_cache,
|
||||||
url_cache=url_cache,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
output_path = await self.media_storage.store_file(
|
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
||||||
t_byte_source, file_info
|
try:
|
||||||
)
|
await self.media_storage.write_to_file(t_byte_source, f)
|
||||||
finally:
|
await finish()
|
||||||
t_byte_source.close()
|
finally:
|
||||||
|
t_byte_source.close()
|
||||||
|
|
||||||
t_len = os.path.getsize(output_path)
|
t_len = os.path.getsize(fname)
|
||||||
|
|
||||||
# Write to database
|
# Write to database
|
||||||
if server_name:
|
if server_name:
|
||||||
await self.store.store_remote_media_thumbnail(
|
# Multiple remote media download requests can race (when
|
||||||
server_name,
|
# using multiple media repos), so this may throw a violation
|
||||||
media_id,
|
# constraint exception. If it does we'll delete the newly
|
||||||
file_id,
|
# generated thumbnail from disk (as we're in the ctx
|
||||||
t_width,
|
# manager).
|
||||||
t_height,
|
#
|
||||||
t_type,
|
# However: we've already called `finish()` so we may have
|
||||||
t_method,
|
# also written to the storage providers. This is preferable
|
||||||
t_len,
|
# to the alternative where we call `finish()` *after* this,
|
||||||
)
|
# where we could end up having an entry in the DB but fail
|
||||||
else:
|
# to write the files to the storage providers.
|
||||||
await self.store.store_local_thumbnail(
|
try:
|
||||||
media_id, t_width, t_height, t_type, t_method, t_len
|
await self.store.store_remote_media_thumbnail(
|
||||||
)
|
server_name,
|
||||||
|
media_id,
|
||||||
|
file_id,
|
||||||
|
t_width,
|
||||||
|
t_height,
|
||||||
|
t_type,
|
||||||
|
t_method,
|
||||||
|
t_len,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
thumbnail_exists = await self.store.get_remote_media_thumbnail(
|
||||||
|
server_name, media_id, t_width, t_height, t_type,
|
||||||
|
)
|
||||||
|
if not thumbnail_exists:
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
await self.store.store_local_thumbnail(
|
||||||
|
media_id, t_width, t_height, t_type, t_method, t_len
|
||||||
|
)
|
||||||
|
|
||||||
return {"width": m_width, "height": m_height}
|
return {"width": m_width, "height": m_height}
|
||||||
|
|
||||||
|
|
|
@ -52,6 +52,7 @@ class MediaStorage:
|
||||||
storage_providers: Sequence["StorageProviderWrapper"],
|
storage_providers: Sequence["StorageProviderWrapper"],
|
||||||
):
|
):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
self.reactor = hs.get_reactor()
|
||||||
self.local_media_directory = local_media_directory
|
self.local_media_directory = local_media_directory
|
||||||
self.filepaths = filepaths
|
self.filepaths = filepaths
|
||||||
self.storage_providers = storage_providers
|
self.storage_providers = storage_providers
|
||||||
|
@ -70,13 +71,16 @@ class MediaStorage:
|
||||||
|
|
||||||
with self.store_into_file(file_info) as (f, fname, finish_cb):
|
with self.store_into_file(file_info) as (f, fname, finish_cb):
|
||||||
# Write to the main repository
|
# Write to the main repository
|
||||||
await defer_to_thread(
|
await self.write_to_file(source, f)
|
||||||
self.hs.get_reactor(), _write_file_synchronously, source, f
|
|
||||||
)
|
|
||||||
await finish_cb()
|
await finish_cb()
|
||||||
|
|
||||||
return fname
|
return fname
|
||||||
|
|
||||||
|
async def write_to_file(self, source: IO, output: IO):
|
||||||
|
"""Asynchronously write the `source` to `output`.
|
||||||
|
"""
|
||||||
|
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def store_into_file(self, file_info: FileInfo):
|
def store_into_file(self, file_info: FileInfo):
|
||||||
"""Context manager used to get a file like object to write into, as
|
"""Context manager used to get a file like object to write into, as
|
||||||
|
@ -112,14 +116,20 @@ class MediaStorage:
|
||||||
|
|
||||||
finished_called = [False]
|
finished_called = [False]
|
||||||
|
|
||||||
async def finish():
|
|
||||||
for provider in self.storage_providers:
|
|
||||||
await provider.store_file(path, file_info)
|
|
||||||
|
|
||||||
finished_called[0] = True
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(fname, "wb") as f:
|
with open(fname, "wb") as f:
|
||||||
|
|
||||||
|
async def finish():
|
||||||
|
# Ensure that all writes have been flushed and close the
|
||||||
|
# file.
|
||||||
|
f.flush()
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
for provider in self.storage_providers:
|
||||||
|
await provider.store_file(path, file_info)
|
||||||
|
|
||||||
|
finished_called[0] = True
|
||||||
|
|
||||||
yield f, fname, finish
|
yield f, fname, finish
|
||||||
except Exception:
|
except Exception:
|
||||||
try:
|
try:
|
||||||
|
@ -210,7 +220,7 @@ class MediaStorage:
|
||||||
if res:
|
if res:
|
||||||
with res:
|
with res:
|
||||||
consumer = BackgroundFileConsumer(
|
consumer = BackgroundFileConsumer(
|
||||||
open(local_path, "wb"), self.hs.get_reactor()
|
open(local_path, "wb"), self.reactor
|
||||||
)
|
)
|
||||||
await res.write_to_consumer(consumer)
|
await res.write_to_consumer(consumer)
|
||||||
await consumer.wait()
|
await consumer.wait()
|
||||||
|
|
|
@ -94,7 +94,7 @@ def make_pool(
|
||||||
cp_openfun=lambda conn: engine.on_new_connection(
|
cp_openfun=lambda conn: engine.on_new_connection(
|
||||||
LoggingDatabaseConnection(conn, engine, "on_new_connection")
|
LoggingDatabaseConnection(conn, engine, "on_new_connection")
|
||||||
),
|
),
|
||||||
**db_config.config.get("args", {})
|
**db_config.config.get("args", {}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -632,7 +632,7 @@ class DatabasePool:
|
||||||
func,
|
func,
|
||||||
*args,
|
*args,
|
||||||
db_autocommit=db_autocommit,
|
db_autocommit=db_autocommit,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
for after_callback, after_args, after_kwargs in after_callbacks:
|
for after_callback, after_args, after_kwargs in after_callbacks:
|
||||||
|
|
|
@ -15,21 +15,31 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import List
|
from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
|
||||||
|
|
||||||
from synapse.appservice import ApplicationService, AppServiceTransaction
|
from synapse.appservice import (
|
||||||
|
ApplicationService,
|
||||||
|
ApplicationServiceState,
|
||||||
|
AppServiceTransaction,
|
||||||
|
)
|
||||||
from synapse.config.appservice import load_appservices
|
from synapse.config.appservice import load_appservices
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||||
|
from synapse.storage.types import Connection
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _make_exclusive_regex(services_cache):
|
def _make_exclusive_regex(
|
||||||
|
services_cache: List[ApplicationService],
|
||||||
|
) -> Optional[Pattern]:
|
||||||
# We precompile a regex constructed from all the regexes that the AS's
|
# We precompile a regex constructed from all the regexes that the AS's
|
||||||
# have registered for exclusive users.
|
# have registered for exclusive users.
|
||||||
exclusive_user_regexes = [
|
exclusive_user_regexes = [
|
||||||
|
@ -39,17 +49,19 @@ def _make_exclusive_regex(services_cache):
|
||||||
]
|
]
|
||||||
if exclusive_user_regexes:
|
if exclusive_user_regexes:
|
||||||
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
|
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
|
||||||
exclusive_user_regex = re.compile(exclusive_user_regex)
|
exclusive_user_pattern = re.compile(
|
||||||
|
exclusive_user_regex
|
||||||
|
) # type: Optional[Pattern]
|
||||||
else:
|
else:
|
||||||
# We handle this case specially otherwise the constructed regex
|
# We handle this case specially otherwise the constructed regex
|
||||||
# will always match
|
# will always match
|
||||||
exclusive_user_regex = None
|
exclusive_user_pattern = None
|
||||||
|
|
||||||
return exclusive_user_regex
|
return exclusive_user_pattern
|
||||||
|
|
||||||
|
|
||||||
class ApplicationServiceWorkerStore(SQLBaseStore):
|
class ApplicationServiceWorkerStore(SQLBaseStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||||
self.services_cache = load_appservices(
|
self.services_cache = load_appservices(
|
||||||
hs.hostname, hs.config.app_service_config_files
|
hs.hostname, hs.config.app_service_config_files
|
||||||
)
|
)
|
||||||
|
@ -60,7 +72,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
|
||||||
def get_app_services(self):
|
def get_app_services(self):
|
||||||
return self.services_cache
|
return self.services_cache
|
||||||
|
|
||||||
def get_if_app_services_interested_in_user(self, user_id):
|
def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
|
||||||
"""Check if the user is one associated with an app service (exclusively)
|
"""Check if the user is one associated with an app service (exclusively)
|
||||||
"""
|
"""
|
||||||
if self.exclusive_user_regex:
|
if self.exclusive_user_regex:
|
||||||
|
@ -68,7 +80,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_app_service_by_user_id(self, user_id):
|
def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]:
|
||||||
"""Retrieve an application service from their user ID.
|
"""Retrieve an application service from their user ID.
|
||||||
|
|
||||||
All application services have associated with them a particular user ID.
|
All application services have associated with them a particular user ID.
|
||||||
|
@ -77,35 +89,35 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
|
||||||
a user ID to an application service.
|
a user ID to an application service.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): The user ID to see if it is an application service.
|
user_id: The user ID to see if it is an application service.
|
||||||
Returns:
|
Returns:
|
||||||
synapse.appservice.ApplicationService or None.
|
The application service or None.
|
||||||
"""
|
"""
|
||||||
for service in self.services_cache:
|
for service in self.services_cache:
|
||||||
if service.sender == user_id:
|
if service.sender == user_id:
|
||||||
return service
|
return service
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_app_service_by_token(self, token):
|
def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]:
|
||||||
"""Get the application service with the given appservice token.
|
"""Get the application service with the given appservice token.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token (str): The application service token.
|
token: The application service token.
|
||||||
Returns:
|
Returns:
|
||||||
synapse.appservice.ApplicationService or None.
|
The application service or None.
|
||||||
"""
|
"""
|
||||||
for service in self.services_cache:
|
for service in self.services_cache:
|
||||||
if service.token == token:
|
if service.token == token:
|
||||||
return service
|
return service
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_app_service_by_id(self, as_id):
|
def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]:
|
||||||
"""Get the application service with the given appservice ID.
|
"""Get the application service with the given appservice ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
as_id (str): The application service ID.
|
as_id: The application service ID.
|
||||||
Returns:
|
Returns:
|
||||||
synapse.appservice.ApplicationService or None.
|
The application service or None.
|
||||||
"""
|
"""
|
||||||
for service in self.services_cache:
|
for service in self.services_cache:
|
||||||
if service.id == as_id:
|
if service.id == as_id:
|
||||||
|
@ -124,11 +136,13 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
|
||||||
class ApplicationServiceTransactionWorkerStore(
|
class ApplicationServiceTransactionWorkerStore(
|
||||||
ApplicationServiceWorkerStore, EventsWorkerStore
|
ApplicationServiceWorkerStore, EventsWorkerStore
|
||||||
):
|
):
|
||||||
async def get_appservices_by_state(self, state):
|
async def get_appservices_by_state(
|
||||||
|
self, state: ApplicationServiceState
|
||||||
|
) -> List[ApplicationService]:
|
||||||
"""Get a list of application services based on their state.
|
"""Get a list of application services based on their state.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state(ApplicationServiceState): The state to filter on.
|
state: The state to filter on.
|
||||||
Returns:
|
Returns:
|
||||||
A list of ApplicationServices, which may be empty.
|
A list of ApplicationServices, which may be empty.
|
||||||
"""
|
"""
|
||||||
|
@ -145,13 +159,15 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
services.append(service)
|
services.append(service)
|
||||||
return services
|
return services
|
||||||
|
|
||||||
async def get_appservice_state(self, service):
|
async def get_appservice_state(
|
||||||
|
self, service: ApplicationService
|
||||||
|
) -> Optional[ApplicationServiceState]:
|
||||||
"""Get the application service state.
|
"""Get the application service state.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service(ApplicationService): The service whose state to set.
|
service: The service whose state to set.
|
||||||
Returns:
|
Returns:
|
||||||
An ApplicationServiceState.
|
An ApplicationServiceState or none.
|
||||||
"""
|
"""
|
||||||
result = await self.db_pool.simple_select_one(
|
result = await self.db_pool.simple_select_one(
|
||||||
"application_services_state",
|
"application_services_state",
|
||||||
|
@ -164,12 +180,14 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
return result.get("state")
|
return result.get("state")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def set_appservice_state(self, service, state) -> None:
|
async def set_appservice_state(
|
||||||
|
self, service: ApplicationService, state: ApplicationServiceState
|
||||||
|
) -> None:
|
||||||
"""Set the application service state.
|
"""Set the application service state.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service(ApplicationService): The service whose state to set.
|
service: The service whose state to set.
|
||||||
state(ApplicationServiceState): The connectivity state to apply.
|
state: The connectivity state to apply.
|
||||||
"""
|
"""
|
||||||
await self.db_pool.simple_upsert(
|
await self.db_pool.simple_upsert(
|
||||||
"application_services_state", {"as_id": service.id}, {"state": state}
|
"application_services_state", {"as_id": service.id}, {"state": state}
|
||||||
|
@ -226,13 +244,14 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
"create_appservice_txn", _create_appservice_txn
|
"create_appservice_txn", _create_appservice_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
async def complete_appservice_txn(self, txn_id, service) -> None:
|
async def complete_appservice_txn(
|
||||||
|
self, txn_id: int, service: ApplicationService
|
||||||
|
) -> None:
|
||||||
"""Completes an application service transaction.
|
"""Completes an application service transaction.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
txn_id(str): The transaction ID being completed.
|
txn_id: The transaction ID being completed.
|
||||||
service(ApplicationService): The application service which was sent
|
service: The application service which was sent this transaction.
|
||||||
this transaction.
|
|
||||||
"""
|
"""
|
||||||
txn_id = int(txn_id)
|
txn_id = int(txn_id)
|
||||||
|
|
||||||
|
@ -242,7 +261,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
# has probably missed some events), so whine loudly but still continue,
|
# has probably missed some events), so whine loudly but still continue,
|
||||||
# since it shouldn't fail completion of the transaction.
|
# since it shouldn't fail completion of the transaction.
|
||||||
last_txn_id = self._get_last_txn(txn, service.id)
|
last_txn_id = self._get_last_txn(txn, service.id)
|
||||||
if (last_txn_id + 1) != txn_id:
|
if (txn_id + 1) != txn_id:
|
||||||
logger.error(
|
logger.error(
|
||||||
"appservice: Completing a transaction which has an ID > 1 from "
|
"appservice: Completing a transaction which has an ID > 1 from "
|
||||||
"the last ID sent to this AS. We've either dropped events or "
|
"the last ID sent to this AS. We've either dropped events or "
|
||||||
|
@ -272,12 +291,13 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
"complete_appservice_txn", _complete_appservice_txn
|
"complete_appservice_txn", _complete_appservice_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_oldest_unsent_txn(self, service):
|
async def get_oldest_unsent_txn(
|
||||||
"""Get the oldest transaction which has not been sent for this
|
self, service: ApplicationService
|
||||||
service.
|
) -> Optional[AppServiceTransaction]:
|
||||||
|
"""Get the oldest transaction which has not been sent for this service.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service(ApplicationService): The app service to get the oldest txn.
|
service: The app service to get the oldest txn.
|
||||||
Returns:
|
Returns:
|
||||||
An AppServiceTransaction or None.
|
An AppServiceTransaction or None.
|
||||||
"""
|
"""
|
||||||
|
@ -313,7 +333,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
service=service, id=entry["txn_id"], events=events, ephemeral=[]
|
service=service, id=entry["txn_id"], events=events, ephemeral=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_last_txn(self, txn, service_id):
|
def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"SELECT last_txn FROM application_services_state WHERE as_id=?",
|
"SELECT last_txn FROM application_services_state WHERE as_id=?",
|
||||||
(service_id,),
|
(service_id,),
|
||||||
|
@ -324,7 +344,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
else:
|
else:
|
||||||
return int(last_txn_id[0]) # select 'last_txn' col
|
return int(last_txn_id[0]) # select 'last_txn' col
|
||||||
|
|
||||||
async def set_appservice_last_pos(self, pos) -> None:
|
async def set_appservice_last_pos(self, pos: int) -> None:
|
||||||
def set_appservice_last_pos_txn(txn):
|
def set_appservice_last_pos_txn(txn):
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
|
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
|
||||||
|
@ -334,7 +354,9 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
"set_appservice_last_pos", set_appservice_last_pos_txn
|
"set_appservice_last_pos", set_appservice_last_pos_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_new_events_for_appservice(self, current_id, limit):
|
async def get_new_events_for_appservice(
|
||||||
|
self, current_id: int, limit: int
|
||||||
|
) -> Tuple[int, List[EventBase]]:
|
||||||
"""Get all new events for an appservice"""
|
"""Get all new events for an appservice"""
|
||||||
|
|
||||||
def get_new_events_for_appservice_txn(txn):
|
def get_new_events_for_appservice_txn(txn):
|
||||||
|
@ -394,7 +416,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
)
|
)
|
||||||
|
|
||||||
async def set_type_stream_id_for_appservice(
|
async def set_type_stream_id_for_appservice(
|
||||||
self, service: ApplicationService, type: str, pos: int
|
self, service: ApplicationService, type: str, pos: Optional[int]
|
||||||
) -> None:
|
) -> None:
|
||||||
if type not in ("read_receipt", "presence"):
|
if type not in ("read_receipt", "presence"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -22,7 +22,7 @@ from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||||
from synapse.util.frozenutils import frozendict_json_encoder
|
from synapse.util import json_encoder
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -104,7 +104,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
||||||
and original_event.internal_metadata.is_redacted()
|
and original_event.internal_metadata.is_redacted()
|
||||||
):
|
):
|
||||||
# Redaction was allowed
|
# Redaction was allowed
|
||||||
pruned_json = frozendict_json_encoder.encode(
|
pruned_json = json_encoder.encode(
|
||||||
prune_event_dict(
|
prune_event_dict(
|
||||||
original_event.room_version, original_event.get_dict()
|
original_event.room_version, original_event.get_dict()
|
||||||
)
|
)
|
||||||
|
@ -170,7 +170,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
||||||
return
|
return
|
||||||
|
|
||||||
# Prune the event's dict then convert it to JSON.
|
# Prune the event's dict then convert it to JSON.
|
||||||
pruned_json = frozendict_json_encoder.encode(
|
pruned_json = json_encoder.encode(
|
||||||
prune_event_dict(event.room_version, event.get_dict())
|
prune_event_dict(event.room_version, event.get_dict())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||||
from synapse.storage.databases.main.search import SearchEntry
|
from synapse.storage.databases.main.search import SearchEntry
|
||||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
from synapse.types import StateMap, get_domain_from_id
|
from synapse.types import StateMap, get_domain_from_id
|
||||||
from synapse.util.frozenutils import frozendict_json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -769,9 +769,7 @@ class PersistEventsStore:
|
||||||
logger.exception("")
|
logger.exception("")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
metadata_json = frozendict_json_encoder.encode(
|
metadata_json = json_encoder.encode(event.internal_metadata.get_dict())
|
||||||
event.internal_metadata.get_dict()
|
|
||||||
)
|
|
||||||
|
|
||||||
sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
|
sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
|
||||||
txn.execute(sql, (metadata_json, event.event_id))
|
txn.execute(sql, (metadata_json, event.event_id))
|
||||||
|
@ -826,10 +824,10 @@ class PersistEventsStore:
|
||||||
{
|
{
|
||||||
"event_id": event.event_id,
|
"event_id": event.event_id,
|
||||||
"room_id": event.room_id,
|
"room_id": event.room_id,
|
||||||
"internal_metadata": frozendict_json_encoder.encode(
|
"internal_metadata": json_encoder.encode(
|
||||||
event.internal_metadata.get_dict()
|
event.internal_metadata.get_dict()
|
||||||
),
|
),
|
||||||
"json": frozendict_json_encoder.encode(event_dict(event)),
|
"json": json_encoder.encode(event_dict(event)),
|
||||||
"format_version": event.format_version,
|
"format_version": event.format_version,
|
||||||
}
|
}
|
||||||
for event, _ in events_and_contexts
|
for event, _ in events_and_contexts
|
||||||
|
|
|
@ -31,6 +31,7 @@ from synapse.api.room_versions import (
|
||||||
RoomVersions,
|
RoomVersions,
|
||||||
)
|
)
|
||||||
from synapse.events import EventBase, make_event_from_dict
|
from synapse.events import EventBase, make_event_from_dict
|
||||||
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.events.utils import prune_event
|
from synapse.events.utils import prune_event
|
||||||
from synapse.logging.context import PreserveLoggingContext, current_context
|
from synapse.logging.context import PreserveLoggingContext, current_context
|
||||||
from synapse.metrics.background_process_metrics import (
|
from synapse.metrics.background_process_metrics import (
|
||||||
|
@ -44,7 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
||||||
from synapse.types import Collection, get_domain_from_id
|
from synapse.types import Collection, JsonDict, get_domain_from_id
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
|
@ -525,6 +526,57 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
return event_map
|
return event_map
|
||||||
|
|
||||||
|
async def get_stripped_room_state_from_event_context(
|
||||||
|
self,
|
||||||
|
context: EventContext,
|
||||||
|
state_types_to_include: List[EventTypes],
|
||||||
|
membership_user_id: Optional[str] = None,
|
||||||
|
) -> List[JsonDict]:
|
||||||
|
"""
|
||||||
|
Retrieve the stripped state from a room, given an event context to retrieve state
|
||||||
|
from as well as the state types to include. Optionally, include the membership
|
||||||
|
events from a specific user.
|
||||||
|
|
||||||
|
"Stripped" state means that only the `type`, `state_key`, `content` and `sender` keys
|
||||||
|
are included from each state event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: The event context to retrieve state of the room from.
|
||||||
|
state_types_to_include: The type of state events to include.
|
||||||
|
membership_user_id: An optional user ID to include the stripped membership state
|
||||||
|
events of. This is useful when generating the stripped state of a room for
|
||||||
|
invites. We want to send membership events of the inviter, so that the
|
||||||
|
invitee can display the inviter's profile information if the room lacks any.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of dictionaries, each representing a stripped state event from the room.
|
||||||
|
"""
|
||||||
|
current_state_ids = await context.get_current_state_ids()
|
||||||
|
|
||||||
|
# We know this event is not an outlier, so this must be
|
||||||
|
# non-None.
|
||||||
|
assert current_state_ids is not None
|
||||||
|
|
||||||
|
# The state to include
|
||||||
|
state_to_include_ids = [
|
||||||
|
e_id
|
||||||
|
for k, e_id in current_state_ids.items()
|
||||||
|
if k[0] in state_types_to_include
|
||||||
|
or (membership_user_id and k == (EventTypes.Member, membership_user_id))
|
||||||
|
]
|
||||||
|
|
||||||
|
state_to_include = await self.get_events(state_to_include_ids)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": e.type,
|
||||||
|
"state_key": e.state_key,
|
||||||
|
"content": e.content,
|
||||||
|
"sender": e.sender,
|
||||||
|
}
|
||||||
|
for e in state_to_include.values()
|
||||||
|
]
|
||||||
|
|
||||||
def _do_fetch(self, conn):
|
def _do_fetch(self, conn):
|
||||||
"""Takes a database connection and waits for requests for events from
|
"""Takes a database connection and waits for requests for events from
|
||||||
the _event_fetch_list queue.
|
the _event_fetch_list queue.
|
||||||
|
@ -1065,11 +1117,13 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
def get_all_new_forward_event_rows(txn):
|
def get_all_new_forward_event_rows(txn):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
|
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
|
||||||
" state_key, redacts, relates_to_id"
|
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
|
||||||
" FROM events AS e"
|
" FROM events AS e"
|
||||||
" LEFT JOIN redactions USING (event_id)"
|
" LEFT JOIN redactions USING (event_id)"
|
||||||
" LEFT JOIN state_events USING (event_id)"
|
" LEFT JOIN state_events USING (event_id)"
|
||||||
" LEFT JOIN event_relations USING (event_id)"
|
" LEFT JOIN event_relations USING (event_id)"
|
||||||
|
" LEFT JOIN room_memberships USING (event_id)"
|
||||||
|
" LEFT JOIN rejections USING (event_id)"
|
||||||
" WHERE ? < stream_ordering AND stream_ordering <= ?"
|
" WHERE ? < stream_ordering AND stream_ordering <= ?"
|
||||||
" AND instance_name = ?"
|
" AND instance_name = ?"
|
||||||
" ORDER BY stream_ordering ASC"
|
" ORDER BY stream_ordering ASC"
|
||||||
|
@ -1100,12 +1154,14 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
def get_ex_outlier_stream_rows_txn(txn):
|
def get_ex_outlier_stream_rows_txn(txn):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
|
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
|
||||||
" state_key, redacts, relates_to_id"
|
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
|
||||||
" FROM events AS e"
|
" FROM events AS e"
|
||||||
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
|
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
|
||||||
" LEFT JOIN redactions USING (event_id)"
|
" LEFT JOIN redactions USING (event_id)"
|
||||||
" LEFT JOIN state_events USING (event_id)"
|
" LEFT JOIN state_events USING (event_id)"
|
||||||
" LEFT JOIN event_relations USING (event_id)"
|
" LEFT JOIN event_relations USING (event_id)"
|
||||||
|
" LEFT JOIN room_memberships USING (event_id)"
|
||||||
|
" LEFT JOIN rejections USING (event_id)"
|
||||||
" WHERE ? < event_stream_ordering"
|
" WHERE ? < event_stream_ordering"
|
||||||
" AND event_stream_ordering <= ?"
|
" AND event_stream_ordering <= ?"
|
||||||
" AND out.instance_name = ?"
|
" AND out.instance_name = ?"
|
||||||
|
|
|
@ -452,6 +452,33 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
desc="get_remote_media_thumbnails",
|
desc="get_remote_media_thumbnails",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def get_remote_media_thumbnail(
|
||||||
|
self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str,
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Fetch the thumbnail info of given width, height and type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return await self.db_pool.simple_select_one(
|
||||||
|
table="remote_media_cache_thumbnails",
|
||||||
|
keyvalues={
|
||||||
|
"media_origin": origin,
|
||||||
|
"media_id": media_id,
|
||||||
|
"thumbnail_width": t_width,
|
||||||
|
"thumbnail_height": t_height,
|
||||||
|
"thumbnail_type": t_type,
|
||||||
|
},
|
||||||
|
retcols=(
|
||||||
|
"thumbnail_width",
|
||||||
|
"thumbnail_height",
|
||||||
|
"thumbnail_method",
|
||||||
|
"thumbnail_type",
|
||||||
|
"thumbnail_length",
|
||||||
|
"filesystem_id",
|
||||||
|
),
|
||||||
|
allow_none=True,
|
||||||
|
desc="get_remote_media_thumbnail",
|
||||||
|
)
|
||||||
|
|
||||||
async def store_remote_media_thumbnail(
|
async def store_remote_media_thumbnail(
|
||||||
self,
|
self,
|
||||||
origin,
|
origin,
|
||||||
|
|
|
@ -18,6 +18,8 @@ import logging
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
from synapse.api.constants import UserTypes
|
from synapse.api.constants import UserTypes
|
||||||
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
||||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
|
@ -38,6 +40,35 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(frozen=True, slots=True)
|
||||||
|
class TokenLookupResult:
|
||||||
|
"""Result of looking up an access token.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
user_id: The user that this token authenticates as
|
||||||
|
is_guest
|
||||||
|
shadow_banned
|
||||||
|
token_id: The ID of the access token looked up
|
||||||
|
device_id: The device associated with the token, if any.
|
||||||
|
valid_until_ms: The timestamp the token expires, if any.
|
||||||
|
token_owner: The "owner" of the token. This is either the same as the
|
||||||
|
user, or a server admin who is logged in as the user.
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_id = attr.ib(type=str)
|
||||||
|
is_guest = attr.ib(type=bool, default=False)
|
||||||
|
shadow_banned = attr.ib(type=bool, default=False)
|
||||||
|
token_id = attr.ib(type=Optional[int], default=None)
|
||||||
|
device_id = attr.ib(type=Optional[str], default=None)
|
||||||
|
valid_until_ms = attr.ib(type=Optional[int], default=None)
|
||||||
|
token_owner = attr.ib(type=str)
|
||||||
|
|
||||||
|
# Make the token owner default to the user ID, which is the common case.
|
||||||
|
@token_owner.default
|
||||||
|
def _default_token_owner(self):
|
||||||
|
return self.user_id
|
||||||
|
|
||||||
|
|
||||||
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
@ -102,15 +133,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
return is_trial
|
return is_trial
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
async def get_user_by_access_token(self, token: str) -> Optional[dict]:
|
async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]:
|
||||||
"""Get a user from the given access token.
|
"""Get a user from the given access token.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token: The access token of a user.
|
token: The access token of a user.
|
||||||
Returns:
|
Returns:
|
||||||
None, if the token did not match, otherwise dict
|
None, if the token did not match, otherwise a `TokenLookupResult`
|
||||||
including the keys `name`, `is_guest`, `device_id`, `token_id`,
|
|
||||||
`valid_until_ms`.
|
|
||||||
"""
|
"""
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_user_by_access_token", self._query_for_auth, token
|
"get_user_by_access_token", self._query_for_auth, token
|
||||||
|
@ -331,23 +360,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
|
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
|
||||||
|
|
||||||
def _query_for_auth(self, txn, token):
|
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT users.name,
|
SELECT users.name as user_id,
|
||||||
users.is_guest,
|
users.is_guest,
|
||||||
users.shadow_banned,
|
users.shadow_banned,
|
||||||
access_tokens.id as token_id,
|
access_tokens.id as token_id,
|
||||||
access_tokens.device_id,
|
access_tokens.device_id,
|
||||||
access_tokens.valid_until_ms
|
access_tokens.valid_until_ms,
|
||||||
|
access_tokens.user_id as token_owner
|
||||||
FROM users
|
FROM users
|
||||||
INNER JOIN access_tokens on users.name = access_tokens.user_id
|
INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
|
||||||
WHERE token = ?
|
WHERE token = ?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
txn.execute(sql, (token,))
|
txn.execute(sql, (token,))
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
rows = self.db_pool.cursor_to_dict(txn)
|
||||||
if rows:
|
if rows:
|
||||||
return rows[0]
|
return TokenLookupResult(**rows[0])
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- Whether the access token is an admin token for controlling another user.
|
||||||
|
ALTER TABLE access_tokens ADD COLUMN puppets_user_id TEXT;
|
|
@ -29,6 +29,7 @@ from typing import (
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -38,6 +39,7 @@ from unpaddedbase64 import decode_base64
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from synapse.appservice.api import ApplicationService
|
||||||
from synapse.storage.databases.main import DataStore
|
from synapse.storage.databases.main import DataStore
|
||||||
|
|
||||||
# define a version of typing.Collection that works on python 3.5
|
# define a version of typing.Collection that works on python 3.5
|
||||||
|
@ -74,6 +76,7 @@ class Requester(
|
||||||
"shadow_banned",
|
"shadow_banned",
|
||||||
"device_id",
|
"device_id",
|
||||||
"app_service",
|
"app_service",
|
||||||
|
"authenticated_entity",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
|
@ -104,6 +107,7 @@ class Requester(
|
||||||
"shadow_banned": self.shadow_banned,
|
"shadow_banned": self.shadow_banned,
|
||||||
"device_id": self.device_id,
|
"device_id": self.device_id,
|
||||||
"app_server_id": self.app_service.id if self.app_service else None,
|
"app_server_id": self.app_service.id if self.app_service else None,
|
||||||
|
"authenticated_entity": self.authenticated_entity,
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -129,16 +133,18 @@ class Requester(
|
||||||
shadow_banned=input["shadow_banned"],
|
shadow_banned=input["shadow_banned"],
|
||||||
device_id=input["device_id"],
|
device_id=input["device_id"],
|
||||||
app_service=appservice,
|
app_service=appservice,
|
||||||
|
authenticated_entity=input["authenticated_entity"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_requester(
|
def create_requester(
|
||||||
user_id,
|
user_id: Union[str, "UserID"],
|
||||||
access_token_id=None,
|
access_token_id: Optional[int] = None,
|
||||||
is_guest=False,
|
is_guest: Optional[bool] = False,
|
||||||
shadow_banned=False,
|
shadow_banned: Optional[bool] = False,
|
||||||
device_id=None,
|
device_id: Optional[str] = None,
|
||||||
app_service=None,
|
app_service: Optional["ApplicationService"] = None,
|
||||||
|
authenticated_entity: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a new ``Requester`` object
|
Create a new ``Requester`` object
|
||||||
|
@ -151,14 +157,27 @@ def create_requester(
|
||||||
shadow_banned (bool): True if the user making this request is shadow-banned.
|
shadow_banned (bool): True if the user making this request is shadow-banned.
|
||||||
device_id (str|None): device_id which was set at authentication time
|
device_id (str|None): device_id which was set at authentication time
|
||||||
app_service (ApplicationService|None): the AS requesting on behalf of the user
|
app_service (ApplicationService|None): the AS requesting on behalf of the user
|
||||||
|
authenticated_entity: The entity that authenticated when making the request.
|
||||||
|
This is different to the user_id when an admin user or the server is
|
||||||
|
"puppeting" the user.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Requester
|
Requester
|
||||||
"""
|
"""
|
||||||
if not isinstance(user_id, UserID):
|
if not isinstance(user_id, UserID):
|
||||||
user_id = UserID.from_string(user_id)
|
user_id = UserID.from_string(user_id)
|
||||||
|
|
||||||
|
if authenticated_entity is None:
|
||||||
|
authenticated_entity = user_id.to_string()
|
||||||
|
|
||||||
return Requester(
|
return Requester(
|
||||||
user_id, access_token_id, is_guest, shadow_banned, device_id, app_service
|
user_id,
|
||||||
|
access_token_id,
|
||||||
|
is_guest,
|
||||||
|
shadow_banned,
|
||||||
|
device_id,
|
||||||
|
app_service,
|
||||||
|
authenticated_entity,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ import logging
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
from frozendict import frozendict
|
||||||
|
|
||||||
from twisted.internet import defer, task
|
from twisted.internet import defer, task
|
||||||
|
|
||||||
|
@ -31,9 +32,26 @@ def _reject_invalid_json(val):
|
||||||
raise ValueError("Invalid JSON value: '%s'" % val)
|
raise ValueError("Invalid JSON value: '%s'" % val)
|
||||||
|
|
||||||
|
|
||||||
# Create a custom encoder to reduce the whitespace produced by JSON encoding and
|
def _handle_frozendict(obj):
|
||||||
# ensure that valid JSON is produced.
|
"""Helper for json_encoder. Makes frozendicts serializable by returning
|
||||||
json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":"))
|
the underlying dict
|
||||||
|
"""
|
||||||
|
if type(obj) is frozendict:
|
||||||
|
# fishing the protected dict out of the object is a bit nasty,
|
||||||
|
# but we don't really want the overhead of copying the dict.
|
||||||
|
return obj._dict
|
||||||
|
raise TypeError(
|
||||||
|
"Object of type %s is not JSON serializable" % obj.__class__.__name__
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# A custom JSON encoder which:
|
||||||
|
# * handles frozendicts
|
||||||
|
# * produces valid JSON (no NaNs etc)
|
||||||
|
# * reduces redundant whitespace
|
||||||
|
json_encoder = json.JSONEncoder(
|
||||||
|
allow_nan=False, separators=(",", ":"), default=_handle_frozendict
|
||||||
|
)
|
||||||
|
|
||||||
# Create a custom decoder to reject Python extensions to JSON.
|
# Create a custom decoder to reject Python extensions to JSON.
|
||||||
json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
|
json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
|
||||||
|
|
|
@ -13,10 +13,23 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import enum
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Generic,
|
||||||
|
Iterable,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
from weakref import WeakValueDictionary
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
@ -24,6 +37,7 @@ from twisted.internet import defer
|
||||||
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.caches.deferred_cache import DeferredCache
|
from synapse.util.caches.deferred_cache import DeferredCache
|
||||||
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -48,7 +62,7 @@ class _CachedFunction(Generic[F]):
|
||||||
|
|
||||||
|
|
||||||
class _CacheDescriptorBase:
|
class _CacheDescriptorBase:
|
||||||
def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
|
def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
|
||||||
self.orig = orig
|
self.orig = orig
|
||||||
|
|
||||||
arg_spec = inspect.getfullargspec(orig)
|
arg_spec = inspect.getfullargspec(orig)
|
||||||
|
@ -97,8 +111,107 @@ class _CacheDescriptorBase:
|
||||||
|
|
||||||
self.add_cache_context = cache_context
|
self.add_cache_context = cache_context
|
||||||
|
|
||||||
|
self.cache_key_builder = get_cache_key_builder(
|
||||||
|
self.arg_names, self.arg_defaults
|
||||||
|
)
|
||||||
|
|
||||||
class CacheDescriptor(_CacheDescriptorBase):
|
|
||||||
|
class _LruCachedFunction(Generic[F]):
|
||||||
|
cache = None # type: LruCache[CacheKey, Any]
|
||||||
|
__call__ = None # type: F
|
||||||
|
|
||||||
|
|
||||||
|
def lru_cache(
|
||||||
|
max_entries: int = 1000, cache_context: bool = False,
|
||||||
|
) -> Callable[[F], _LruCachedFunction[F]]:
|
||||||
|
"""A method decorator that applies a memoizing cache around the function.
|
||||||
|
|
||||||
|
This is more-or-less a drop-in equivalent to functools.lru_cache, although note
|
||||||
|
that the signature is slightly different.
|
||||||
|
|
||||||
|
The main differences with functools.lru_cache are:
|
||||||
|
(a) the size of the cache can be controlled via the cache_factor mechanism
|
||||||
|
(b) the wrapped function can request a "cache_context" which provides a
|
||||||
|
callback mechanism to indicate that the result is no longer valid
|
||||||
|
(c) prometheus metrics are exposed automatically.
|
||||||
|
|
||||||
|
The function should take zero or more arguments, which are used as the key for the
|
||||||
|
cache. Single-argument functions use that argument as the cache key; otherwise the
|
||||||
|
arguments are built into a tuple.
|
||||||
|
|
||||||
|
Cached functions can be "chained" (i.e. a cached function can call other cached
|
||||||
|
functions and get appropriately invalidated when they called caches are
|
||||||
|
invalidated) by adding a special "cache_context" argument to the function
|
||||||
|
and passing that as a kwarg to all caches called. For example:
|
||||||
|
|
||||||
|
@lru_cache(cache_context=True)
|
||||||
|
def foo(self, key, cache_context):
|
||||||
|
r1 = self.bar1(key, on_invalidate=cache_context.invalidate)
|
||||||
|
r2 = self.bar2(key, on_invalidate=cache_context.invalidate)
|
||||||
|
return r1 + r2
|
||||||
|
|
||||||
|
The wrapped function also has a 'cache' property which offers direct access to the
|
||||||
|
underlying LruCache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def func(orig: F) -> _LruCachedFunction[F]:
|
||||||
|
desc = LruCacheDescriptor(
|
||||||
|
orig, max_entries=max_entries, cache_context=cache_context,
|
||||||
|
)
|
||||||
|
return cast(_LruCachedFunction[F], desc)
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
class LruCacheDescriptor(_CacheDescriptorBase):
|
||||||
|
"""Helper for @lru_cache"""
|
||||||
|
|
||||||
|
class _Sentinel(enum.Enum):
|
||||||
|
sentinel = object()
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, orig, max_entries: int = 1000, cache_context: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(orig, num_args=None, cache_context=cache_context)
|
||||||
|
self.max_entries = max_entries
|
||||||
|
|
||||||
|
def __get__(self, obj, owner):
|
||||||
|
cache = LruCache(
|
||||||
|
cache_name=self.orig.__name__, max_size=self.max_entries,
|
||||||
|
) # type: LruCache[CacheKey, Any]
|
||||||
|
|
||||||
|
get_cache_key = self.cache_key_builder
|
||||||
|
sentinel = LruCacheDescriptor._Sentinel.sentinel
|
||||||
|
|
||||||
|
@functools.wraps(self.orig)
|
||||||
|
def _wrapped(*args, **kwargs):
|
||||||
|
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||||
|
callbacks = (invalidate_callback,) if invalidate_callback else ()
|
||||||
|
|
||||||
|
cache_key = get_cache_key(args, kwargs)
|
||||||
|
|
||||||
|
ret = cache.get(cache_key, default=sentinel, callbacks=callbacks)
|
||||||
|
if ret != sentinel:
|
||||||
|
return ret
|
||||||
|
|
||||||
|
# Add our own `cache_context` to argument list if the wrapped function
|
||||||
|
# has asked for one
|
||||||
|
if self.add_cache_context:
|
||||||
|
kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
|
||||||
|
|
||||||
|
ret2 = self.orig(obj, *args, **kwargs)
|
||||||
|
cache.set(cache_key, ret2, callbacks=callbacks)
|
||||||
|
|
||||||
|
return ret2
|
||||||
|
|
||||||
|
wrapped = cast(_CachedFunction, _wrapped)
|
||||||
|
wrapped.cache = cache
|
||||||
|
obj.__dict__[self.orig.__name__] = wrapped
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
|
class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||||
""" A method decorator that applies a memoizing cache around the function.
|
""" A method decorator that applies a memoizing cache around the function.
|
||||||
|
|
||||||
This caches deferreds, rather than the results themselves. Deferreds that
|
This caches deferreds, rather than the results themselves. Deferreds that
|
||||||
|
@ -141,7 +254,6 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||||
cache_context=False,
|
cache_context=False,
|
||||||
iterable=False,
|
iterable=False,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(orig, num_args=num_args, cache_context=cache_context)
|
super().__init__(orig, num_args=num_args, cache_context=cache_context)
|
||||||
|
|
||||||
self.max_entries = max_entries
|
self.max_entries = max_entries
|
||||||
|
@ -157,41 +269,7 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||||
iterable=self.iterable,
|
iterable=self.iterable,
|
||||||
) # type: DeferredCache[CacheKey, Any]
|
) # type: DeferredCache[CacheKey, Any]
|
||||||
|
|
||||||
def get_cache_key_gen(args, kwargs):
|
get_cache_key = self.cache_key_builder
|
||||||
"""Given some args/kwargs return a generator that resolves into
|
|
||||||
the cache_key.
|
|
||||||
|
|
||||||
We loop through each arg name, looking up if its in the `kwargs`,
|
|
||||||
otherwise using the next argument in `args`. If there are no more
|
|
||||||
args then we try looking the arg name up in the defaults
|
|
||||||
"""
|
|
||||||
pos = 0
|
|
||||||
for nm in self.arg_names:
|
|
||||||
if nm in kwargs:
|
|
||||||
yield kwargs[nm]
|
|
||||||
elif pos < len(args):
|
|
||||||
yield args[pos]
|
|
||||||
pos += 1
|
|
||||||
else:
|
|
||||||
yield self.arg_defaults[nm]
|
|
||||||
|
|
||||||
# By default our cache key is a tuple, but if there is only one item
|
|
||||||
# then don't bother wrapping in a tuple. This is to save memory.
|
|
||||||
if self.num_args == 1:
|
|
||||||
nm = self.arg_names[0]
|
|
||||||
|
|
||||||
def get_cache_key(args, kwargs):
|
|
||||||
if nm in kwargs:
|
|
||||||
return kwargs[nm]
|
|
||||||
elif len(args):
|
|
||||||
return args[0]
|
|
||||||
else:
|
|
||||||
return self.arg_defaults[nm]
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
def get_cache_key(args, kwargs):
|
|
||||||
return tuple(get_cache_key_gen(args, kwargs))
|
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
def _wrapped(*args, **kwargs):
|
def _wrapped(*args, **kwargs):
|
||||||
|
@ -223,7 +301,6 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||||
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
|
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
|
||||||
else:
|
else:
|
||||||
wrapped.invalidate = cache.invalidate
|
wrapped.invalidate = cache.invalidate
|
||||||
wrapped.invalidate_all = cache.invalidate_all
|
|
||||||
wrapped.invalidate_many = cache.invalidate_many
|
wrapped.invalidate_many = cache.invalidate_many
|
||||||
wrapped.prefill = cache.prefill
|
wrapped.prefill = cache.prefill
|
||||||
|
|
||||||
|
@ -236,7 +313,7 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
class CacheListDescriptor(_CacheDescriptorBase):
|
class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
||||||
"""Wraps an existing cache to support bulk fetching of keys.
|
"""Wraps an existing cache to support bulk fetching of keys.
|
||||||
|
|
||||||
Given a list of keys it looks in the cache to find any hits, then passes
|
Given a list of keys it looks in the cache to find any hits, then passes
|
||||||
|
@ -382,11 +459,13 @@ class _CacheContext:
|
||||||
on a lower level.
|
on a lower level.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
Cache = Union[DeferredCache, LruCache]
|
||||||
|
|
||||||
_cache_context_objects = (
|
_cache_context_objects = (
|
||||||
WeakValueDictionary()
|
WeakValueDictionary()
|
||||||
) # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext]
|
) # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext]
|
||||||
|
|
||||||
def __init__(self, cache, cache_key): # type: (DeferredCache, CacheKey) -> None
|
def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None:
|
||||||
self._cache = cache
|
self._cache = cache
|
||||||
self._cache_key = cache_key
|
self._cache_key = cache_key
|
||||||
|
|
||||||
|
@ -396,8 +475,8 @@ class _CacheContext:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(
|
def get_instance(
|
||||||
cls, cache, cache_key
|
cls, cache: "_CacheContext.Cache", cache_key: CacheKey
|
||||||
): # type: (DeferredCache, CacheKey) -> _CacheContext
|
) -> "_CacheContext":
|
||||||
"""Returns an instance constructed with the given arguments.
|
"""Returns an instance constructed with the given arguments.
|
||||||
|
|
||||||
A new instance is only created if none already exists.
|
A new instance is only created if none already exists.
|
||||||
|
@ -418,7 +497,7 @@ def cached(
|
||||||
cache_context: bool = False,
|
cache_context: bool = False,
|
||||||
iterable: bool = False,
|
iterable: bool = False,
|
||||||
) -> Callable[[F], _CachedFunction[F]]:
|
) -> Callable[[F], _CachedFunction[F]]:
|
||||||
func = lambda orig: CacheDescriptor(
|
func = lambda orig: DeferredCacheDescriptor(
|
||||||
orig,
|
orig,
|
||||||
max_entries=max_entries,
|
max_entries=max_entries,
|
||||||
num_args=num_args,
|
num_args=num_args,
|
||||||
|
@ -460,7 +539,7 @@ def cachedList(
|
||||||
def batch_do_something(self, first_arg, second_args):
|
def batch_do_something(self, first_arg, second_args):
|
||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
func = lambda orig: CacheListDescriptor(
|
func = lambda orig: DeferredCacheListDescriptor(
|
||||||
orig,
|
orig,
|
||||||
cached_method_name=cached_method_name,
|
cached_method_name=cached_method_name,
|
||||||
list_name=list_name,
|
list_name=list_name,
|
||||||
|
@ -468,3 +547,65 @@ def cachedList(
|
||||||
)
|
)
|
||||||
|
|
||||||
return cast(Callable[[F], _CachedFunction[F]], func)
|
return cast(Callable[[F], _CachedFunction[F]], func)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_key_builder(
|
||||||
|
param_names: Sequence[str], param_defaults: Mapping[str, Any]
|
||||||
|
) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]:
|
||||||
|
"""Construct a function which will build cache keys suitable for a cached function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param_names: list of formal parameter names for the cached function
|
||||||
|
param_defaults: a mapping from parameter name to default value for that param
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A function which will take an (args, kwargs) pair and return a cache key
|
||||||
|
"""
|
||||||
|
|
||||||
|
# By default our cache key is a tuple, but if there is only one item
|
||||||
|
# then don't bother wrapping in a tuple. This is to save memory.
|
||||||
|
|
||||||
|
if len(param_names) == 1:
|
||||||
|
nm = param_names[0]
|
||||||
|
|
||||||
|
def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
|
||||||
|
if nm in kwargs:
|
||||||
|
return kwargs[nm]
|
||||||
|
elif len(args):
|
||||||
|
return args[0]
|
||||||
|
else:
|
||||||
|
return param_defaults[nm]
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
|
||||||
|
return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs))
|
||||||
|
|
||||||
|
return get_cache_key
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cache_key_gen(
|
||||||
|
param_names: Iterable[str],
|
||||||
|
param_defaults: Mapping[str, Any],
|
||||||
|
args: Sequence[Any],
|
||||||
|
kwargs: Mapping[str, Any],
|
||||||
|
) -> Iterable[Any]:
|
||||||
|
"""Given some args/kwargs return a generator that resolves into
|
||||||
|
the cache_key.
|
||||||
|
|
||||||
|
This is essentially the same operation as `inspect.getcallargs`, but optimised so
|
||||||
|
that we don't need to inspect the target function for each call.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# We loop through each arg name, looking up if its in the `kwargs`,
|
||||||
|
# otherwise using the next argument in `args`. If there are no more
|
||||||
|
# args then we try looking the arg name up in the defaults.
|
||||||
|
pos = 0
|
||||||
|
for nm in param_names:
|
||||||
|
if nm in kwargs:
|
||||||
|
yield kwargs[nm]
|
||||||
|
elif pos < len(args):
|
||||||
|
yield args[pos]
|
||||||
|
pos += 1
|
||||||
|
else:
|
||||||
|
yield param_defaults[nm]
|
||||||
|
|
|
@ -13,8 +13,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,23 +47,3 @@ def unfreeze(o):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
|
||||||
def _handle_frozendict(obj):
|
|
||||||
"""Helper for EventEncoder. Makes frozendicts serializable by returning
|
|
||||||
the underlying dict
|
|
||||||
"""
|
|
||||||
if type(obj) is frozendict:
|
|
||||||
# fishing the protected dict out of the object is a bit nasty,
|
|
||||||
# but we don't really want the overhead of copying the dict.
|
|
||||||
return obj._dict
|
|
||||||
raise TypeError(
|
|
||||||
"Object of type %s is not JSON serializable" % obj.__class__.__name__
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# A JSONEncoder which is capable of encoding frozendicts without barfing.
|
|
||||||
# Additionally reduce the whitespace produced by JSON encoding.
|
|
||||||
frozendict_json_encoder = json.JSONEncoder(
|
|
||||||
allow_nan=False, separators=(",", ":"), default=_handle_frozendict,
|
|
||||||
)
|
|
||||||
|
|
|
@ -110,7 +110,7 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k
|
||||||
failure_ts,
|
failure_ts,
|
||||||
retry_interval,
|
retry_interval,
|
||||||
backoff_on_failure=backoff_on_failure,
|
backoff_on_failure=backoff_on_failure,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,45 +21,6 @@ except ImportError:
|
||||||
from twisted.internet.pollreactor import PollReactor as Reactor
|
from twisted.internet.pollreactor import PollReactor as Reactor
|
||||||
from twisted.internet.main import installReactor
|
from twisted.internet.main import installReactor
|
||||||
|
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
|
||||||
from synapse.util import Clock
|
|
||||||
|
|
||||||
from tests.utils import default_config, setup_test_homeserver
|
|
||||||
|
|
||||||
|
|
||||||
async def make_homeserver(reactor, config=None):
|
|
||||||
"""
|
|
||||||
Make a Homeserver suitable for running benchmarks against.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
reactor: A Twisted reactor to run under.
|
|
||||||
config: A HomeServerConfig to use, or None.
|
|
||||||
"""
|
|
||||||
cleanup_tasks = []
|
|
||||||
clock = Clock(reactor)
|
|
||||||
|
|
||||||
if not config:
|
|
||||||
config = default_config("test")
|
|
||||||
|
|
||||||
config_obj = HomeServerConfig()
|
|
||||||
config_obj.parse_config_dict(config, "", "")
|
|
||||||
|
|
||||||
hs = setup_test_homeserver(
|
|
||||||
cleanup_tasks.append, config=config_obj, reactor=reactor, clock=clock
|
|
||||||
)
|
|
||||||
stor = hs.get_datastore()
|
|
||||||
|
|
||||||
# Run the database background updates.
|
|
||||||
if hasattr(stor.db_pool.updates, "do_next_background_update"):
|
|
||||||
while not await stor.db_pool.updates.has_completed_background_updates():
|
|
||||||
await stor.db_pool.updates.do_next_background_update(1)
|
|
||||||
|
|
||||||
def cleanup():
|
|
||||||
for i in cleanup_tasks:
|
|
||||||
i()
|
|
||||||
|
|
||||||
return hs, clock.sleep, cleanup
|
|
||||||
|
|
||||||
|
|
||||||
def make_reactor():
|
def make_reactor():
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -12,20 +12,20 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
from argparse import REMAINDER
|
from argparse import REMAINDER
|
||||||
from contextlib import redirect_stderr
|
from contextlib import redirect_stderr
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
import pyperf
|
import pyperf
|
||||||
from synmark import make_reactor
|
|
||||||
from synmark.suites import SUITES
|
|
||||||
|
|
||||||
from twisted.internet.defer import Deferred, ensureDeferred
|
from twisted.internet.defer import Deferred, ensureDeferred
|
||||||
from twisted.logger import globalLogBeginner, textFileLogObserver
|
from twisted.logger import globalLogBeginner, textFileLogObserver
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
|
from synmark import make_reactor
|
||||||
|
from synmark.suites import SUITES
|
||||||
|
|
||||||
from tests.utils import setupdb
|
from tests.utils import setupdb
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,20 +13,22 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
from pyperf import perf_counter
|
from pyperf import perf_counter
|
||||||
from synmark import make_homeserver
|
|
||||||
|
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
from twisted.internet.protocol import ServerFactory
|
from twisted.internet.protocol import ServerFactory
|
||||||
from twisted.logger import LogBeginner, Logger, LogPublisher
|
from twisted.logger import LogBeginner, LogPublisher
|
||||||
from twisted.protocols.basic import LineOnlyReceiver
|
from twisted.protocols.basic import LineOnlyReceiver
|
||||||
|
|
||||||
from synapse.logging._structured import setup_structured_logging
|
from synapse.config.logger import _setup_stdlib_logging
|
||||||
|
from synapse.logging import RemoteHandler
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
|
||||||
class LineCounter(LineOnlyReceiver):
|
class LineCounter(LineOnlyReceiver):
|
||||||
|
@ -62,7 +64,15 @@ async def main(reactor, loops):
|
||||||
logger_factory.on_done = Deferred()
|
logger_factory.on_done = Deferred()
|
||||||
port = reactor.listenTCP(0, logger_factory, interface="127.0.0.1")
|
port = reactor.listenTCP(0, logger_factory, interface="127.0.0.1")
|
||||||
|
|
||||||
hs, wait, cleanup = await make_homeserver(reactor)
|
# A fake homeserver config.
|
||||||
|
class Config:
|
||||||
|
server_name = "synmark-" + str(loops)
|
||||||
|
no_redirect_stdio = True
|
||||||
|
|
||||||
|
hs_config = Config()
|
||||||
|
|
||||||
|
# To be able to sleep.
|
||||||
|
clock = Clock(reactor)
|
||||||
|
|
||||||
errors = StringIO()
|
errors = StringIO()
|
||||||
publisher = LogPublisher()
|
publisher = LogPublisher()
|
||||||
|
@ -72,47 +82,49 @@ async def main(reactor, loops):
|
||||||
)
|
)
|
||||||
|
|
||||||
log_config = {
|
log_config = {
|
||||||
"loggers": {"synapse": {"level": "DEBUG"}},
|
"version": 1,
|
||||||
"drains": {
|
"loggers": {"synapse": {"level": "DEBUG", "handlers": ["tersejson"]}},
|
||||||
|
"formatters": {"tersejson": {"class": "synapse.logging.TerseJsonFormatter"}},
|
||||||
|
"handlers": {
|
||||||
"tersejson": {
|
"tersejson": {
|
||||||
"type": "network_json_terse",
|
"class": "synapse.logging.RemoteHandler",
|
||||||
"host": "127.0.0.1",
|
"host": "127.0.0.1",
|
||||||
"port": port.getHost().port,
|
"port": port.getHost().port,
|
||||||
"maximum_buffer": 100,
|
"maximum_buffer": 100,
|
||||||
|
"_reactor": reactor,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
logger = Logger(namespace="synapse.logging.test_terse_json", observer=publisher)
|
logger = logging.getLogger("synapse.logging.test_terse_json")
|
||||||
logging_system = setup_structured_logging(
|
_setup_stdlib_logging(
|
||||||
hs, hs.config, log_config, logBeginner=beginner, redirect_stdlib_logging=False
|
hs_config, log_config, logBeginner=beginner,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait for it to connect...
|
# Wait for it to connect...
|
||||||
await logging_system._observers[0]._service.whenConnected()
|
for handler in logging.getLogger("synapse").handlers:
|
||||||
|
if isinstance(handler, RemoteHandler):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Improperly configured: no RemoteHandler found.")
|
||||||
|
|
||||||
|
await handler._service.whenConnected()
|
||||||
|
|
||||||
start = perf_counter()
|
start = perf_counter()
|
||||||
|
|
||||||
# Send a bunch of useful messages
|
# Send a bunch of useful messages
|
||||||
for i in range(0, loops):
|
for i in range(0, loops):
|
||||||
logger.info("test message %s" % (i,))
|
logger.info("test message %s", i)
|
||||||
|
|
||||||
if (
|
if len(handler._buffer) == handler.maximum_buffer:
|
||||||
len(logging_system._observers[0]._buffer)
|
while len(handler._buffer) > handler.maximum_buffer / 2:
|
||||||
== logging_system._observers[0].maximum_buffer
|
await clock.sleep(0.01)
|
||||||
):
|
|
||||||
while (
|
|
||||||
len(logging_system._observers[0]._buffer)
|
|
||||||
> logging_system._observers[0].maximum_buffer / 2
|
|
||||||
):
|
|
||||||
await wait(0.01)
|
|
||||||
|
|
||||||
await logger_factory.on_done
|
await logger_factory.on_done
|
||||||
|
|
||||||
end = perf_counter() - start
|
end = perf_counter() - start
|
||||||
|
|
||||||
logging_system.stop()
|
handler.close()
|
||||||
port.stopListening()
|
port.stopListening()
|
||||||
cleanup()
|
|
||||||
|
|
||||||
return end
|
return end
|
||||||
|
|
|
@ -29,6 +29,7 @@ from synapse.api.errors import (
|
||||||
MissingClientTokenError,
|
MissingClientTokenError,
|
||||||
ResourceLimitError,
|
ResourceLimitError,
|
||||||
)
|
)
|
||||||
|
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
@ -61,7 +62,9 @@ class AuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_user_by_req_user_valid_token(self):
|
def test_get_user_by_req_user_valid_token(self):
|
||||||
user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
|
user_info = TokenLookupResult(
|
||||||
|
user_id=self.test_user, token_id=5, device_id="device"
|
||||||
|
)
|
||||||
self.store.get_user_by_access_token = Mock(
|
self.store.get_user_by_access_token = Mock(
|
||||||
return_value=defer.succeed(user_info)
|
return_value=defer.succeed(user_info)
|
||||||
)
|
)
|
||||||
|
@ -84,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
|
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
|
||||||
|
|
||||||
def test_get_user_by_req_user_missing_token(self):
|
def test_get_user_by_req_user_missing_token(self):
|
||||||
user_info = {"name": self.test_user, "token_id": "ditto"}
|
user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
|
||||||
self.store.get_user_by_access_token = Mock(
|
self.store.get_user_by_access_token = Mock(
|
||||||
return_value=defer.succeed(user_info)
|
return_value=defer.succeed(user_info)
|
||||||
)
|
)
|
||||||
|
@ -221,7 +224,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
def test_get_user_from_macaroon(self):
|
def test_get_user_from_macaroon(self):
|
||||||
self.store.get_user_by_access_token = Mock(
|
self.store.get_user_by_access_token = Mock(
|
||||||
return_value=defer.succeed(
|
return_value=defer.succeed(
|
||||||
{"name": "@baldrick:matrix.org", "device_id": "device"}
|
TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -237,12 +240,11 @@ class AuthTestCase(unittest.TestCase):
|
||||||
user_info = yield defer.ensureDeferred(
|
user_info = yield defer.ensureDeferred(
|
||||||
self.auth.get_user_by_access_token(macaroon.serialize())
|
self.auth.get_user_by_access_token(macaroon.serialize())
|
||||||
)
|
)
|
||||||
user = user_info["user"]
|
self.assertEqual(user_id, user_info.user_id)
|
||||||
self.assertEqual(UserID.from_string(user_id), user)
|
|
||||||
|
|
||||||
# TODO: device_id should come from the macaroon, but currently comes
|
# TODO: device_id should come from the macaroon, but currently comes
|
||||||
# from the db.
|
# from the db.
|
||||||
self.assertEqual(user_info["device_id"], "device")
|
self.assertEqual(user_info.device_id, "device")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_guest_user_from_macaroon(self):
|
def test_get_guest_user_from_macaroon(self):
|
||||||
|
@ -264,10 +266,8 @@ class AuthTestCase(unittest.TestCase):
|
||||||
user_info = yield defer.ensureDeferred(
|
user_info = yield defer.ensureDeferred(
|
||||||
self.auth.get_user_by_access_token(serialized)
|
self.auth.get_user_by_access_token(serialized)
|
||||||
)
|
)
|
||||||
user = user_info["user"]
|
self.assertEqual(user_id, user_info.user_id)
|
||||||
is_guest = user_info["is_guest"]
|
self.assertTrue(user_info.is_guest)
|
||||||
self.assertEqual(UserID.from_string(user_id), user)
|
|
||||||
self.assertTrue(is_guest)
|
|
||||||
self.store.get_user_by_id.assert_called_with(user_id)
|
self.store.get_user_by_id.assert_called_with(user_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -289,12 +289,9 @@ class AuthTestCase(unittest.TestCase):
|
||||||
if token != tok:
|
if token != tok:
|
||||||
return defer.succeed(None)
|
return defer.succeed(None)
|
||||||
return defer.succeed(
|
return defer.succeed(
|
||||||
{
|
TokenLookupResult(
|
||||||
"name": USER_ID,
|
user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE",
|
||||||
"is_guest": False,
|
)
|
||||||
"token_id": 1234,
|
|
||||||
"device_id": "DEVICE",
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.store.get_user_by_access_token = get_user
|
self.store.get_user_by_access_token = get_user
|
||||||
|
|
|
@ -43,7 +43,7 @@ class TestRatelimiter(unittest.TestCase):
|
||||||
|
|
||||||
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
|
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
|
||||||
appservice = ApplicationService(
|
appservice = ApplicationService(
|
||||||
None, "example.com", id="foo", rate_limited=True,
|
None, "example.com", id="foo", rate_limited=True, sender="@as:example.com",
|
||||||
)
|
)
|
||||||
as_requester = create_requester("@user:example.com", app_service=appservice)
|
as_requester = create_requester("@user:example.com", app_service=appservice)
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@ class TestRatelimiter(unittest.TestCase):
|
||||||
|
|
||||||
def test_allowed_appservice_via_can_requester_do_action(self):
|
def test_allowed_appservice_via_can_requester_do_action(self):
|
||||||
appservice = ApplicationService(
|
appservice = ApplicationService(
|
||||||
None, "example.com", id="foo", rate_limited=False,
|
None, "example.com", id="foo", rate_limited=False, sender="@as:example.com",
|
||||||
)
|
)
|
||||||
as_requester = create_requester("@user:example.com", app_service=appservice)
|
as_requester = create_requester("@user:example.com", app_service=appservice)
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.service = ApplicationService(
|
self.service = ApplicationService(
|
||||||
id="unique_identifier",
|
id="unique_identifier",
|
||||||
|
sender="@as:test",
|
||||||
url="some_url",
|
url="some_url",
|
||||||
token="some_token",
|
token="some_token",
|
||||||
hostname="matrix.org", # only used by get_groups_for_user
|
hostname="matrix.org", # only used by get_groups_for_user
|
||||||
|
|
|
@ -289,7 +289,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
|
||||||
# make sure that our device ID has changed
|
# make sure that our device ID has changed
|
||||||
user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
|
user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
|
||||||
|
|
||||||
self.assertEqual(user_info["device_id"], retrieved_device_id)
|
self.assertEqual(user_info.device_id, retrieved_device_id)
|
||||||
|
|
||||||
# make sure the device has the display name that was set from the login
|
# make sure the device has the display name that was set from the login
|
||||||
res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
|
res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
|
||||||
|
|
|
@ -46,7 +46,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||||
self.info = self.get_success(
|
self.info = self.get_success(
|
||||||
self.hs.get_datastore().get_user_by_access_token(self.access_token,)
|
self.hs.get_datastore().get_user_by_access_token(self.access_token,)
|
||||||
)
|
)
|
||||||
self.token_id = self.info["token_id"]
|
self.token_id = self.info.token_id
|
||||||
|
|
||||||
self.requester = create_requester(self.user_id, access_token_id=self.token_id)
|
self.requester = create_requester(self.user_id, access_token_id=self.token_id)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
class LoggerCleanupMixin:
|
||||||
|
def get_logger(self, handler):
|
||||||
|
"""
|
||||||
|
Attach a handler to a logger and add clean-ups to remove revert this.
|
||||||
|
"""
|
||||||
|
# Create a logger and add the handler to it.
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
# Ensure the logger actually logs something.
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
# Ensure the logger gets cleaned-up appropriately.
|
||||||
|
self.addCleanup(logger.removeHandler, handler)
|
||||||
|
self.addCleanup(logger.setLevel, logging.NOTSET)
|
||||||
|
|
||||||
|
return logger
|
169
tests/logging/test_remote_handler.py
Normal file
169
tests/logging/test_remote_handler.py
Normal file
|
@ -0,0 +1,169 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from twisted.test.proto_helpers import AccumulatingProtocol
|
||||||
|
|
||||||
|
from synapse.logging import RemoteHandler
|
||||||
|
|
||||||
|
from tests.logging import LoggerCleanupMixin
|
||||||
|
from tests.server import FakeTransport, get_clock
|
||||||
|
from tests.unittest import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
def connect_logging_client(reactor, client_id):
|
||||||
|
# This is essentially tests.server.connect_client, but disabling autoflush on
|
||||||
|
# the client transport. This is necessary to avoid an infinite loop due to
|
||||||
|
# sending of data via the logging transport causing additional logs to be
|
||||||
|
# written.
|
||||||
|
factory = reactor.tcpClients.pop(client_id)[2]
|
||||||
|
client = factory.buildProtocol(None)
|
||||||
|
server = AccumulatingProtocol()
|
||||||
|
server.makeConnection(FakeTransport(client, reactor))
|
||||||
|
client.makeConnection(FakeTransport(server, reactor, autoflush=False))
|
||||||
|
|
||||||
|
return client, server
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.reactor, _ = get_clock()
|
||||||
|
|
||||||
|
def test_log_output(self):
|
||||||
|
"""
|
||||||
|
The remote handler delivers logs over TCP.
|
||||||
|
"""
|
||||||
|
handler = RemoteHandler("127.0.0.1", 9000, _reactor=self.reactor)
|
||||||
|
logger = self.get_logger(handler)
|
||||||
|
|
||||||
|
logger.info("Hello there, %s!", "wally")
|
||||||
|
|
||||||
|
# Trigger the connection
|
||||||
|
client, server = connect_logging_client(self.reactor, 0)
|
||||||
|
|
||||||
|
# Trigger data being sent
|
||||||
|
client.transport.flush()
|
||||||
|
|
||||||
|
# One log message, with a single trailing newline
|
||||||
|
logs = server.data.decode("utf8").splitlines()
|
||||||
|
self.assertEqual(len(logs), 1)
|
||||||
|
self.assertEqual(server.data.count(b"\n"), 1)
|
||||||
|
|
||||||
|
# Ensure the data passed through properly.
|
||||||
|
self.assertEqual(logs[0], "Hello there, wally!")
|
||||||
|
|
||||||
|
def test_log_backpressure_debug(self):
|
||||||
|
"""
|
||||||
|
When backpressure is hit, DEBUG logs will be shed.
|
||||||
|
"""
|
||||||
|
handler = RemoteHandler(
|
||||||
|
"127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
|
||||||
|
)
|
||||||
|
logger = self.get_logger(handler)
|
||||||
|
|
||||||
|
# Send some debug messages
|
||||||
|
for i in range(0, 3):
|
||||||
|
logger.debug("debug %s" % (i,))
|
||||||
|
|
||||||
|
# Send a bunch of useful messages
|
||||||
|
for i in range(0, 7):
|
||||||
|
logger.info("info %s" % (i,))
|
||||||
|
|
||||||
|
# The last debug message pushes it past the maximum buffer
|
||||||
|
logger.debug("too much debug")
|
||||||
|
|
||||||
|
# Allow the reconnection
|
||||||
|
client, server = connect_logging_client(self.reactor, 0)
|
||||||
|
client.transport.flush()
|
||||||
|
|
||||||
|
# Only the 7 infos made it through, the debugs were elided
|
||||||
|
logs = server.data.splitlines()
|
||||||
|
self.assertEqual(len(logs), 7)
|
||||||
|
self.assertNotIn(b"debug", server.data)
|
||||||
|
|
||||||
|
def test_log_backpressure_info(self):
|
||||||
|
"""
|
||||||
|
When backpressure is hit, DEBUG and INFO logs will be shed.
|
||||||
|
"""
|
||||||
|
handler = RemoteHandler(
|
||||||
|
"127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
|
||||||
|
)
|
||||||
|
logger = self.get_logger(handler)
|
||||||
|
|
||||||
|
# Send some debug messages
|
||||||
|
for i in range(0, 3):
|
||||||
|
logger.debug("debug %s" % (i,))
|
||||||
|
|
||||||
|
# Send a bunch of useful messages
|
||||||
|
for i in range(0, 10):
|
||||||
|
logger.warning("warn %s" % (i,))
|
||||||
|
|
||||||
|
# Send a bunch of info messages
|
||||||
|
for i in range(0, 3):
|
||||||
|
logger.info("info %s" % (i,))
|
||||||
|
|
||||||
|
# The last debug message pushes it past the maximum buffer
|
||||||
|
logger.debug("too much debug")
|
||||||
|
|
||||||
|
# Allow the reconnection
|
||||||
|
client, server = connect_logging_client(self.reactor, 0)
|
||||||
|
client.transport.flush()
|
||||||
|
|
||||||
|
# The 10 warnings made it through, the debugs and infos were elided
|
||||||
|
logs = server.data.splitlines()
|
||||||
|
self.assertEqual(len(logs), 10)
|
||||||
|
self.assertNotIn(b"debug", server.data)
|
||||||
|
self.assertNotIn(b"info", server.data)
|
||||||
|
|
||||||
|
def test_log_backpressure_cut_middle(self):
|
||||||
|
"""
|
||||||
|
When backpressure is hit, and no more DEBUG and INFOs cannot be culled,
|
||||||
|
it will cut the middle messages out.
|
||||||
|
"""
|
||||||
|
handler = RemoteHandler(
|
||||||
|
"127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
|
||||||
|
)
|
||||||
|
logger = self.get_logger(handler)
|
||||||
|
|
||||||
|
# Send a bunch of useful messages
|
||||||
|
for i in range(0, 20):
|
||||||
|
logger.warning("warn %s" % (i,))
|
||||||
|
|
||||||
|
# Allow the reconnection
|
||||||
|
client, server = connect_logging_client(self.reactor, 0)
|
||||||
|
client.transport.flush()
|
||||||
|
|
||||||
|
# The first five and last five warnings made it through, the debugs and
|
||||||
|
# infos were elided
|
||||||
|
logs = server.data.decode("utf8").splitlines()
|
||||||
|
self.assertEqual(
|
||||||
|
["warn %s" % (i,) for i in range(5)]
|
||||||
|
+ ["warn %s" % (i,) for i in range(15, 20)],
|
||||||
|
logs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cancel_connection(self):
|
||||||
|
"""
|
||||||
|
Gracefully handle the connection being cancelled.
|
||||||
|
"""
|
||||||
|
handler = RemoteHandler(
|
||||||
|
"127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
|
||||||
|
)
|
||||||
|
logger = self.get_logger(handler)
|
||||||
|
|
||||||
|
# Send a message.
|
||||||
|
logger.info("Hello there, %s!", "wally")
|
||||||
|
|
||||||
|
# Do not accept the connection and shutdown. This causes the pending
|
||||||
|
# connection to be cancelled (and should not raise any exceptions).
|
||||||
|
handler.close()
|
|
@ -1,214 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import os.path
|
|
||||||
import shutil
|
|
||||||
import sys
|
|
||||||
import textwrap
|
|
||||||
|
|
||||||
from twisted.logger import Logger, eventAsText, eventsFromJSONLogFile
|
|
||||||
|
|
||||||
from synapse.config.logger import setup_logging
|
|
||||||
from synapse.logging._structured import setup_structured_logging
|
|
||||||
from synapse.logging.context import LoggingContext
|
|
||||||
|
|
||||||
from tests.unittest import DEBUG, HomeserverTestCase
|
|
||||||
|
|
||||||
|
|
||||||
class FakeBeginner:
|
|
||||||
def beginLoggingTo(self, observers, **kwargs):
|
|
||||||
self.observers = observers
|
|
||||||
|
|
||||||
|
|
||||||
class StructuredLoggingTestBase:
|
|
||||||
"""
|
|
||||||
Test base that registers a cleanup handler to reset the stdlib log handler
|
|
||||||
to 'unset'.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
|
||||||
def _cleanup():
|
|
||||||
logging.getLogger("synapse").setLevel(logging.NOTSET)
|
|
||||||
|
|
||||||
self.addCleanup(_cleanup)
|
|
||||||
|
|
||||||
|
|
||||||
class StructuredLoggingTestCase(StructuredLoggingTestBase, HomeserverTestCase):
|
|
||||||
"""
|
|
||||||
Tests for Synapse's structured logging support.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_output_to_json_round_trip(self):
|
|
||||||
"""
|
|
||||||
Synapse logs can be outputted to JSON and then read back again.
|
|
||||||
"""
|
|
||||||
temp_dir = self.mktemp()
|
|
||||||
os.mkdir(temp_dir)
|
|
||||||
self.addCleanup(shutil.rmtree, temp_dir)
|
|
||||||
|
|
||||||
json_log_file = os.path.abspath(os.path.join(temp_dir, "out.json"))
|
|
||||||
|
|
||||||
log_config = {
|
|
||||||
"drains": {"jsonfile": {"type": "file_json", "location": json_log_file}}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Begin the logger with our config
|
|
||||||
beginner = FakeBeginner()
|
|
||||||
setup_structured_logging(
|
|
||||||
self.hs, self.hs.config, log_config, logBeginner=beginner
|
|
||||||
)
|
|
||||||
|
|
||||||
# Make a logger and send an event
|
|
||||||
logger = Logger(
|
|
||||||
namespace="tests.logging.test_structured", observer=beginner.observers[0]
|
|
||||||
)
|
|
||||||
logger.info("Hello there, {name}!", name="wally")
|
|
||||||
|
|
||||||
# Read the log file and check it has the event we sent
|
|
||||||
with open(json_log_file, "r") as f:
|
|
||||||
logged_events = list(eventsFromJSONLogFile(f))
|
|
||||||
self.assertEqual(len(logged_events), 1)
|
|
||||||
|
|
||||||
# The event pulled from the file should render fine
|
|
||||||
self.assertEqual(
|
|
||||||
eventAsText(logged_events[0], includeTimestamp=False),
|
|
||||||
"[tests.logging.test_structured#info] Hello there, wally!",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_output_to_text(self):
|
|
||||||
"""
|
|
||||||
Synapse logs can be outputted to text.
|
|
||||||
"""
|
|
||||||
temp_dir = self.mktemp()
|
|
||||||
os.mkdir(temp_dir)
|
|
||||||
self.addCleanup(shutil.rmtree, temp_dir)
|
|
||||||
|
|
||||||
log_file = os.path.abspath(os.path.join(temp_dir, "out.log"))
|
|
||||||
|
|
||||||
log_config = {"drains": {"file": {"type": "file", "location": log_file}}}
|
|
||||||
|
|
||||||
# Begin the logger with our config
|
|
||||||
beginner = FakeBeginner()
|
|
||||||
setup_structured_logging(
|
|
||||||
self.hs, self.hs.config, log_config, logBeginner=beginner
|
|
||||||
)
|
|
||||||
|
|
||||||
# Make a logger and send an event
|
|
||||||
logger = Logger(
|
|
||||||
namespace="tests.logging.test_structured", observer=beginner.observers[0]
|
|
||||||
)
|
|
||||||
logger.info("Hello there, {name}!", name="wally")
|
|
||||||
|
|
||||||
# Read the log file and check it has the event we sent
|
|
||||||
with open(log_file, "r") as f:
|
|
||||||
logged_events = f.read().strip().split("\n")
|
|
||||||
self.assertEqual(len(logged_events), 1)
|
|
||||||
|
|
||||||
# The event pulled from the file should render fine
|
|
||||||
self.assertTrue(
|
|
||||||
logged_events[0].endswith(
|
|
||||||
" - tests.logging.test_structured - INFO - None - Hello there, wally!"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_collects_logcontext(self):
|
|
||||||
"""
|
|
||||||
Test that log outputs have the attached logging context.
|
|
||||||
"""
|
|
||||||
log_config = {"drains": {}}
|
|
||||||
|
|
||||||
# Begin the logger with our config
|
|
||||||
beginner = FakeBeginner()
|
|
||||||
publisher = setup_structured_logging(
|
|
||||||
self.hs, self.hs.config, log_config, logBeginner=beginner
|
|
||||||
)
|
|
||||||
|
|
||||||
logs = []
|
|
||||||
|
|
||||||
publisher.addObserver(logs.append)
|
|
||||||
|
|
||||||
# Make a logger and send an event
|
|
||||||
logger = Logger(
|
|
||||||
namespace="tests.logging.test_structured", observer=beginner.observers[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
with LoggingContext("testcontext", request="somereq"):
|
|
||||||
logger.info("Hello there, {name}!", name="steve")
|
|
||||||
|
|
||||||
self.assertEqual(len(logs), 1)
|
|
||||||
self.assertEqual(logs[0]["request"], "somereq")
|
|
||||||
|
|
||||||
|
|
||||||
class StructuredLoggingConfigurationFileTestCase(
|
|
||||||
StructuredLoggingTestBase, HomeserverTestCase
|
|
||||||
):
|
|
||||||
def make_homeserver(self, reactor, clock):
|
|
||||||
|
|
||||||
tempdir = self.mktemp()
|
|
||||||
os.mkdir(tempdir)
|
|
||||||
log_config_file = os.path.abspath(os.path.join(tempdir, "log.config.yaml"))
|
|
||||||
self.homeserver_log = os.path.abspath(os.path.join(tempdir, "homeserver.log"))
|
|
||||||
|
|
||||||
config = self.default_config()
|
|
||||||
config["log_config"] = log_config_file
|
|
||||||
|
|
||||||
with open(log_config_file, "w") as f:
|
|
||||||
f.write(
|
|
||||||
textwrap.dedent(
|
|
||||||
"""\
|
|
||||||
structured: true
|
|
||||||
|
|
||||||
drains:
|
|
||||||
file:
|
|
||||||
type: file_json
|
|
||||||
location: %s
|
|
||||||
"""
|
|
||||||
% (self.homeserver_log,)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.addCleanup(self._sys_cleanup)
|
|
||||||
|
|
||||||
return self.setup_test_homeserver(config=config)
|
|
||||||
|
|
||||||
def _sys_cleanup(self):
|
|
||||||
sys.stdout = sys.__stdout__
|
|
||||||
sys.stderr = sys.__stderr__
|
|
||||||
|
|
||||||
# Do not remove! We need the logging system to be set other than WARNING.
|
|
||||||
@DEBUG
|
|
||||||
def test_log_output(self):
|
|
||||||
"""
|
|
||||||
When a structured logging config is given, Synapse will use it.
|
|
||||||
"""
|
|
||||||
beginner = FakeBeginner()
|
|
||||||
publisher = setup_logging(self.hs, self.hs.config, logBeginner=beginner)
|
|
||||||
|
|
||||||
# Make a logger and send an event
|
|
||||||
logger = Logger(namespace="tests.logging.test_structured", observer=publisher)
|
|
||||||
|
|
||||||
with LoggingContext("testcontext", request="somereq"):
|
|
||||||
logger.info("Hello there, {name}!", name="steve")
|
|
||||||
|
|
||||||
with open(self.homeserver_log, "r") as f:
|
|
||||||
logged_events = [
|
|
||||||
eventAsText(x, includeTimestamp=False) for x in eventsFromJSONLogFile(f)
|
|
||||||
]
|
|
||||||
|
|
||||||
logs = "\n".join(logged_events)
|
|
||||||
self.assertTrue("***** STARTING SERVER *****" in logs)
|
|
||||||
self.assertTrue("Hello there, steve!" in logs)
|
|
|
@ -14,57 +14,33 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from collections import Counter
|
import logging
|
||||||
|
from io import StringIO
|
||||||
|
|
||||||
from twisted.logger import Logger
|
from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
|
||||||
|
|
||||||
from synapse.logging._structured import setup_structured_logging
|
from tests.logging import LoggerCleanupMixin
|
||||||
|
from tests.unittest import TestCase
|
||||||
from tests.server import connect_client
|
|
||||||
from tests.unittest import HomeserverTestCase
|
|
||||||
|
|
||||||
from .test_structured import FakeBeginner, StructuredLoggingTestBase
|
|
||||||
|
|
||||||
|
|
||||||
class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase):
|
class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
|
||||||
def test_log_output(self):
|
def test_terse_json_output(self):
|
||||||
"""
|
"""
|
||||||
The Terse JSON outputter delivers simplified structured logs over TCP.
|
The Terse JSON formatter converts log messages to JSON.
|
||||||
"""
|
"""
|
||||||
log_config = {
|
output = StringIO()
|
||||||
"drains": {
|
|
||||||
"tersejson": {
|
|
||||||
"type": "network_json_terse",
|
|
||||||
"host": "127.0.0.1",
|
|
||||||
"port": 8000,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Begin the logger with our config
|
handler = logging.StreamHandler(output)
|
||||||
beginner = FakeBeginner()
|
handler.setFormatter(TerseJsonFormatter())
|
||||||
setup_structured_logging(
|
logger = self.get_logger(handler)
|
||||||
self.hs, self.hs.config, log_config, logBeginner=beginner
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = Logger(
|
logger.info("Hello there, %s!", "wally")
|
||||||
namespace="tests.logging.test_terse_json", observer=beginner.observers[0]
|
|
||||||
)
|
|
||||||
logger.info("Hello there, {name}!", name="wally")
|
|
||||||
|
|
||||||
# Trigger the connection
|
# One log message, with a single trailing newline.
|
||||||
self.pump()
|
data = output.getvalue()
|
||||||
|
logs = data.splitlines()
|
||||||
_, server = connect_client(self.reactor, 0)
|
|
||||||
|
|
||||||
# Trigger data being sent
|
|
||||||
self.pump()
|
|
||||||
|
|
||||||
# One log message, with a single trailing newline
|
|
||||||
logs = server.data.decode("utf8").splitlines()
|
|
||||||
self.assertEqual(len(logs), 1)
|
self.assertEqual(len(logs), 1)
|
||||||
self.assertEqual(server.data.count(b"\n"), 1)
|
self.assertEqual(data.count("\n"), 1)
|
||||||
|
|
||||||
log = json.loads(logs[0])
|
log = json.loads(logs[0])
|
||||||
|
|
||||||
# The terse logger should give us these keys.
|
# The terse logger should give us these keys.
|
||||||
|
@ -72,163 +48,74 @@ class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase):
|
||||||
"log",
|
"log",
|
||||||
"time",
|
"time",
|
||||||
"level",
|
"level",
|
||||||
"log_namespace",
|
"namespace",
|
||||||
"request",
|
]
|
||||||
"scope",
|
self.assertCountEqual(log.keys(), expected_log_keys)
|
||||||
"server_name",
|
self.assertEqual(log["log"], "Hello there, wally!")
|
||||||
"name",
|
|
||||||
|
def test_extra_data(self):
|
||||||
|
"""
|
||||||
|
Additional information can be included in the structured logging.
|
||||||
|
"""
|
||||||
|
output = StringIO()
|
||||||
|
|
||||||
|
handler = logging.StreamHandler(output)
|
||||||
|
handler.setFormatter(TerseJsonFormatter())
|
||||||
|
logger = self.get_logger(handler)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Hello there, %s!", "wally", extra={"foo": "bar", "int": 3, "bool": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
# One log message, with a single trailing newline.
|
||||||
|
data = output.getvalue()
|
||||||
|
logs = data.splitlines()
|
||||||
|
self.assertEqual(len(logs), 1)
|
||||||
|
self.assertEqual(data.count("\n"), 1)
|
||||||
|
log = json.loads(logs[0])
|
||||||
|
|
||||||
|
# The terse logger should give us these keys.
|
||||||
|
expected_log_keys = [
|
||||||
|
"log",
|
||||||
|
"time",
|
||||||
|
"level",
|
||||||
|
"namespace",
|
||||||
|
# The additional keys given via extra.
|
||||||
|
"foo",
|
||||||
|
"int",
|
||||||
|
"bool",
|
||||||
]
|
]
|
||||||
self.assertCountEqual(log.keys(), expected_log_keys)
|
self.assertCountEqual(log.keys(), expected_log_keys)
|
||||||
|
|
||||||
# It contains the data we expect.
|
# Check the values of the extra fields.
|
||||||
self.assertEqual(log["name"], "wally")
|
self.assertEqual(log["foo"], "bar")
|
||||||
|
self.assertEqual(log["int"], 3)
|
||||||
|
self.assertIs(log["bool"], True)
|
||||||
|
|
||||||
def test_log_backpressure_debug(self):
|
def test_json_output(self):
|
||||||
"""
|
"""
|
||||||
When backpressure is hit, DEBUG logs will be shed.
|
The Terse JSON formatter converts log messages to JSON.
|
||||||
"""
|
"""
|
||||||
log_config = {
|
output = StringIO()
|
||||||
"loggers": {"synapse": {"level": "DEBUG"}},
|
|
||||||
"drains": {
|
|
||||||
"tersejson": {
|
|
||||||
"type": "network_json_terse",
|
|
||||||
"host": "127.0.0.1",
|
|
||||||
"port": 8000,
|
|
||||||
"maximum_buffer": 10,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Begin the logger with our config
|
handler = logging.StreamHandler(output)
|
||||||
beginner = FakeBeginner()
|
handler.setFormatter(JsonFormatter())
|
||||||
setup_structured_logging(
|
logger = self.get_logger(handler)
|
||||||
self.hs,
|
|
||||||
self.hs.config,
|
|
||||||
log_config,
|
|
||||||
logBeginner=beginner,
|
|
||||||
redirect_stdlib_logging=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = Logger(
|
logger.info("Hello there, %s!", "wally")
|
||||||
namespace="synapse.logging.test_terse_json", observer=beginner.observers[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send some debug messages
|
# One log message, with a single trailing newline.
|
||||||
for i in range(0, 3):
|
data = output.getvalue()
|
||||||
logger.debug("debug %s" % (i,))
|
logs = data.splitlines()
|
||||||
|
self.assertEqual(len(logs), 1)
|
||||||
|
self.assertEqual(data.count("\n"), 1)
|
||||||
|
log = json.loads(logs[0])
|
||||||
|
|
||||||
# Send a bunch of useful messages
|
# The terse logger should give us these keys.
|
||||||
for i in range(0, 7):
|
expected_log_keys = [
|
||||||
logger.info("test message %s" % (i,))
|
"log",
|
||||||
|
"level",
|
||||||
# The last debug message pushes it past the maximum buffer
|
"namespace",
|
||||||
logger.debug("too much debug")
|
]
|
||||||
|
self.assertCountEqual(log.keys(), expected_log_keys)
|
||||||
# Allow the reconnection
|
self.assertEqual(log["log"], "Hello there, wally!")
|
||||||
_, server = connect_client(self.reactor, 0)
|
|
||||||
self.pump()
|
|
||||||
|
|
||||||
# Only the 7 infos made it through, the debugs were elided
|
|
||||||
logs = server.data.splitlines()
|
|
||||||
self.assertEqual(len(logs), 7)
|
|
||||||
|
|
||||||
def test_log_backpressure_info(self):
|
|
||||||
"""
|
|
||||||
When backpressure is hit, DEBUG and INFO logs will be shed.
|
|
||||||
"""
|
|
||||||
log_config = {
|
|
||||||
"loggers": {"synapse": {"level": "DEBUG"}},
|
|
||||||
"drains": {
|
|
||||||
"tersejson": {
|
|
||||||
"type": "network_json_terse",
|
|
||||||
"host": "127.0.0.1",
|
|
||||||
"port": 8000,
|
|
||||||
"maximum_buffer": 10,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Begin the logger with our config
|
|
||||||
beginner = FakeBeginner()
|
|
||||||
setup_structured_logging(
|
|
||||||
self.hs,
|
|
||||||
self.hs.config,
|
|
||||||
log_config,
|
|
||||||
logBeginner=beginner,
|
|
||||||
redirect_stdlib_logging=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = Logger(
|
|
||||||
namespace="synapse.logging.test_terse_json", observer=beginner.observers[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send some debug messages
|
|
||||||
for i in range(0, 3):
|
|
||||||
logger.debug("debug %s" % (i,))
|
|
||||||
|
|
||||||
# Send a bunch of useful messages
|
|
||||||
for i in range(0, 10):
|
|
||||||
logger.warn("test warn %s" % (i,))
|
|
||||||
|
|
||||||
# Send a bunch of info messages
|
|
||||||
for i in range(0, 3):
|
|
||||||
logger.info("test message %s" % (i,))
|
|
||||||
|
|
||||||
# The last debug message pushes it past the maximum buffer
|
|
||||||
logger.debug("too much debug")
|
|
||||||
|
|
||||||
# Allow the reconnection
|
|
||||||
client, server = connect_client(self.reactor, 0)
|
|
||||||
self.pump()
|
|
||||||
|
|
||||||
# The 10 warnings made it through, the debugs and infos were elided
|
|
||||||
logs = list(map(json.loads, server.data.decode("utf8").splitlines()))
|
|
||||||
self.assertEqual(len(logs), 10)
|
|
||||||
|
|
||||||
self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10})
|
|
||||||
|
|
||||||
def test_log_backpressure_cut_middle(self):
|
|
||||||
"""
|
|
||||||
When backpressure is hit, and no more DEBUG and INFOs cannot be culled,
|
|
||||||
it will cut the middle messages out.
|
|
||||||
"""
|
|
||||||
log_config = {
|
|
||||||
"loggers": {"synapse": {"level": "DEBUG"}},
|
|
||||||
"drains": {
|
|
||||||
"tersejson": {
|
|
||||||
"type": "network_json_terse",
|
|
||||||
"host": "127.0.0.1",
|
|
||||||
"port": 8000,
|
|
||||||
"maximum_buffer": 10,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Begin the logger with our config
|
|
||||||
beginner = FakeBeginner()
|
|
||||||
setup_structured_logging(
|
|
||||||
self.hs,
|
|
||||||
self.hs.config,
|
|
||||||
log_config,
|
|
||||||
logBeginner=beginner,
|
|
||||||
redirect_stdlib_logging=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = Logger(
|
|
||||||
namespace="synapse.logging.test_terse_json", observer=beginner.observers[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send a bunch of useful messages
|
|
||||||
for i in range(0, 20):
|
|
||||||
logger.warn("test warn", num=i)
|
|
||||||
|
|
||||||
# Allow the reconnection
|
|
||||||
client, server = connect_client(self.reactor, 0)
|
|
||||||
self.pump()
|
|
||||||
|
|
||||||
# The first five and last five warnings made it through, the debugs and
|
|
||||||
# infos were elided
|
|
||||||
logs = list(map(json.loads, server.data.decode("utf8").splitlines()))
|
|
||||||
self.assertEqual(len(logs), 10)
|
|
||||||
self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10})
|
|
||||||
self.assertEqual([0, 1, 2, 3, 4, 15, 16, 17, 18, 19], [x["num"] for x in logs])
|
|
||||||
|
|
|
@ -100,7 +100,7 @@ class EmailPusherTests(HomeserverTestCase):
|
||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastore().get_user_by_access_token(self.access_token)
|
self.hs.get_datastore().get_user_by_access_token(self.access_token)
|
||||||
)
|
)
|
||||||
token_id = user_tuple["token_id"]
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.pusher = self.get_success(
|
self.pusher = self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_pusher(
|
||||||
|
|
|
@ -69,7 +69,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
token_id = user_tuple["token_id"]
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_pusher(
|
||||||
|
@ -181,7 +181,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
token_id = user_tuple["token_id"]
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_pusher(
|
||||||
|
@ -297,7 +297,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
token_id = user_tuple["token_id"]
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_pusher(
|
||||||
|
@ -379,7 +379,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
token_id = user_tuple["token_id"]
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_pusher(
|
||||||
|
@ -452,7 +452,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
token_id = user_tuple["token_id"]
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_pusher(
|
||||||
|
|
|
@ -16,7 +16,6 @@ import logging
|
||||||
from typing import Any, Callable, List, Optional, Tuple
|
from typing import Any, Callable, List, Optional, Tuple
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import hiredis
|
|
||||||
|
|
||||||
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
|
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
|
||||||
from twisted.internet.protocol import Protocol
|
from twisted.internet.protocol import Protocol
|
||||||
|
@ -39,12 +38,22 @@ from synapse.util import Clock
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import FakeTransport, render
|
from tests.server import FakeTransport, render
|
||||||
|
|
||||||
|
try:
|
||||||
|
import hiredis
|
||||||
|
except ImportError:
|
||||||
|
hiredis = None
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BaseStreamTestCase(unittest.HomeserverTestCase):
|
class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
"""Base class for tests of the replication streams"""
|
"""Base class for tests of the replication streams"""
|
||||||
|
|
||||||
|
# hiredis is an optional dependency so we don't want to require it for running
|
||||||
|
# the tests.
|
||||||
|
if not hiredis:
|
||||||
|
skip = "Requires hiredis"
|
||||||
|
|
||||||
servlets = [
|
servlets = [
|
||||||
streams.register_servlets,
|
streams.register_servlets,
|
||||||
]
|
]
|
||||||
|
@ -269,7 +278,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||||
homeserver_to_use=GenericWorkerServer,
|
homeserver_to_use=GenericWorkerServer,
|
||||||
config=config,
|
config=config,
|
||||||
reactor=self.reactor,
|
reactor=self.reactor,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If the instance is in the `instance_map` config then workers may try
|
# If the instance is in the `instance_map` config then workers may try
|
||||||
|
|
|
@ -449,7 +449,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
sender=sender,
|
sender=sender,
|
||||||
type="test_event",
|
type="test_event",
|
||||||
content={"body": body},
|
content={"body": body},
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
277
tests/replication/test_multi_media_repo.py
Normal file
277
tests/replication/test_multi_media_repo.py
Normal file
|
@ -0,0 +1,277 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from binascii import unhexlify
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from twisted.internet.protocol import Factory
|
||||||
|
from twisted.protocols.tls import TLSMemoryBIOFactory
|
||||||
|
from twisted.web.http import HTTPChannel
|
||||||
|
from twisted.web.server import Request
|
||||||
|
|
||||||
|
from synapse.rest import admin
|
||||||
|
from synapse.rest.client.v1 import login
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
|
||||||
|
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||||
|
from tests.server import FakeChannel, FakeTransport
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
test_server_connection_factory = None
|
||||||
|
|
||||||
|
|
||||||
|
class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
|
"""Checks running multiple media repos work correctly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
servlets = [
|
||||||
|
admin.register_servlets_for_client_rest_resource,
|
||||||
|
login.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def prepare(self, reactor, clock, hs):
|
||||||
|
self.user_id = self.register_user("user", "pass")
|
||||||
|
self.access_token = self.login("user", "pass")
|
||||||
|
|
||||||
|
self.reactor.lookups["example.com"] = "127.0.0.2"
|
||||||
|
|
||||||
|
def default_config(self):
|
||||||
|
conf = super().default_config()
|
||||||
|
conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
|
||||||
|
return conf
|
||||||
|
|
||||||
|
def _get_media_req(
|
||||||
|
self, hs: HomeServer, target: str, media_id: str
|
||||||
|
) -> Tuple[FakeChannel, Request]:
|
||||||
|
"""Request some remote media from the given HS by calling the download
|
||||||
|
API.
|
||||||
|
|
||||||
|
This then triggers an outbound request from the HS to the target.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The channel for the *client* request and the *outbound* request for
|
||||||
|
the media which the caller should respond to.
|
||||||
|
"""
|
||||||
|
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/{}/{}".format(target, media_id),
|
||||||
|
shorthand=False,
|
||||||
|
access_token=self.access_token,
|
||||||
|
)
|
||||||
|
request.render(hs.get_media_repository_resource().children[b"download"])
|
||||||
|
self.pump()
|
||||||
|
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertGreaterEqual(len(clients), 1)
|
||||||
|
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
|
||||||
|
|
||||||
|
# build the test server
|
||||||
|
server_tls_protocol = _build_test_server(get_connection_factory())
|
||||||
|
|
||||||
|
# now, tell the client protocol factory to build the client protocol (it will be a
|
||||||
|
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
|
||||||
|
# HTTP11ClientProtocol) and wire the output of said protocol up to the server via
|
||||||
|
# a FakeTransport.
|
||||||
|
#
|
||||||
|
# Normally this would be done by the TCP socket code in Twisted, but we are
|
||||||
|
# stubbing that out here.
|
||||||
|
client_protocol = client_factory.buildProtocol(None)
|
||||||
|
client_protocol.makeConnection(
|
||||||
|
FakeTransport(server_tls_protocol, self.reactor, client_protocol)
|
||||||
|
)
|
||||||
|
|
||||||
|
# tell the server tls protocol to send its stuff back to the client, too
|
||||||
|
server_tls_protocol.makeConnection(
|
||||||
|
FakeTransport(client_protocol, self.reactor, server_tls_protocol)
|
||||||
|
)
|
||||||
|
|
||||||
|
# fish the test server back out of the server-side TLS protocol.
|
||||||
|
http_server = server_tls_protocol.wrappedProtocol
|
||||||
|
|
||||||
|
# give the reactor a pump to get the TLS juices flowing.
|
||||||
|
self.reactor.pump((0.1,))
|
||||||
|
|
||||||
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
request = http_server.requests[0]
|
||||||
|
|
||||||
|
self.assertEqual(request.method, b"GET")
|
||||||
|
self.assertEqual(
|
||||||
|
request.path,
|
||||||
|
"/_matrix/media/r0/download/{}/{}".format(target, media_id).encode("utf-8"),
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
|
||||||
|
)
|
||||||
|
|
||||||
|
return channel, request
|
||||||
|
|
||||||
|
def test_basic(self):
|
||||||
|
"""Test basic fetching of remote media from a single worker.
|
||||||
|
"""
|
||||||
|
hs1 = self.make_worker_hs("synapse.app.generic_worker")
|
||||||
|
|
||||||
|
channel, request = self._get_media_req(hs1, "example.com:443", "ABC123")
|
||||||
|
|
||||||
|
request.setResponseCode(200)
|
||||||
|
request.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
|
||||||
|
request.write(b"Hello!")
|
||||||
|
request.finish()
|
||||||
|
|
||||||
|
self.pump(0.1)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
self.assertEqual(channel.result["body"], b"Hello!")
|
||||||
|
|
||||||
|
def test_download_simple_file_race(self):
|
||||||
|
"""Test that fetching remote media from two different processes at the
|
||||||
|
same time works.
|
||||||
|
"""
|
||||||
|
hs1 = self.make_worker_hs("synapse.app.generic_worker")
|
||||||
|
hs2 = self.make_worker_hs("synapse.app.generic_worker")
|
||||||
|
|
||||||
|
start_count = self._count_remote_media()
|
||||||
|
|
||||||
|
# Make two requests without responding to the outbound media requests.
|
||||||
|
channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123")
|
||||||
|
channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123")
|
||||||
|
|
||||||
|
# Respond to the first outbound media request and check that the client
|
||||||
|
# request is successful
|
||||||
|
request1.setResponseCode(200)
|
||||||
|
request1.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
|
||||||
|
request1.write(b"Hello!")
|
||||||
|
request1.finish()
|
||||||
|
|
||||||
|
self.pump(0.1)
|
||||||
|
|
||||||
|
self.assertEqual(channel1.code, 200, channel1.result["body"])
|
||||||
|
self.assertEqual(channel1.result["body"], b"Hello!")
|
||||||
|
|
||||||
|
# Now respond to the second with the same content.
|
||||||
|
request2.setResponseCode(200)
|
||||||
|
request2.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
|
||||||
|
request2.write(b"Hello!")
|
||||||
|
request2.finish()
|
||||||
|
|
||||||
|
self.pump(0.1)
|
||||||
|
|
||||||
|
self.assertEqual(channel2.code, 200, channel2.result["body"])
|
||||||
|
self.assertEqual(channel2.result["body"], b"Hello!")
|
||||||
|
|
||||||
|
# We expect only one new file to have been persisted.
|
||||||
|
self.assertEqual(start_count + 1, self._count_remote_media())
|
||||||
|
|
||||||
|
def test_download_image_race(self):
|
||||||
|
"""Test that fetching remote *images* from two different processes at
|
||||||
|
the same time works.
|
||||||
|
|
||||||
|
This checks that races generating thumbnails are handled correctly.
|
||||||
|
"""
|
||||||
|
hs1 = self.make_worker_hs("synapse.app.generic_worker")
|
||||||
|
hs2 = self.make_worker_hs("synapse.app.generic_worker")
|
||||||
|
|
||||||
|
start_count = self._count_remote_thumbnails()
|
||||||
|
|
||||||
|
channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1")
|
||||||
|
channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1")
|
||||||
|
|
||||||
|
png_data = unhexlify(
|
||||||
|
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
|
||||||
|
b"0000001f15c4890000000a49444154789c63000100000500010d"
|
||||||
|
b"0a2db40000000049454e44ae426082"
|
||||||
|
)
|
||||||
|
|
||||||
|
request1.setResponseCode(200)
|
||||||
|
request1.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
|
||||||
|
request1.write(png_data)
|
||||||
|
request1.finish()
|
||||||
|
|
||||||
|
self.pump(0.1)
|
||||||
|
|
||||||
|
self.assertEqual(channel1.code, 200, channel1.result["body"])
|
||||||
|
self.assertEqual(channel1.result["body"], png_data)
|
||||||
|
|
||||||
|
request2.setResponseCode(200)
|
||||||
|
request2.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
|
||||||
|
request2.write(png_data)
|
||||||
|
request2.finish()
|
||||||
|
|
||||||
|
self.pump(0.1)
|
||||||
|
|
||||||
|
self.assertEqual(channel2.code, 200, channel2.result["body"])
|
||||||
|
self.assertEqual(channel2.result["body"], png_data)
|
||||||
|
|
||||||
|
# We expect only three new thumbnails to have been persisted.
|
||||||
|
self.assertEqual(start_count + 3, self._count_remote_thumbnails())
|
||||||
|
|
||||||
|
def _count_remote_media(self) -> int:
|
||||||
|
"""Count the number of files in our remote media directory.
|
||||||
|
"""
|
||||||
|
path = os.path.join(
|
||||||
|
self.hs.get_media_repository().primary_base_path, "remote_content"
|
||||||
|
)
|
||||||
|
return sum(len(files) for _, _, files in os.walk(path))
|
||||||
|
|
||||||
|
def _count_remote_thumbnails(self) -> int:
|
||||||
|
"""Count the number of files in our remote thumbnails directory.
|
||||||
|
"""
|
||||||
|
path = os.path.join(
|
||||||
|
self.hs.get_media_repository().primary_base_path, "remote_thumbnail"
|
||||||
|
)
|
||||||
|
return sum(len(files) for _, _, files in os.walk(path))
|
||||||
|
|
||||||
|
|
||||||
|
def get_connection_factory():
|
||||||
|
# this needs to happen once, but not until we are ready to run the first test
|
||||||
|
global test_server_connection_factory
|
||||||
|
if test_server_connection_factory is None:
|
||||||
|
test_server_connection_factory = TestServerTLSConnectionFactory(
|
||||||
|
sanlist=[b"DNS:example.com"]
|
||||||
|
)
|
||||||
|
return test_server_connection_factory
|
||||||
|
|
||||||
|
|
||||||
|
def _build_test_server(connection_creator):
|
||||||
|
"""Construct a test server
|
||||||
|
|
||||||
|
This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_creator (IOpenSSLServerConnectionCreator): thing to build
|
||||||
|
SSL connections
|
||||||
|
sanlist (list[bytes]): list of the SAN entries for the cert returned
|
||||||
|
by the server
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TLSMemoryBIOProtocol
|
||||||
|
"""
|
||||||
|
server_factory = Factory.forProtocol(HTTPChannel)
|
||||||
|
# Request.finish expects the factory to have a 'log' method.
|
||||||
|
server_factory.log = _log_request
|
||||||
|
|
||||||
|
server_tls_factory = TLSMemoryBIOFactory(
|
||||||
|
connection_creator, isClient=False, wrappedFactory=server_factory
|
||||||
|
)
|
||||||
|
|
||||||
|
return server_tls_factory.buildProtocol(None)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_request(request):
|
||||||
|
"""Implements Factory.log, which is expected by Request.finish"""
|
||||||
|
logger.info("Completed request %s", request)
|
|
@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
user_dict = self.get_success(
|
user_dict = self.get_success(
|
||||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
token_id = user_dict["token_id"]
|
token_id = user_dict.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_pusher(
|
||||||
|
|
|
@ -1118,6 +1118,130 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
|
self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
|
||||||
|
|
||||||
|
|
||||||
|
class PushersRestTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
|
servlets = [
|
||||||
|
synapse.rest.admin.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def prepare(self, reactor, clock, hs):
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||||
|
self.admin_user_tok = self.login("admin", "pass")
|
||||||
|
|
||||||
|
self.other_user = self.register_user("user", "pass")
|
||||||
|
self.url = "/_synapse/admin/v1/users/%s/pushers" % urllib.parse.quote(
|
||||||
|
self.other_user
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_no_auth(self):
|
||||||
|
"""
|
||||||
|
Try to list pushers of an user without authentication.
|
||||||
|
"""
|
||||||
|
request, channel = self.make_request("GET", self.url, b"{}")
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
def test_requester_is_no_admin(self):
|
||||||
|
"""
|
||||||
|
If the user is not a server admin, an error is returned.
|
||||||
|
"""
|
||||||
|
other_user_token = self.login("user", "pass")
|
||||||
|
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET", self.url, access_token=other_user_token,
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
def test_user_does_not_exist(self):
|
||||||
|
"""
|
||||||
|
Tests that a lookup for a user that does not exist returns a 404
|
||||||
|
"""
|
||||||
|
url = "/_synapse/admin/v1/users/@unknown_person:test/pushers"
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET", url, access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEqual(404, channel.code, msg=channel.json_body)
|
||||||
|
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
def test_user_is_not_local(self):
|
||||||
|
"""
|
||||||
|
Tests that a lookup for a user that is not a local returns a 400
|
||||||
|
"""
|
||||||
|
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers"
|
||||||
|
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET", url, access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEqual(400, channel.code, msg=channel.json_body)
|
||||||
|
self.assertEqual("Can only lookup local users", channel.json_body["error"])
|
||||||
|
|
||||||
|
def test_get_pushers(self):
|
||||||
|
"""
|
||||||
|
Tests that a normal lookup for pushers is successfully
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Get pushers
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET", self.url, access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||||
|
self.assertEqual(0, channel.json_body["total"])
|
||||||
|
|
||||||
|
# Register the pusher
|
||||||
|
other_user_token = self.login("user", "pass")
|
||||||
|
user_tuple = self.get_success(
|
||||||
|
self.store.get_user_by_access_token(other_user_token)
|
||||||
|
)
|
||||||
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
|
self.get_success(
|
||||||
|
self.hs.get_pusherpool().add_pusher(
|
||||||
|
user_id=self.other_user,
|
||||||
|
access_token=token_id,
|
||||||
|
kind="http",
|
||||||
|
app_id="m.http",
|
||||||
|
app_display_name="HTTP Push Notifications",
|
||||||
|
device_display_name="pushy push",
|
||||||
|
pushkey="a@example.com",
|
||||||
|
lang=None,
|
||||||
|
data={"url": "example.com"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get pushers
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET", self.url, access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||||
|
self.assertEqual(1, channel.json_body["total"])
|
||||||
|
|
||||||
|
for p in channel.json_body["pushers"]:
|
||||||
|
self.assertIn("pushkey", p)
|
||||||
|
self.assertIn("kind", p)
|
||||||
|
self.assertIn("app_id", p)
|
||||||
|
self.assertIn("app_display_name", p)
|
||||||
|
self.assertIn("device_display_name", p)
|
||||||
|
self.assertIn("profile_tag", p)
|
||||||
|
self.assertIn("lang", p)
|
||||||
|
self.assertIn("url", p["data"])
|
||||||
|
|
||||||
|
|
||||||
class UserMediaRestTestCase(unittest.HomeserverTestCase):
|
class UserMediaRestTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
servlets = [
|
servlets = [
|
||||||
|
|
|
@ -55,6 +55,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.hs.config.server_name,
|
self.hs.config.server_name,
|
||||||
id="1234",
|
id="1234",
|
||||||
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
|
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
|
||||||
|
sender="@as:test",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.hs.get_datastore().services_cache.append(appservice)
|
self.hs.get_datastore().services_cache.append(appservice)
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue