Merge branch 'release-v0.17.0' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2016-08-08 13:38:13 +01:00
commit d330d45e2d
136 changed files with 5946 additions and 1657 deletions

View file

@ -1,3 +1,125 @@
Changes in synapse v0.17.0 (2016-08-08)
=======================================
This release contains significant security bug fixes regarding authenticating
events received over federation. PLEASE UPGRADE.
This release changes the LDAP configuration format in a backwards incompatible
way, see PR #843 for details.
Changes:
* Add federation /version API (PR #990)
* Make psutil dependency optional (PR #992)
Bug fixes:
* Fix URL preview API to exclude HTML comments in description (PR #988)
* Fix error handling of remote joins (PR #991)
Changes in synapse v0.17.0-rc4 (2016-08-05)
===========================================
Changes:
* Change the way we summarize URLs when previewing (PR #973)
* Add new ``/state_ids/`` federation API (PR #979)
* Speed up processing of ``/state/`` response (PR #986)
Bug fixes:
* Fix event persistence when event has already been partially persisted
(PR #975, #983, #985)
* Fix port script to also copy across backfilled events (PR #982)
Changes in synapse v0.17.0-rc3 (2016-08-02)
===========================================
Changes:
* Forbid non-ASes from registering users whose names begin with '_' (PR #958)
* Add some basic admin API docs (PR #963)
Bug fixes:
* Send the correct host header when fetching keys (PR #941)
* Fix joining a room that has missing auth events (PR #964)
* Fix various push bugs (PR #966, #970)
* Fix adding emails on registration (PR #968)
Changes in synapse v0.17.0-rc2 (2016-08-02)
===========================================
(This release did not include the changes advertised and was identical to RC1)
Changes in synapse v0.17.0-rc1 (2016-07-28)
===========================================
This release changes the LDAP configuration format in a backwards incompatible
way, see PR #843 for details.
Features:
* Add purge_media_cache admin API (PR #902)
* Add deactivate account admin API (PR #903)
* Add optional pepper to password hashing (PR #907, #910 by KentShikama)
* Add an admin option to shared secret registration (breaks backwards compat)
(PR #909)
* Add purge local room history API (PR #911, #923, #924)
* Add requestToken endpoints (PR #915)
* Add an /account/deactivate endpoint (PR #921)
* Add filter param to /messages. Add 'contains_url' to filter. (PR #922)
* Add device_id support to /login (PR #929)
* Add device_id support to /v2/register flow. (PR #937, #942)
* Add GET /devices endpoint (PR #939, #944)
* Add GET /device/{deviceId} (PR #943)
* Add update and delete APIs for devices (PR #949)
Changes:
* Rewrite LDAP Authentication against ldap3 (PR #843 by mweinelt)
* Linearize some federation endpoints based on (origin, room_id) (PR #879)
* Remove the legacy v0 content upload API. (PR #888)
* Use similar naming we use in email notifs for push (PR #894)
* Optionally include password hash in createUser endpoint (PR #905 by
KentShikama)
* Use a query that postgresql optimises better for get_events_around (PR #906)
* Fall back to 'username' if 'user' is not given for appservice registration.
(PR #927 by Half-Shot)
* Add metrics for psutil derived memory usage (PR #936)
* Record device_id in client_ips (PR #938)
* Send the correct host header when fetching keys (PR #941)
* Log the hostname the reCAPTCHA was completed on (PR #946)
* Make the device id on e2e key upload optional (PR #956)
* Add r0.2.0 to the "supported versions" list (PR #960)
* Don't include name of room for invites in push (PR #961)
Bug fixes:
* Fix substitution failure in mail template (PR #887)
* Put most recent 20 messages in email notif (PR #892)
* Ensure that the guest user is in the database when upgrading accounts
(PR #914)
* Fix various edge cases in auth handling (PR #919)
* Fix 500 ISE when sending alias event without a state_key (PR #925)
* Fix bug where we stored rejections in the state_group, persist all
rejections (PR #948)
* Fix lack of check of if the user is banned when handling 3pid invites
(PR #952)
* Fix a couple of bugs in the transaction and keyring code (PR #954, #955)
Changes in synapse v0.16.1-r1 (2016-07-08) Changes in synapse v0.16.1-r1 (2016-07-08)
========================================== ==========================================

View file

@ -14,6 +14,7 @@ recursive-include docs *
recursive-include res * recursive-include res *
recursive-include scripts * recursive-include scripts *
recursive-include scripts-dev * recursive-include scripts-dev *
recursive-include synapse *.pyi
recursive-include tests *.py recursive-include tests *.py
recursive-include synapse/static *.css recursive-include synapse/static *.css
@ -23,5 +24,7 @@ recursive-include synapse/static *.js
exclude jenkins.sh exclude jenkins.sh
exclude jenkins*.sh exclude jenkins*.sh
exclude jenkins*
recursive-exclude jenkins *.sh
prune demo/etc prune demo/etc

View file

@ -445,7 +445,7 @@ You have two choices here, which will influence the form of your Matrix user
IDs: IDs:
1) Use the machine's own hostname as available on public DNS in the form of 1) Use the machine's own hostname as available on public DNS in the form of
its A or AAAA records. This is easier to set up initially, perhaps for its A records. This is easier to set up initially, perhaps for
testing, but lacks the flexibility of SRV. testing, but lacks the flexibility of SRV.
2) Set up a SRV record for your domain name. This requires you create a SRV 2) Set up a SRV record for your domain name. This requires you create a SRV

View file

@ -27,7 +27,7 @@ running:
# Pull the latest version of the master branch. # Pull the latest version of the master branch.
git pull git pull
# Update the versions of synapse's python dependencies. # Update the versions of synapse's python dependencies.
python synapse/python_dependencies.py | xargs -n1 pip install python synapse/python_dependencies.py | xargs -n1 pip install --upgrade
Upgrading to v0.15.0 Upgrading to v0.15.0

12
docs/admin_api/README.rst Normal file
View file

@ -0,0 +1,12 @@
Admin APIs
==========
This directory includes documentation for the various synapse specific admin
APIs available.
Only users that are server admins can use these APIs. A user can be marked as a
server admin by updating the database directly, e.g.:
``UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'``
Restarting may be required for the changes to register.

View file

@ -0,0 +1,15 @@
Purge History API
=================
The purge history API allows server admins to purge historic events from their
database, reclaiming disk space.
Depending on the amount of history being purged a call to the API may take
several minutes or longer. During this period users will not be able to
paginate further back in the room from the point being purged from.
The API is simply:
``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>``
including an ``access_token`` of a server admin.

View file

@ -0,0 +1,19 @@
Purge Remote Media API
======================
The purge remote media API allows server admins to purge old cached remote
media.
The API is::
POST /_matrix/client/r0/admin/purge_media_cache
{
"before_ts": <unix_timestamp_in_ms>
}
Which will remove all cached media that was last accessed before
``<unix_timestamp_in_ms>``.
If the user re-requests purged remote media, synapse will re-request the media
from the originating server.

View file

@ -43,7 +43,10 @@ Basically, PEP8
together, or want to deliberately extend or preserve vertical/horizontal together, or want to deliberately extend or preserve vertical/horizontal
space) space)
Comments should follow the google code style. This is so that we can generate Comments should follow the `google code style <http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_.
documentation with sphinx (http://sphinxcontrib-napoleon.readthedocs.org/en/latest/) This is so that we can generate documentation with
`sphinx <http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the
`examples <http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_
in the sphinx documentation.
Code should pass pep8 --max-line-length=100 without any warnings. Code should pass pep8 --max-line-length=100 without any warnings.

View file

@ -9,31 +9,35 @@ the Home Server to generate credentials that are valid for use on the TURN
server through the use of a secret shared between the Home Server and the server through the use of a secret shared between the Home Server and the
TURN server. TURN server.
This document described how to install coturn This document describes how to install coturn
(https://code.google.com/p/coturn/) which also supports the TURN REST API, (https://github.com/coturn/coturn) which also supports the TURN REST API,
and integrate it with synapse. and integrate it with synapse.
coturn Setup coturn Setup
============ ============
You may be able to setup coturn via your package manager, or set it up manually using the usual ``configure, make, make install`` process.
1. Check out coturn:: 1. Check out coturn::
svn checkout http://coturn.googlecode.com/svn/trunk/ coturn
git clone https://github.com/coturn/coturn.git coturn
cd coturn cd coturn
2. Configure it:: 2. Configure it::
./configure ./configure
You may need to install libevent2: if so, you should do so You may need to install ``libevent2``: if so, you should do so
in the way recommended by your operating system. in the way recommended by your operating system.
You can ignore warnings about lack of database support: a You can ignore warnings about lack of database support: a
database is unnecessary for this purpose. database is unnecessary for this purpose.
3. Build and install it:: 3. Build and install it::
make make
make install make install
4. Make a config file in /etc/turnserver.conf. You can customise 4. Create or edit the config file in ``/etc/turnserver.conf``. The relevant
a config file from turnserver.conf.default. The relevant
lines, with example values, are:: lines, with example values, are::
lt-cred-mech lt-cred-mech
@ -41,7 +45,7 @@ coturn Setup
static-auth-secret=[your secret key here] static-auth-secret=[your secret key here]
realm=turn.myserver.org realm=turn.myserver.org
See turnserver.conf.default for explanations of the options. See turnserver.conf for explanations of the options.
One way to generate the static-auth-secret is with pwgen:: One way to generate the static-auth-secret is with pwgen::
pwgen -s 64 1 pwgen -s 64 1
@ -54,6 +58,7 @@ coturn Setup
import your private key and certificate. import your private key and certificate.
7. Start the turn server:: 7. Start the turn server::
bin/turnserver -o bin/turnserver -o

View file

@ -4,83 +4,19 @@ set -eux
: ${WORKSPACE:="$(pwd)"} : ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1 export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml ./jenkins/prepare_synapse.sh
export TRIAL_FLAGS="--reporter=subunit" ./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml" ./jenkins/clone.sh dendron https://github.com/matrix-org/dendron.git
# Write coverage reports to a separate file for each process ./dendron/jenkins/build_dendron.sh
export COVERAGE_OPTS="-p" ./sytest/jenkins/prep_sytest_for_postgres.sh
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log ./sytest/jenkins/install_and_run.sh \
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install psycopg2
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .dendron-base ]]; then
git clone https://github.com/matrix-org/dendron.git .dendron-base --mirror
else
(cd .dendron-base; git fetch -p)
fi
rm -rf dendron
git clone .dendron-base dendron --shared
cd dendron
: ${GOPATH:=${WORKSPACE}/.gopath}
if [[ "${GOPATH}" != *:* ]]; then
mkdir -p "${GOPATH}"
export PATH="${GOPATH}/bin:${PATH}"
fi
export GOPATH
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
go get github.com/constabulary/gb/...
gb generate
gb build
cd ..
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8000}
./jenkins/prep_sytest_for_postgres.sh
mkdir -p var
echo >&2 "Running sytest with PostgreSQL";
./jenkins/install_and_run.sh --python $TOX_BIN/python \
--synapse-directory $WORKSPACE \ --synapse-directory $WORKSPACE \
--dendron $WORKSPACE/dendron/bin/dendron \ --dendron $WORKSPACE/dendron/bin/dendron \
--pusher \ --pusher \
--synchrotron \ --synchrotron \
--port-base $PORT_BASE --federation-reader \
cd ..

View file

@ -4,60 +4,14 @@ set -eux
: ${WORKSPACE:="$(pwd)"} : ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1 export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml ./jenkins/prepare_synapse.sh
export TRIAL_FLAGS="--reporter=subunit" ./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log ./sytest/jenkins/prep_sytest_for_postgres.sh
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove" ./sytest/jenkins/install_and_run.sh \
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install psycopg2
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8000}
./jenkins/prep_sytest_for_postgres.sh
echo >&2 "Running sytest with PostgreSQL";
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \ --synapse-directory $WORKSPACE \
--port-base $PORT_BASE
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

View file

@ -4,54 +4,12 @@ set -eux
: ${WORKSPACE:="$(pwd)"} : ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1 export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml ./jenkins/prepare_synapse.sh
export TRIAL_FLAGS="--reporter=subunit" ./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log ./sytest/jenkins/install_and_run.sh \
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8500}
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \ --synapse-directory $WORKSPACE \
--port-base $PORT_BASE
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

View file

@ -22,4 +22,8 @@ export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished w
rm .coverage* || echo "No coverage files to remove" rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
tox -e py27 tox -e py27

44
jenkins/clone.sh Executable file
View file

@ -0,0 +1,44 @@
#! /bin/bash
# This clones a project from github into a named subdirectory
# If the project has a branch with the same name as this branch
# then it will checkout that branch after cloning.
# Otherwise it will checkout "origin/develop."
# The first argument is the name of the directory to checkout
# the branch into.
# The second argument is the URL of the remote repository to checkout.
# Usually something like https://github.com/matrix-org/sytest.git
set -eux
NAME=$1
PROJECT=$2
BASE=".$NAME-base"
# Update our mirror.
if [ ! -d ".$NAME-base" ]; then
# Create a local mirror of the source repository.
# This saves us from having to download the entire repository
# when this script is next run.
git clone "$PROJECT" "$BASE" --mirror
else
# Fetch any updates from the source repository.
(cd "$BASE"; git fetch -p)
fi
# Remove the existing repository so that we have a clean copy
rm -rf "$NAME"
# Cloning with --shared means that we will share portions of the
# .git directory with our local mirror.
git clone "$BASE" "$NAME" --shared
# Jenkins may have supplied us with the name of the branch in the
# environment. Otherwise we will have to guess based on the current
# commit.
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
cd "$NAME"
# check out the relevant branch
git checkout "${GIT_BRANCH}" || (
echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop"
git checkout "origin/develop"
)

19
jenkins/prepare_synapse.sh Executable file
View file

@ -0,0 +1,19 @@
#! /bin/bash
cd "`dirname $0`/.."
TOX_DIR=$WORKSPACE/.tox
mkdir -p $TOX_DIR
if ! [ $TOX_DIR -ef .tox ]; then
ln -s "$TOX_DIR" .tox
fi
# set up the virtualenv
tox -e py27 --notest -v
TOX_BIN=$TOX_DIR/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
$TOX_BIN/pip install psycopg2

View file

@ -36,7 +36,7 @@
<div class="debug"> <div class="debug">
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
an event was received at {{ reason.received_at|format_ts("%c") }} an event was received at {{ reason.received_at|format_ts("%c") }}
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} (delay_before_mail_ms) mins ago, which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
{% if reason.last_sent_ts %} {% if reason.last_sent_ts %}
and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }}, and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago. which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.

View file

@ -116,17 +116,19 @@ def get_json(origin_name, origin_key, destination, path):
authorization_headers = [] authorization_headers = []
for key, sig in signed_json["signatures"][origin_name].items(): for key, sig in signed_json["signatures"][origin_name].items():
authorization_headers.append(bytes( header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
origin_name, key, sig, origin_name, key, sig,
) )
)) authorization_headers.append(bytes(header))
sys.stderr.write(header)
sys.stderr.write("\n")
result = requests.get( result = requests.get(
lookup(destination, path), lookup(destination, path),
headers={"Authorization": authorization_headers[0]}, headers={"Authorization": authorization_headers[0]},
verify=False, verify=False,
) )
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
return result.json() return result.json()
@ -141,6 +143,7 @@ def main():
) )
json.dump(result, sys.stdout) json.dump(result, sys.stdout)
print ""
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View file

@ -1,10 +1,16 @@
#!/usr/bin/env python #!/usr/bin/env python
import argparse import argparse
import sys
import bcrypt import bcrypt
import getpass import getpass
import yaml
bcrypt_rounds=12 bcrypt_rounds=12
password_pepper = ""
def prompt_for_pass(): def prompt_for_pass():
password = getpass.getpass("Password: ") password = getpass.getpass("Password: ")
@ -28,12 +34,22 @@ if __name__ == "__main__":
default=None, default=None,
help="New password for user. Will prompt if omitted.", help="New password for user. Will prompt if omitted.",
) )
parser.add_argument(
"-c", "--config",
type=argparse.FileType('r'),
help="Path to server config file. Used to read in bcrypt_rounds and password_pepper.",
)
args = parser.parse_args() args = parser.parse_args()
if "config" in args and args.config:
config = yaml.safe_load(args.config)
bcrypt_rounds = config.get("bcrypt_rounds", bcrypt_rounds)
password_config = config.get("password_config", {})
password_pepper = password_config.get("pepper", password_pepper)
password = args.password password = args.password
if not password: if not password:
password = prompt_for_pass() password = prompt_for_pass()
print bcrypt.hashpw(password, bcrypt.gensalt(bcrypt_rounds)) print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds))

View file

@ -25,18 +25,26 @@ import urllib2
import yaml import yaml
def request_registration(user, password, server_location, shared_secret): def request_registration(user, password, server_location, shared_secret, admin=False):
mac = hmac.new( mac = hmac.new(
key=shared_secret, key=shared_secret,
msg=user,
digestmod=hashlib.sha1, digestmod=hashlib.sha1,
).hexdigest() )
mac.update(user)
mac.update("\x00")
mac.update(password)
mac.update("\x00")
mac.update("admin" if admin else "notadmin")
mac = mac.hexdigest()
data = { data = {
"user": user, "user": user,
"password": password, "password": password,
"mac": mac, "mac": mac,
"type": "org.matrix.login.shared_secret", "type": "org.matrix.login.shared_secret",
"admin": admin,
} }
server_location = server_location.rstrip("/") server_location = server_location.rstrip("/")
@ -68,7 +76,7 @@ def request_registration(user, password, server_location, shared_secret):
sys.exit(1) sys.exit(1)
def register_new_user(user, password, server_location, shared_secret): def register_new_user(user, password, server_location, shared_secret, admin):
if not user: if not user:
try: try:
default_user = getpass.getuser() default_user = getpass.getuser()
@ -99,7 +107,14 @@ def register_new_user(user, password, server_location, shared_secret):
print "Passwords do not match" print "Passwords do not match"
sys.exit(1) sys.exit(1)
request_registration(user, password, server_location, shared_secret) if not admin:
admin = raw_input("Make admin [no]: ")
if admin in ("y", "yes", "true"):
admin = True
else:
admin = False
request_registration(user, password, server_location, shared_secret, bool(admin))
if __name__ == "__main__": if __name__ == "__main__":
@ -119,6 +134,11 @@ if __name__ == "__main__":
default=None, default=None,
help="New password for user. Will prompt if omitted.", help="New password for user. Will prompt if omitted.",
) )
parser.add_argument(
"-a", "--admin",
action="store_true",
help="Register new user as an admin. Will prompt if omitted.",
)
group = parser.add_mutually_exclusive_group(required=True) group = parser.add_mutually_exclusive_group(required=True)
group.add_argument( group.add_argument(
@ -151,4 +171,4 @@ if __name__ == "__main__":
else: else:
secret = args.shared_secret secret = args.shared_secret
register_new_user(args.user, args.password, args.server_url, secret) register_new_user(args.user, args.password, args.server_url, secret, args.admin)

View file

@ -34,7 +34,7 @@ logger = logging.getLogger("synapse_port_db")
BOOLEAN_COLUMNS = { BOOLEAN_COLUMNS = {
"events": ["processed", "outlier"], "events": ["processed", "outlier", "contains_url"],
"rooms": ["is_public"], "rooms": ["is_public"],
"event_edges": ["is_state"], "event_edges": ["is_state"],
"presence_list": ["accepted"], "presence_list": ["accepted"],
@ -92,8 +92,12 @@ class Store(object):
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"] _simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
_simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"] _simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
_simple_select_one = SQLBaseStore.__dict__["_simple_select_one"]
_simple_select_one_txn = SQLBaseStore.__dict__["_simple_select_one_txn"]
_simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"] _simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
_simple_select_one_onecol_txn = SQLBaseStore.__dict__["_simple_select_one_onecol_txn"] _simple_select_one_onecol_txn = SQLBaseStore.__dict__[
"_simple_select_one_onecol_txn"
]
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"] _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"] _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
@ -158,31 +162,40 @@ class Porter(object):
def setup_table(self, table): def setup_table(self, table):
if table in APPEND_ONLY_TABLES: if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting. # It's safe to just carry on inserting.
next_chunk = yield self.postgres_store._simple_select_one_onecol( row = yield self.postgres_store._simple_select_one(
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": table}, keyvalues={"table_name": table},
retcol="rowid", retcols=("forward_rowid", "backward_rowid"),
allow_none=True, allow_none=True,
) )
total_to_port = None total_to_port = None
if next_chunk is None: if row is None:
if table == "sent_transactions": if table == "sent_transactions":
next_chunk, already_ported, total_to_port = ( forward_chunk, already_ported, total_to_port = (
yield self._setup_sent_transactions() yield self._setup_sent_transactions()
) )
backward_chunk = 0
else: else:
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={"table_name": table, "rowid": 1} values={
"table_name": table,
"forward_rowid": 1,
"backward_rowid": 0,
}
) )
next_chunk = 1 forward_chunk = 1
backward_chunk = 0
already_ported = 0 already_ported = 0
else:
forward_chunk = row["forward_rowid"]
backward_chunk = row["backward_rowid"]
if total_to_port is None: if total_to_port is None:
already_ported, total_to_port = yield self._get_total_count_to_port( already_ported, total_to_port = yield self._get_total_count_to_port(
table, next_chunk table, forward_chunk, backward_chunk
) )
else: else:
def delete_all(txn): def delete_all(txn):
@ -196,46 +209,85 @@ class Porter(object):
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={"table_name": table, "rowid": 0} values={
"table_name": table,
"forward_rowid": 1,
"backward_rowid": 0,
}
) )
next_chunk = 1 forward_chunk = 1
backward_chunk = 0
already_ported, total_to_port = yield self._get_total_count_to_port( already_ported, total_to_port = yield self._get_total_count_to_port(
table, next_chunk table, forward_chunk, backward_chunk
) )
defer.returnValue((table, already_ported, total_to_port, next_chunk)) defer.returnValue(
(table, already_ported, total_to_port, forward_chunk, backward_chunk)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_table(self, table, postgres_size, table_size, next_chunk): def handle_table(self, table, postgres_size, table_size, forward_chunk,
backward_chunk):
if not table_size: if not table_size:
return return
self.progress.add_table(table, postgres_size, table_size) self.progress.add_table(table, postgres_size, table_size)
if table == "event_search": if table == "event_search":
yield self.handle_search_table(postgres_size, table_size, next_chunk) yield self.handle_search_table(
postgres_size, table_size, forward_chunk, backward_chunk
)
return return
select = ( forward_select = (
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
% (table,) % (table,)
) )
backward_select = (
"SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?"
% (table,)
)
do_forward = [True]
do_backward = [True]
while True: while True:
def r(txn): def r(txn):
txn.execute(select, (next_chunk, self.batch_size,)) forward_rows = []
rows = txn.fetchall() backward_rows = []
if do_forward[0]:
txn.execute(forward_select, (forward_chunk, self.batch_size,))
forward_rows = txn.fetchall()
if not forward_rows:
do_forward[0] = False
if do_backward[0]:
txn.execute(backward_select, (backward_chunk, self.batch_size,))
backward_rows = txn.fetchall()
if not backward_rows:
do_backward[0] = False
if forward_rows or backward_rows:
headers = [column[0] for column in txn.description] headers = [column[0] for column in txn.description]
else:
headers = None
return headers, rows return headers, forward_rows, backward_rows
headers, rows = yield self.sqlite_store.runInteraction("select", r) headers, frows, brows = yield self.sqlite_store.runInteraction(
"select", r
)
if rows: if frows or brows:
next_chunk = rows[-1][0] + 1 if frows:
forward_chunk = max(row[0] for row in frows) + 1
if brows:
backward_chunk = min(row[0] for row in brows) - 1
rows = frows + brows
self._convert_rows(table, headers, rows) self._convert_rows(table, headers, rows)
def insert(txn): def insert(txn):
@ -247,7 +299,10 @@ class Porter(object):
txn, txn,
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": table}, keyvalues={"table_name": table},
updatevalues={"rowid": next_chunk}, updatevalues={
"forward_rowid": forward_chunk,
"backward_rowid": backward_chunk,
},
) )
yield self.postgres_store.execute(insert) yield self.postgres_store.execute(insert)
@ -259,7 +314,8 @@ class Porter(object):
return return
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_search_table(self, postgres_size, table_size, next_chunk): def handle_search_table(self, postgres_size, table_size, forward_chunk,
backward_chunk):
select = ( select = (
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering" "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
" FROM event_search as es" " FROM event_search as es"
@ -270,7 +326,7 @@ class Porter(object):
while True: while True:
def r(txn): def r(txn):
txn.execute(select, (next_chunk, self.batch_size,)) txn.execute(select, (forward_chunk, self.batch_size,))
rows = txn.fetchall() rows = txn.fetchall()
headers = [column[0] for column in txn.description] headers = [column[0] for column in txn.description]
@ -279,7 +335,7 @@ class Porter(object):
headers, rows = yield self.sqlite_store.runInteraction("select", r) headers, rows = yield self.sqlite_store.runInteraction("select", r)
if rows: if rows:
next_chunk = rows[-1][0] + 1 forward_chunk = rows[-1][0] + 1
# We have to treat event_search differently since it has a # We have to treat event_search differently since it has a
# different structure in the two different databases. # different structure in the two different databases.
@ -312,7 +368,10 @@ class Porter(object):
txn, txn,
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": "event_search"}, keyvalues={"table_name": "event_search"},
updatevalues={"rowid": next_chunk}, updatevalues={
"forward_rowid": forward_chunk,
"backward_rowid": backward_chunk,
},
) )
yield self.postgres_store.execute(insert) yield self.postgres_store.execute(insert)
@ -324,7 +383,6 @@ class Porter(object):
else: else:
return return
def setup_db(self, db_config, database_engine): def setup_db(self, db_config, database_engine):
db_conn = database_engine.module.connect( db_conn = database_engine.module.connect(
**{ **{
@ -395,10 +453,32 @@ class Porter(object):
txn.execute( txn.execute(
"CREATE TABLE port_from_sqlite3 (" "CREATE TABLE port_from_sqlite3 ("
" table_name varchar(100) NOT NULL UNIQUE," " table_name varchar(100) NOT NULL UNIQUE,"
" rowid bigint NOT NULL" " forward_rowid bigint NOT NULL,"
" backward_rowid bigint NOT NULL"
")" ")"
) )
# The old port script created a table with just a "rowid" column.
# We want people to be able to rerun this script from an old port
# so that they can pick up any missing events that were not
# ported across.
def alter_table(txn):
txn.execute(
"ALTER TABLE IF EXISTS port_from_sqlite3"
" RENAME rowid TO forward_rowid"
)
txn.execute(
"ALTER TABLE IF EXISTS port_from_sqlite3"
" ADD backward_rowid bigint NOT NULL DEFAULT 0"
)
try:
yield self.postgres_store.runInteraction(
"alter_table", alter_table
)
except Exception as e:
logger.info("Failed to create port table: %s", e)
try: try:
yield self.postgres_store.runInteraction( yield self.postgres_store.runInteraction(
"create_port_table", create_port_table "create_port_table", create_port_table
@ -458,7 +538,7 @@ class Porter(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _setup_sent_transactions(self): def _setup_sent_transactions(self):
# Only save things from the last day # Only save things from the last day
yesterday = int(time.time()*1000) - 86400000 yesterday = int(time.time() * 1000) - 86400000
# And save the max transaction id from each destination # And save the max transaction id from each destination
select = ( select = (
@ -514,7 +594,11 @@ class Porter(object):
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={"table_name": "sent_transactions", "rowid": next_chunk} values={
"table_name": "sent_transactions",
"forward_rowid": next_chunk,
"backward_rowid": 0,
}
) )
def get_sent_table_size(txn): def get_sent_table_size(txn):
@ -535,13 +619,18 @@ class Porter(object):
defer.returnValue((next_chunk, inserted_rows, total_count)) defer.returnValue((next_chunk, inserted_rows, total_count))
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_remaining_count_to_port(self, table, next_chunk): def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
rows = yield self.sqlite_store.execute_sql( frows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
next_chunk, forward_chunk,
) )
defer.returnValue(rows[0][0]) brows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,),
backward_chunk,
)
defer.returnValue(frows[0][0] + brows[0][0])
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_already_ported_count(self, table): def _get_already_ported_count(self, table):
@ -552,10 +641,10 @@ class Porter(object):
defer.returnValue(rows[0][0]) defer.returnValue(rows[0][0])
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_total_count_to_port(self, table, next_chunk): def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
remaining, done = yield defer.gatherResults( remaining, done = yield defer.gatherResults(
[ [
self._get_remaining_count_to_port(table, next_chunk), self._get_remaining_count_to_port(table, forward_chunk, backward_chunk),
self._get_already_ported_count(table), self._get_already_ported_count(table),
], ],
consumeErrors=True, consumeErrors=True,
@ -686,7 +775,7 @@ class CursesProgress(Progress):
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1) color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
self.stdscr.addstr( self.stdscr.addstr(
i+2, left_margin + max_len - len(table), i + 2, left_margin + max_len - len(table),
table, table,
curses.A_BOLD | color, curses.A_BOLD | color,
) )
@ -694,18 +783,18 @@ class CursesProgress(Progress):
size = 20 size = 20
progress = "[%s%s]" % ( progress = "[%s%s]" % (
"#" * int(perc*size/100), "#" * int(perc * size / 100),
" " * (size - int(perc*size/100)), " " * (size - int(perc * size / 100)),
) )
self.stdscr.addstr( self.stdscr.addstr(
i+2, left_margin + max_len + middle_space, i + 2, left_margin + max_len + middle_space,
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]), "%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
) )
if self.finished: if self.finished:
self.stdscr.addstr( self.stdscr.addstr(
rows-1, 0, rows - 1, 0,
"Press any key to exit...", "Press any key to exit...",
) )

View file

@ -16,7 +16,5 @@ ignore =
[flake8] [flake8]
max-line-length = 90 max-line-length = 90
ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it. # W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
ignore = W503
[pep8]
max-line-length = 90

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.16.1-r1" __version__ = "0.17.0"

View file

@ -13,22 +13,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import pymacaroons
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException from signedjson.sign import verify_signed_json, SignatureVerifyException
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import Requester, UserID, get_domain_from_id
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.metrics import Measure
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
import logging import synapse.types
import pymacaroons from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import UserID, get_domain_from_id
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -63,7 +63,7 @@ class Auth(object):
"user_id = ", "user_id = ",
]) ])
def check(self, event, auth_events): def check(self, event, auth_events, do_sig_check=True):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
Args: Args:
@ -79,6 +79,13 @@ class Auth(object):
if not hasattr(event, "room_id"): if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event) raise AuthError(500, "Event has no room_id: %s" % event)
sender_domain = get_domain_from_id(event.sender)
# Check the sender's domain has signed the event
if do_sig_check and not event.signatures.get(sender_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None: if auth_events is None:
# Oh, we don't know what the state of the room was, so we # Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now) # are trusting that this is allowed (at least for now)
@ -86,6 +93,12 @@ class Auth(object):
return True return True
if event.type == EventTypes.Create: if event.type == EventTypes.Create:
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
403,
"Creation event's room_id domain does not match sender's"
)
# FIXME # FIXME
return True return True
@ -108,6 +121,22 @@ class Auth(object):
# FIXME: Temp hack # FIXME: Temp hack
if event.type == EventTypes.Aliases: if event.type == EventTypes.Aliases:
if not event.is_state():
raise AuthError(
403,
"Alias event must be a state event",
)
if not event.state_key:
raise AuthError(
403,
"Alias event must have non-empty state_key"
)
sender_domain = get_domain_from_id(event.sender)
if event.state_key != sender_domain:
raise AuthError(
403,
"Alias event's state_key does not match sender's domain"
)
return True return True
logger.debug( logger.debug(
@ -347,6 +376,10 @@ class Auth(object):
if Membership.INVITE == membership and "third_party_invite" in event.content: if Membership.INVITE == membership and "third_party_invite" in event.content:
if not self._verify_third_party_invite(event, auth_events): if not self._verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.") raise AuthError(403, "You are not invited to this room.")
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
return True return True
if Membership.JOIN != membership: if Membership.JOIN != membership:
@ -537,9 +570,7 @@ class Auth(object):
Args: Args:
request - An HTTP request with an access_token query parameter. request - An HTTP request with an access_token query parameter.
Returns: Returns:
tuple of: defer.Deferred: resolves to a ``synapse.types.Requester`` object
UserID (str)
Access token ID (str)
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
@ -548,9 +579,7 @@ class Auth(object):
user_id = yield self._get_appservice_user_id(request.args) user_id = yield self._get_appservice_user_id(request.args)
if user_id: if user_id:
request.authenticated_entity = user_id request.authenticated_entity = user_id
defer.returnValue( defer.returnValue(synapse.types.create_requester(user_id))
Requester(UserID.from_string(user_id), "", False)
)
access_token = request.args["access_token"][0] access_token = request.args["access_token"][0]
user_info = yield self.get_user_by_access_token(access_token, rights) user_info = yield self.get_user_by_access_token(access_token, rights)
@ -558,6 +587,10 @@ class Auth(object):
token_id = user_info["token_id"] token_id = user_info["token_id"]
is_guest = user_info["is_guest"] is_guest = user_info["is_guest"]
# device_id may not be present if get_user_by_access_token has been
# stubbed out.
device_id = user_info.get("device_id")
ip_addr = self.hs.get_ip_from_request(request) ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders( user_agent = request.requestHeaders.getRawHeaders(
"User-Agent", "User-Agent",
@ -569,7 +602,8 @@ class Auth(object):
user=user, user=user,
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,
user_agent=user_agent user_agent=user_agent,
device_id=device_id,
) )
if is_guest and not allow_guest: if is_guest and not allow_guest:
@ -579,7 +613,8 @@ class Auth(object):
request.authenticated_entity = user.to_string() request.authenticated_entity = user.to_string()
defer.returnValue(Requester(user, token_id, is_guest)) defer.returnValue(synapse.types.create_requester(
user, token_id, is_guest, device_id))
except KeyError: except KeyError:
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@ -629,7 +664,10 @@ class Auth(object):
except AuthError: except AuthError:
# TODO(daniel): Remove this fallback when all existing access tokens # TODO(daniel): Remove this fallback when all existing access tokens
# have been re-issued as macaroons. # have been re-issued as macaroons.
if self.hs.config.expire_access_token:
raise
ret = yield self._look_up_user_by_access_token(token) ret = yield self._look_up_user_by_access_token(token)
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -664,6 +702,7 @@ class Auth(object):
"user": user, "user": user,
"is_guest": True, "is_guest": True,
"token_id": None, "token_id": None,
"device_id": None,
} }
elif rights == "delete_pusher": elif rights == "delete_pusher":
# We don't store these tokens in the database # We don't store these tokens in the database
@ -671,13 +710,20 @@ class Auth(object):
"user": user, "user": user,
"is_guest": False, "is_guest": False,
"token_id": None, "token_id": None,
"device_id": None,
} }
else: else:
# This codepath exists so that we can actually return a # This codepath exists for several reasons:
# token ID, because we use token IDs in place of device # * so that we can actually return a token ID, which is used
# identifiers throughout the codebase. # in some parts of the schema (where we probably ought to
# TODO(daniel): Remove this fallback when device IDs are # use device IDs instead)
# properly implemented. # * the only way we currently have to invalidate an
# access_token is by removing it from the database, so we
# have to check here that it is still in the db
# * some attributes (notably device_id) aren't stored in the
# macaroon. They probably should be.
# TODO: build the dictionary from the macaroon once the
# above are fixed
ret = yield self._look_up_user_by_access_token(macaroon_str) ret = yield self._look_up_user_by_access_token(macaroon_str)
if ret["user"] != user: if ret["user"] != user:
logger.error( logger.error(
@ -751,10 +797,14 @@ class Auth(object):
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
# we use ret.get() below because *lots* of unit tests stub out
# get_user_by_access_token in a way where it only returns a couple of
# the fields.
user_info = { user_info = {
"user": UserID.from_string(ret.get("name")), "user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None), "token_id": ret.get("token_id", None),
"is_guest": False, "is_guest": False,
"device_id": ret.get("device_id"),
} }
defer.returnValue(user_info) defer.returnValue(user_info)

View file

@ -42,8 +42,10 @@ class Codes(object):
TOO_LARGE = "M_TOO_LARGE" TOO_LARGE = "M_TOO_LARGE"
EXCLUSIVE = "M_EXCLUSIVE" EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "THREEPID_IN_USE" THREEPID_IN_USE = "M_THREEPID_IN_USE"
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
INVALID_USERNAME = "M_INVALID_USERNAME" INVALID_USERNAME = "M_INVALID_USERNAME"
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):

View file

@ -191,6 +191,17 @@ class Filter(object):
def __init__(self, filter_json): def __init__(self, filter_json):
self.filter_json = filter_json self.filter_json = filter_json
self.types = self.filter_json.get("types", None)
self.not_types = self.filter_json.get("not_types", [])
self.rooms = self.filter_json.get("rooms", None)
self.not_rooms = self.filter_json.get("not_rooms", [])
self.senders = self.filter_json.get("senders", None)
self.not_senders = self.filter_json.get("not_senders", [])
self.contains_url = self.filter_json.get("contains_url", None)
def check(self, event): def check(self, event):
"""Checks whether the filter matches the given event. """Checks whether the filter matches the given event.
@ -209,9 +220,10 @@ class Filter(object):
event.get("room_id", None), event.get("room_id", None),
sender, sender,
event.get("type", None), event.get("type", None),
"url" in event.get("content", {})
) )
def check_fields(self, room_id, sender, event_type): def check_fields(self, room_id, sender, event_type, contains_url):
"""Checks whether the filter matches the given event fields. """Checks whether the filter matches the given event fields.
Returns: Returns:
@ -225,15 +237,20 @@ class Filter(object):
for name, match_func in literal_keys.items(): for name, match_func in literal_keys.items():
not_name = "not_%s" % (name,) not_name = "not_%s" % (name,)
disallowed_values = self.filter_json.get(not_name, []) disallowed_values = getattr(self, not_name)
if any(map(match_func, disallowed_values)): if any(map(match_func, disallowed_values)):
return False return False
allowed_values = self.filter_json.get(name, None) allowed_values = getattr(self, name)
if allowed_values is not None: if allowed_values is not None:
if not any(map(match_func, allowed_values)): if not any(map(match_func, allowed_values)):
return False return False
contains_url_filter = self.filter_json.get("contains_url")
if contains_url_filter is not None:
if contains_url_filter != contains_url:
return False
return True return True
def filter_rooms(self, room_ids): def filter_rooms(self, room_ids):

View file

@ -16,13 +16,11 @@
import sys import sys
sys.dont_write_bytecode = True sys.dont_write_bytecode = True
from synapse.python_dependencies import ( from synapse import python_dependencies # noqa: E402
check_requirements, MissingRequirementError
) # NOQA
try: try:
check_requirements() python_dependencies.check_requirements()
except MissingRequirementError as e: except python_dependencies.MissingRequirementError as e:
message = "\n".join([ message = "\n".join([
"Missing Requirement: %s" % (e.message,), "Missing Requirement: %s" % (e.message,),
"To install run:", "To install run:",

View file

@ -0,0 +1,206 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.api.urls import FEDERATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer
from synapse.crypto import context_factory
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.federation_reader")
class FederationReaderSlavedStore(
SlavedEventStore,
SlavedKeyStore,
RoomStore,
DirectoryStore,
TransactionStore,
BaseSlavedStore,
):
pass
class FederationReaderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "federation":
resources.update({
FEDERATION_PREFIX: TransportLayerServer(self),
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse federation reader now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse federation reader", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.federation_reader"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
ss = FederationReaderServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-federation-reader",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View file

@ -51,6 +51,7 @@ from synapse.api.urls import (
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.metrics import register_memory_metrics
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer from synapse.federation.transport.server import TransportLayerServer
@ -147,7 +148,7 @@ class SynapseHomeServer(HomeServer):
MEDIA_PREFIX: media_repo, MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource( CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path, self.auth, self.content_addr self, self.config.uploads_path
), ),
}) })
@ -284,7 +285,7 @@ def setup(config_options):
# check any extra requirements we have now we have a config # check any extra requirements we have now we have a config
check_requirements(config) check_requirements(config)
version_string = get_version_string("Synapse", synapse) version_string = "Synapse/" + get_version_string(synapse)
logger.info("Server hostname: %s", config.server_name) logger.info("Server hostname: %s", config.server_name)
logger.info("Server version: %s", version_string) logger.info("Server version: %s", version_string)
@ -301,7 +302,6 @@ def setup(config_options):
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
config=config, config=config,
content_addr=config.content_addr,
version_string=version_string, version_string=version_string,
database_engine=database_engine, database_engine=database_engine,
) )
@ -336,6 +336,8 @@ def setup(config_options):
hs.get_datastore().start_doing_background_updates() hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache() hs.get_replication_layer().start_get_pdu_cache()
register_memory_metrics(hs)
reactor.callWhenRunning(start) reactor.callWhenRunning(start)
return hs return hs

View file

@ -273,7 +273,7 @@ def start(config_options):
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
config=config, config=config,
version_string=get_version_string("Synapse", synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
) )

View file

@ -424,7 +424,7 @@ def start(config_options):
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
config=config, config=config,
version_string=get_version_string("Synapse", synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
application_service_handler=SynchrotronApplicationService(), application_service_handler=SynchrotronApplicationService(),
) )

View file

@ -13,40 +13,88 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import Config from ._base import Config, ConfigError
MISSING_LDAP3 = (
"Missing ldap3 library. This is required for LDAP Authentication."
)
class LDAPMode(object):
SIMPLE = "simple",
SEARCH = "search",
LIST = (SIMPLE, SEARCH)
class LDAPConfig(Config): class LDAPConfig(Config):
def read_config(self, config): def read_config(self, config):
ldap_config = config.get("ldap_config", None) ldap_config = config.get("ldap_config", {})
if ldap_config:
self.ldap_enabled = ldap_config.get("enabled", False) self.ldap_enabled = ldap_config.get("enabled", False)
self.ldap_server = ldap_config["server"]
self.ldap_port = ldap_config["port"] if self.ldap_enabled:
self.ldap_tls = ldap_config.get("tls", False) # verify dependencies are available
self.ldap_search_base = ldap_config["search_base"] try:
self.ldap_search_property = ldap_config["search_property"] import ldap3
self.ldap_email_property = ldap_config["email_property"] ldap3 # to stop unused lint
self.ldap_full_name_property = ldap_config["full_name_property"] except ImportError:
else: raise ConfigError(MISSING_LDAP3)
self.ldap_enabled = False
self.ldap_server = None self.ldap_mode = LDAPMode.SIMPLE
self.ldap_port = None
self.ldap_tls = False # verify config sanity
self.ldap_search_base = None self.require_keys(ldap_config, [
self.ldap_search_property = None "uri",
self.ldap_email_property = None "base",
self.ldap_full_name_property = None "attributes",
])
self.ldap_uri = ldap_config["uri"]
self.ldap_start_tls = ldap_config.get("start_tls", False)
self.ldap_base = ldap_config["base"]
self.ldap_attributes = ldap_config["attributes"]
if "bind_dn" in ldap_config:
self.ldap_mode = LDAPMode.SEARCH
self.require_keys(ldap_config, [
"bind_dn",
"bind_password",
])
self.ldap_bind_dn = ldap_config["bind_dn"]
self.ldap_bind_password = ldap_config["bind_password"]
self.ldap_filter = ldap_config.get("filter", None)
# verify attribute lookup
self.require_keys(ldap_config['attributes'], [
"uid",
"name",
"mail",
])
def require_keys(self, config, required):
missing = [key for key in required if key not in config]
if missing:
raise ConfigError(
"LDAP enabled but missing required config values: {}".format(
", ".join(missing)
)
)
def default_config(self, **kwargs): def default_config(self, **kwargs):
return """\ return """\
# ldap_config: # ldap_config:
# enabled: true # enabled: true
# server: "ldap://localhost" # uri: "ldap://ldap.example.com:389"
# port: 389 # start_tls: true
# tls: false # base: "ou=users,dc=example,dc=com"
# search_base: "ou=Users,dc=example,dc=com" # attributes:
# search_property: "cn" # uid: "cn"
# email_property: "email" # mail: "email"
# full_name_property: "givenName" # name: "givenName"
# #bind_dn:
# #bind_password:
# #filter: "(objectClass=posixAccount)"
""" """

View file

@ -23,10 +23,14 @@ class PasswordConfig(Config):
def read_config(self, config): def read_config(self, config):
password_config = config.get("password_config", {}) password_config = config.get("password_config", {})
self.password_enabled = password_config.get("enabled", True) self.password_enabled = password_config.get("enabled", True)
self.password_pepper = password_config.get("pepper", "")
def default_config(self, config_dir_path, server_name, **kwargs): def default_config(self, config_dir_path, server_name, **kwargs):
return """ return """
# Enable password for login. # Enable password for login.
password_config: password_config:
enabled: true enabled: true
# Uncomment and change to a secret random string for extra security.
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
#pepper: ""
""" """

View file

@ -107,26 +107,6 @@ class ServerConfig(Config):
] ]
}) })
# Attempt to guess the content_addr for the v0 content repostitory
content_addr = config.get("content_addr")
if not content_addr:
for listener in self.listeners:
if listener["type"] == "http" and not listener.get("tls", False):
unsecure_port = listener["port"]
break
else:
raise RuntimeError("Could not determine 'content_addr'")
host = self.server_name
if ':' not in host:
host = "%s:%d" % (host, unsecure_port)
else:
host = host.split(':')[0]
host = "%s:%d" % (host, unsecure_port)
content_addr = "http://%s" % (host,)
self.content_addr = content_addr
def default_config(self, server_name, **kwargs): def default_config(self, server_name, **kwargs):
if ":" in server_name: if ":" in server_name:
bind_port = int(server_name.split(":")[1]) bind_port = int(server_name.split(":")[1])
@ -169,7 +149,6 @@ class ServerConfig(Config):
# room directory. # room directory.
# secondary_directory_servers: # secondary_directory_servers:
# - matrix.org # - matrix.org
# - vector.im
# List of ports that Synapse should listen on, their purpose and their # List of ports that Synapse should listen on, their purpose and their
# configuration. # configuration.

View file

@ -77,10 +77,12 @@ class SynapseKeyClientProtocol(HTTPClient):
def __init__(self): def __init__(self):
self.remote_key = defer.Deferred() self.remote_key = defer.Deferred()
self.host = None self.host = None
self._peer = None
def connectionMade(self): def connectionMade(self):
self.host = self.transport.getHost() self._peer = self.transport.getPeer()
logger.debug("Connected to %s", self.host) logger.debug("Connected to %s", self._peer)
self.sendCommand(b"GET", self.path) self.sendCommand(b"GET", self.path)
if self.host: if self.host:
self.sendHeader(b"Host", self.host) self.sendHeader(b"Host", self.host)
@ -124,7 +126,10 @@ class SynapseKeyClientProtocol(HTTPClient):
self.timer.cancel() self.timer.cancel()
def on_timeout(self): def on_timeout(self):
logger.debug("Timeout waiting for response from %s", self.host) logger.debug(
"Timeout waiting for response from %s: %s",
self.host, self._peer,
)
self.errback(IOError("Timeout waiting for response")) self.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection() self.transport.abortConnection()
@ -133,4 +138,5 @@ class SynapseKeyClientFactory(Factory):
def protocol(self): def protocol(self):
protocol = SynapseKeyClientProtocol() protocol = SynapseKeyClientProtocol()
protocol.path = self.path protocol.path = self.path
protocol.host = self.host
return protocol return protocol

View file

@ -44,7 +44,21 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids")) VerifyKeyRequest = namedtuple("VerifyRequest", (
"server_name", "key_ids", "json_object", "deferred"
))
"""
A request for a verify key to verify a JSON object.
Attributes:
server_name(str): The name of the server to verify against.
key_ids(set(str)): The set of key_ids to that could be used to verify the
JSON object
json_object(dict): The JSON object to verify.
deferred(twisted.internet.defer.Deferred):
A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched
"""
class Keyring(object): class Keyring(object):
@ -74,39 +88,32 @@ class Keyring(object):
list of deferreds indicating success or failure to verify each list of deferreds indicating success or failure to verify each
json object's signature for the given server_name. json object's signature for the given server_name.
""" """
group_id_to_json = {} verify_requests = []
group_id_to_group = {}
group_ids = []
next_group_id = 0
deferreds = {}
for server_name, json_object in server_and_json: for server_name, json_object in server_and_json:
logger.debug("Verifying for %s", server_name) logger.debug("Verifying for %s", server_name)
group_id = next_group_id
next_group_id += 1
group_ids.append(group_id)
key_ids = signature_ids(json_object, server_name) key_ids = signature_ids(json_object, server_name)
if not key_ids: if not key_ids:
deferreds[group_id] = defer.fail(SynapseError( deferred = defer.fail(SynapseError(
400, 400,
"Not signed with a supported algorithm", "Not signed with a supported algorithm",
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
)) ))
else: else:
deferreds[group_id] = defer.Deferred() deferred = defer.Deferred()
group = KeyGroup(server_name, group_id, key_ids) verify_request = VerifyKeyRequest(
server_name, key_ids, json_object, deferred
)
group_id_to_group[group_id] = group verify_requests.append(verify_request)
group_id_to_json[group_id] = json_object
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_key_deferred(group, deferred): def handle_key_deferred(verify_request):
server_name = group.server_name server_name = verify_request.server_name
try: try:
_, _, key_id, verify_key = yield deferred _, key_id, verify_key = yield verify_request.deferred
except IOError as e: except IOError as e:
logger.warn( logger.warn(
"Got IOError when downloading keys for %s: %s %s", "Got IOError when downloading keys for %s: %s %s",
@ -128,7 +135,7 @@ class Keyring(object):
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
json_object = group_id_to_json[group.group_id] json_object = verify_request.json_object
try: try:
verify_signed_json(json_object, server_name, verify_key) verify_signed_json(json_object, server_name, verify_key)
@ -157,36 +164,34 @@ class Keyring(object):
# Actually start fetching keys. # Actually start fetching keys.
wait_on_deferred.addBoth( wait_on_deferred.addBoth(
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds) lambda _: self.get_server_verify_keys(verify_requests)
) )
# When we've finished fetching all the keys for a given server_name, # When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that # resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed. # any lookups waiting will proceed.
server_to_gids = {} server_to_request_ids = {}
def remove_deferreds(res, server_name, group_id): def remove_deferreds(res, server_name, verify_request):
server_to_gids[server_name].discard(group_id) request_id = id(verify_request)
if not server_to_gids[server_name]: server_to_request_ids[server_name].discard(request_id)
if not server_to_request_ids[server_name]:
d = server_to_deferred.pop(server_name, None) d = server_to_deferred.pop(server_name, None)
if d: if d:
d.callback(None) d.callback(None)
return res return res
for g_id, deferred in deferreds.items(): for verify_request in verify_requests:
server_name = group_id_to_group[g_id].server_name server_name = verify_request.server_name
server_to_gids.setdefault(server_name, set()).add(g_id) request_id = id(verify_request)
deferred.addBoth(remove_deferreds, server_name, g_id) server_to_request_ids.setdefault(server_name, set()).add(request_id)
deferred.addBoth(remove_deferreds, server_name, verify_request)
# Pass those keys to handle_key_deferred so that the json object # Pass those keys to handle_key_deferred so that the json object
# signatures can be verified # signatures can be verified
return [ return [
preserve_context_over_fn( preserve_context_over_fn(handle_key_deferred, verify_request)
handle_key_deferred, for verify_request in verify_requests
group_id_to_group[g_id],
deferreds[g_id],
)
for g_id in group_ids
] ]
@defer.inlineCallbacks @defer.inlineCallbacks
@ -220,7 +225,7 @@ class Keyring(object):
d.addBoth(rm, server_name) d.addBoth(rm, server_name)
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred): def get_server_verify_keys(self, verify_requests):
"""Takes a dict of KeyGroups and tries to find at least one key for """Takes a dict of KeyGroups and tries to find at least one key for
each group. each group.
""" """
@ -237,62 +242,64 @@ class Keyring(object):
merged_results = {} merged_results = {}
missing_keys = {} missing_keys = {}
for group in group_id_to_group.values(): for verify_request in verify_requests:
missing_keys.setdefault(group.server_name, set()).update( missing_keys.setdefault(verify_request.server_name, set()).update(
group.key_ids verify_request.key_ids
) )
for fn in key_fetch_fns: for fn in key_fetch_fns:
results = yield fn(missing_keys.items()) results = yield fn(missing_keys.items())
merged_results.update(results) merged_results.update(results)
# We now need to figure out which groups we have keys for # We now need to figure out which verify requests we have keys
# and which we don't # for and which we don't
missing_groups = {} missing_keys = {}
for group in group_id_to_group.values(): requests_missing_keys = []
for key_id in group.key_ids: for verify_request in verify_requests:
if key_id in merged_results[group.server_name]: server_name = verify_request.server_name
result_keys = merged_results[server_name]
if verify_request.deferred.called:
# We've already called this deferred, which probably
# means that we've already found a key for it.
continue
for key_id in verify_request.key_ids:
if key_id in result_keys:
with PreserveLoggingContext(): with PreserveLoggingContext():
group_id_to_deferred[group.group_id].callback(( verify_request.deferred.callback((
group.group_id, server_name,
group.server_name,
key_id, key_id,
merged_results[group.server_name][key_id], result_keys[key_id],
)) ))
break break
else: else:
missing_groups.setdefault( # The else block is only reached if the loop above
group.server_name, [] # doesn't break.
).append(group) missing_keys.setdefault(server_name, set()).update(
verify_request.key_ids
)
requests_missing_keys.append(verify_request)
if not missing_groups: if not missing_keys:
break break
missing_keys = { for verify_request in requests_missing_keys.values():
server_name: set( verify_request.deferred.errback(SynapseError(
key_id for group in groups for key_id in group.key_ids
)
for server_name, groups in missing_groups.items()
}
for group in missing_groups.values():
group_id_to_deferred[group.group_id].errback(SynapseError(
401, 401,
"No key for %s with id %s" % ( "No key for %s with id %s" % (
group.server_name, group.key_ids, verify_request.server_name, verify_request.key_ids,
), ),
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
)) ))
def on_err(err): def on_err(err):
for deferred in group_id_to_deferred.values(): for verify_request in verify_requests:
if not deferred.called: if not verify_request.deferred.called:
deferred.errback(err) verify_request.deferred.errback(err)
do_iterations().addErrback(on_err) do_iterations().addErrback(on_err)
return group_id_to_deferred
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids): def get_keys_from_store(self, server_name_and_key_ids):
res = yield defer.gatherResults( res = yield defer.gatherResults(
@ -447,7 +454,7 @@ class Keyring(object):
) )
processed_response = yield self.process_v2_response( processed_response = yield self.process_v2_response(
perspective_name, response perspective_name, response, only_from_server=False
) )
for server_name, response_keys in processed_response.items(): for server_name, response_keys in processed_response.items():
@ -527,7 +534,7 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def process_v2_response(self, from_server, response_json, def process_v2_response(self, from_server, response_json,
requested_ids=[]): requested_ids=[], only_from_server=True):
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
response_keys = {} response_keys = {}
verify_keys = {} verify_keys = {}
@ -551,6 +558,13 @@ class Keyring(object):
results = {} results = {}
server_name = response_json["server_name"] server_name = response_json["server_name"]
if only_from_server:
if server_name != from_server:
raise ValueError(
"Expected a response for server %r not %r" % (
from_server, server_name
)
)
for key_id in response_json["signatures"].get(server_name, {}): for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]: if key_id not in response_json["verify_keys"]:
raise ValueError( raise ValueError(

View file

@ -236,9 +236,9 @@ class FederationClient(FederationBase):
# TODO: Rate limit the number of times we try and get the same event. # TODO: Rate limit the number of times we try and get the same event.
if self._get_pdu_cache: if self._get_pdu_cache:
e = self._get_pdu_cache.get(event_id) ev = self._get_pdu_cache.get(event_id)
if e: if ev:
defer.returnValue(e) defer.returnValue(ev)
pdu = None pdu = None
for destination in destinations: for destination in destinations:
@ -269,7 +269,7 @@ class FederationClient(FederationBase):
break break
except SynapseError: except SynapseError as e:
logger.info( logger.info(
"Failed to get PDU %s from %s because %s", "Failed to get PDU %s from %s because %s",
event_id, destination, e, event_id, destination, e,
@ -314,6 +314,42 @@ class FederationClient(FederationBase):
Deferred: Results in a list of PDUs. Deferred: Results in a list of PDUs.
""" """
try:
# First we try and ask for just the IDs, as thats far quicker if
# we have most of the state and auth_chain already.
# However, this may 404 if the other side has an old synapse.
result = yield self.transport_layer.get_room_state_ids(
destination, room_id, event_id=event_id,
)
state_event_ids = result["pdu_ids"]
auth_event_ids = result.get("auth_chain_ids", [])
fetched_events, failed_to_fetch = yield self.get_events(
[destination], room_id, set(state_event_ids + auth_event_ids)
)
if failed_to_fetch:
logger.warn("Failed to get %r", failed_to_fetch)
event_map = {
ev.event_id: ev for ev in fetched_events
}
pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
auth_chain = [
event_map[e_id] for e_id in auth_event_ids if e_id in event_map
]
auth_chain.sort(key=lambda e: e.depth)
defer.returnValue((pdus, auth_chain))
except HttpResponseException as e:
if e.code == 400 or e.code == 404:
logger.info("Failed to use get_room_state_ids API, falling back")
else:
raise e
result = yield self.transport_layer.get_room_state( result = yield self.transport_layer.get_room_state(
destination, room_id, event_id=event_id, destination, room_id, event_id=event_id,
) )
@ -327,18 +363,93 @@ class FederationClient(FederationBase):
for p in result.get("auth_chain", []) for p in result.get("auth_chain", [])
] ]
seen_events = yield self.store.get_events([
ev.event_id for ev in itertools.chain(pdus, auth_chain)
])
signed_pdus = yield self._check_sigs_and_hash_and_fetch( signed_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, pdus, outlier=True destination,
[p for p in pdus if p.event_id not in seen_events],
outlier=True
)
signed_pdus.extend(
seen_events[p.event_id] for p in pdus if p.event_id in seen_events
) )
signed_auth = yield self._check_sigs_and_hash_and_fetch( signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True destination,
[p for p in auth_chain if p.event_id not in seen_events],
outlier=True
)
signed_auth.extend(
seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
) )
signed_auth.sort(key=lambda e: e.depth) signed_auth.sort(key=lambda e: e.depth)
defer.returnValue((signed_pdus, signed_auth)) defer.returnValue((signed_pdus, signed_auth))
@defer.inlineCallbacks
def get_events(self, destinations, room_id, event_ids, return_local=True):
"""Fetch events from some remote destinations, checking if we already
have them.
Args:
destinations (list)
room_id (str)
event_ids (list)
return_local (bool): Whether to include events we already have in
the DB in the returned list of events
Returns:
Deferred: A deferred resolving to a 2-tuple where the first is a list of
events and the second is a list of event ids that we failed to fetch.
"""
if return_local:
seen_events = yield self.store.get_events(event_ids)
signed_events = seen_events.values()
else:
seen_events = yield self.store.have_events(event_ids)
signed_events = []
failed_to_fetch = set()
missing_events = set(event_ids)
for k in seen_events:
missing_events.discard(k)
if not missing_events:
defer.returnValue((signed_events, failed_to_fetch))
def random_server_list():
srvs = list(destinations)
random.shuffle(srvs)
return srvs
batch_size = 20
missing_events = list(missing_events)
for i in xrange(0, len(missing_events), batch_size):
batch = set(missing_events[i:i + batch_size])
deferreds = [
self.get_pdu(
destinations=random_server_list(),
event_id=e_id,
)
for e_id in batch
]
res = yield defer.DeferredList(deferreds, consumeErrors=True)
for success, result in res:
if success:
signed_events.append(result)
batch.discard(result.event_id)
# We removed all events we successfully fetched from `batch`
failed_to_fetch.update(batch)
defer.returnValue((signed_events, failed_to_fetch))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_event_auth(self, destination, room_id, event_id): def get_event_auth(self, destination, room_id, event_id):
@ -414,14 +525,19 @@ class FederationClient(FederationBase):
(destination, self.event_from_pdu_json(pdu_dict)) (destination, self.event_from_pdu_json(pdu_dict))
) )
break break
except CodeMessageException: except CodeMessageException as e:
if not 500 <= e.code < 600:
raise raise
else:
logger.warn(
"Failed to make_%s via %s: %s",
membership, destination, e.message
)
except Exception as e: except Exception as e:
logger.warn( logger.warn(
"Failed to make_%s via %s: %s", "Failed to make_%s via %s: %s",
membership, destination, e.message membership, destination, e.message
) )
raise
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")
@ -493,8 +609,14 @@ class FederationClient(FederationBase):
"auth_chain": signed_auth, "auth_chain": signed_auth,
"origin": destination, "origin": destination,
}) })
except CodeMessageException: except CodeMessageException as e:
if not 500 <= e.code < 600:
raise raise
else:
logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
except Exception as e: except Exception as e:
logger.exception( logger.exception(
"Failed to send_join via %s: %s", "Failed to send_join via %s: %s",

View file

@ -21,10 +21,11 @@ from .units import Transaction, Edu
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.caches.response_cache import ResponseCache
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
import synapse.metrics import synapse.metrics
from synapse.api.errors import FederationError, SynapseError from synapse.api.errors import AuthError, FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature from synapse.crypto.event_signing import compute_event_signature
@ -48,7 +49,14 @@ class FederationServer(FederationBase):
def __init__(self, hs): def __init__(self, hs):
super(FederationServer, self).__init__(hs) super(FederationServer, self).__init__(hs)
self.auth = hs.get_auth()
self._room_pdu_linearizer = Linearizer() self._room_pdu_linearizer = Linearizer()
self._server_linearizer = Linearizer()
# We cache responses to state queries, as they take a while and often
# come in waves.
self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
def set_handler(self, handler): def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate """Sets the handler that the replication layer will use to communicate
@ -89,11 +97,14 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_backfill_request(self, origin, room_id, versions, limit): def on_backfill_request(self, origin, room_id, versions, limit):
with (yield self._server_linearizer.queue((origin, room_id))):
pdus = yield self.handler.on_backfill_request( pdus = yield self.handler.on_backfill_request(
origin, room_id, versions, limit origin, room_id, versions, limit
) )
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) res = self._transaction_from_pdus(pdus).get_dict()
defer.returnValue((200, res))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -184,9 +195,50 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_context_state_request(self, origin, room_id, event_id): def on_context_state_request(self, origin, room_id, event_id):
if event_id: if not event_id:
raise NotImplementedError("Specify an event")
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
result = self._state_resp_cache.get((room_id, event_id))
if not result:
with (yield self._server_linearizer.queue((origin, room_id))):
resp = yield self._state_resp_cache.set(
(room_id, event_id),
self._on_context_state_request_compute(room_id, event_id)
)
else:
resp = yield result
defer.returnValue((200, resp))
@defer.inlineCallbacks
def on_state_ids_request(self, origin, room_id, event_id):
if not event_id:
raise NotImplementedError("Specify an event")
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
pdus = yield self.handler.get_state_for_pdu( pdus = yield self.handler.get_state_for_pdu(
origin, room_id, event_id, room_id, event_id,
)
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
)
defer.returnValue((200, {
"pdu_ids": [pdu.event_id for pdu in pdus],
"auth_chain_ids": [pdu.event_id for pdu in auth_chain],
}))
@defer.inlineCallbacks
def _on_context_state_request_compute(self, room_id, event_id):
pdus = yield self.handler.get_state_for_pdu(
room_id, event_id,
) )
auth_chain = yield self.store.get_auth_chain( auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus] [pdu.event_id for pdu in pdus]
@ -203,13 +255,11 @@ class FederationServer(FederationBase):
self.hs.config.signing_key[0] self.hs.config.signing_key[0]
) )
) )
else:
raise NotImplementedError("Specify an event")
defer.returnValue((200, { defer.returnValue({
"pdus": [pdu.get_pdu_json() for pdu in pdus], "pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
})) })
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -283,14 +333,16 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id): def on_event_auth(self, origin, room_id, event_id):
with (yield self._server_linearizer.queue((origin, room_id))):
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
auth_pdus = yield self.handler.on_event_auth(event_id) auth_pdus = yield self.handler.on_event_auth(event_id)
defer.returnValue((200, { res = {
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
})) }
defer.returnValue((200, res))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_query_auth_request(self, origin, content, event_id): def on_query_auth_request(self, origin, content, room_id, event_id):
""" """
Content is a dict with keys:: Content is a dict with keys::
auth_chain (list): A list of events that give the auth chain. auth_chain (list): A list of events that give the auth chain.
@ -309,6 +361,7 @@ class FederationServer(FederationBase):
Returns: Returns:
Deferred: Results in `dict` with the same format as `content` Deferred: Results in `dict` with the same format as `content`
""" """
with (yield self._server_linearizer.queue((origin, room_id))):
auth_chain = [ auth_chain = [
self.event_from_pdu_json(e) self.event_from_pdu_json(e)
for e in content["auth_chain"] for e in content["auth_chain"]
@ -340,27 +393,9 @@ class FederationServer(FederationBase):
(200, send_content) (200, send_content)
) )
@defer.inlineCallbacks
@log_function @log_function
def on_query_client_keys(self, origin, content): def on_query_client_keys(self, origin, content):
query = [] return self.on_query_request("client_keys", content)
for user_id, device_ids in content.get("device_keys", {}).items():
if not device_ids:
query.append((user_id, None))
else:
for device_id in device_ids:
query.append((user_id, device_id))
results = yield self.store.get_e2e_device_keys(query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
defer.returnValue({"device_keys": json_result})
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -386,6 +421,7 @@ class FederationServer(FederationBase):
@log_function @log_function
def on_get_missing_events(self, origin, room_id, earliest_events, def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth): latest_events, limit, min_depth):
with (yield self._server_linearizer.queue((origin, room_id))):
logger.info( logger.info(
"on_get_missing_events: earliest_events: %r, latest_events: %r," "on_get_missing_events: earliest_events: %r, latest_events: %r,"
" limit: %d, min_depth: %d", " limit: %d, min_depth: %d",
@ -396,7 +432,9 @@ class FederationServer(FederationBase):
) )
if len(missing_events) < 5: if len(missing_events) < 5:
logger.info("Returning %d events: %r", len(missing_events), missing_events) logger.info(
"Returning %d events: %r", len(missing_events), missing_events
)
else: else:
logger.info("Returning %d events", len(missing_events)) logger.info("Returning %d events", len(missing_events))
@ -567,7 +605,7 @@ class FederationServer(FederationBase):
origin, pdu.room_id, pdu.event_id, origin, pdu.room_id, pdu.event_id,
) )
except: except:
logger.warn("Failed to get state for event: %s", pdu.event_id) logger.exception("Failed to get state for event: %s", pdu.event_id)
yield self.handler.on_receive_pdu( yield self.handler.on_receive_pdu(
origin, origin,

View file

@ -54,6 +54,28 @@ class TransportLayerClient(object):
destination, path=path, args={"event_id": event_id}, destination, path=path, args={"event_id": event_id},
) )
@log_function
def get_room_state_ids(self, destination, room_id, event_id):
""" Requests all state for a given room from the given server at the
given event. Returns the state's event_id's
Args:
destination (str): The host name of the remote home server we want
to get the state from.
context (str): The name of the context we want the state of
event_id (str): The event we want the context at.
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
logger.debug("get_room_state_ids dest=%s, room=%s",
destination, room_id)
path = PREFIX + "/state_ids/%s/" % room_id
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
)
@log_function @log_function
def get_event(self, destination, event_id, timeout=None): def get_event(self, destination, event_id, timeout=None):
""" Requests the pdu with give id and origin from the given server. """ Requests the pdu with give id and origin from the given server.

View file

@ -18,13 +18,14 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import parse_json_object_from_request, parse_string from synapse.http.servlet import parse_json_object_from_request
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
import functools import functools
import logging import logging
import simplejson as json
import re import re
import synapse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -60,6 +61,16 @@ class TransportLayerServer(JsonResource):
) )
class AuthenticationError(SynapseError):
"""There was a problem authenticating the request"""
pass
class NoAuthenticationError(AuthenticationError):
"""The request had no authentication information"""
pass
class Authenticator(object): class Authenticator(object):
def __init__(self, hs): def __init__(self, hs):
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
@ -67,7 +78,7 @@ class Authenticator(object):
# A method just so we can pass 'self' as the authenticator to the Servlets # A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks @defer.inlineCallbacks
def authenticate_request(self, request): def authenticate_request(self, request, content):
json_request = { json_request = {
"method": request.method, "method": request.method,
"uri": request.uri, "uri": request.uri,
@ -75,17 +86,10 @@ class Authenticator(object):
"signatures": {}, "signatures": {},
} }
content = None if content is not None:
origin = None
if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
try:
content_bytes = request.content.read()
content = json.loads(content_bytes)
json_request["content"] = content json_request["content"] = content
except:
raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON) origin = None
def parse_auth_header(header_str): def parse_auth_header(header_str):
try: try:
@ -103,14 +107,14 @@ class Authenticator(object):
sig = strip_quotes(param_dict["sig"]) sig = strip_quotes(param_dict["sig"])
return (origin, key, sig) return (origin, key, sig)
except: except:
raise SynapseError( raise AuthenticationError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED 400, "Malformed Authorization header", Codes.UNAUTHORIZED
) )
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers: if not auth_headers:
raise SynapseError( raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED, 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
) )
@ -121,7 +125,7 @@ class Authenticator(object):
json_request["signatures"].setdefault(origin, {})[key] = sig json_request["signatures"].setdefault(origin, {})[key] = sig
if not json_request["signatures"]: if not json_request["signatures"]:
raise SynapseError( raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED, 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
) )
@ -130,10 +134,12 @@ class Authenticator(object):
logger.info("Request from %s", origin) logger.info("Request from %s", origin)
request.authenticated_entity = origin request.authenticated_entity = origin
defer.returnValue((origin, content)) defer.returnValue(origin)
class BaseFederationServlet(object): class BaseFederationServlet(object):
REQUIRE_AUTH = True
def __init__(self, handler, authenticator, ratelimiter, server_name, def __init__(self, handler, authenticator, ratelimiter, server_name,
room_list_handler): room_list_handler):
self.handler = handler self.handler = handler
@ -141,29 +147,46 @@ class BaseFederationServlet(object):
self.ratelimiter = ratelimiter self.ratelimiter = ratelimiter
self.room_list_handler = room_list_handler self.room_list_handler = room_list_handler
def _wrap(self, code): def _wrap(self, func):
authenticator = self.authenticator authenticator = self.authenticator
ratelimiter = self.ratelimiter ratelimiter = self.ratelimiter
@defer.inlineCallbacks @defer.inlineCallbacks
@functools.wraps(code) @functools.wraps(func)
def new_code(request, *args, **kwargs): def new_func(request, *args, **kwargs):
content = None
if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
content = parse_json_object_from_request(request)
try: try:
(origin, content) = yield authenticator.authenticate_request(request) origin = yield authenticator.authenticate_request(request, content)
with ratelimiter.ratelimit(origin) as d: except NoAuthenticationError:
yield d origin = None
response = yield code( if self.REQUIRE_AUTH:
origin, content, request.args, *args, **kwargs logger.exception("authenticate_request failed")
) raise
except: except:
logger.exception("authenticate_request failed") logger.exception("authenticate_request failed")
raise raise
if origin:
with ratelimiter.ratelimit(origin) as d:
yield d
response = yield func(
origin, content, request.args, *args, **kwargs
)
else:
response = yield func(
origin, content, request.args, *args, **kwargs
)
defer.returnValue(response) defer.returnValue(response)
# Extra logic that functools.wraps() doesn't finish # Extra logic that functools.wraps() doesn't finish
new_code.__self__ = code.__self__ new_func.__self__ = func.__self__
return new_code return new_func
def register(self, server): def register(self, server):
pattern = re.compile("^" + PREFIX + self.PATH + "$") pattern = re.compile("^" + PREFIX + self.PATH + "$")
@ -271,6 +294,17 @@ class FederationStateServlet(BaseFederationServlet):
) )
class FederationStateIdsServlet(BaseFederationServlet):
PATH = "/state_ids/(?P<room_id>[^/]*)/"
def on_GET(self, origin, content, query, room_id):
return self.handler.on_state_ids_request(
origin,
room_id,
query.get("event_id", [None])[0],
)
class FederationBackfillServlet(BaseFederationServlet): class FederationBackfillServlet(BaseFederationServlet):
PATH = "/backfill/(?P<context>[^/]*)/" PATH = "/backfill/(?P<context>[^/]*)/"
@ -367,10 +401,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
class FederationClientKeysQueryServlet(BaseFederationServlet): class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/user/keys/query" PATH = "/user/keys/query"
@defer.inlineCallbacks
def on_POST(self, origin, content, query): def on_POST(self, origin, content, query):
response = yield self.handler.on_query_client_keys(origin, content) return self.handler.on_query_client_keys(origin, content)
defer.returnValue((200, response))
class FederationClientKeysClaimServlet(BaseFederationServlet): class FederationClientKeysClaimServlet(BaseFederationServlet):
@ -388,7 +420,7 @@ class FederationQueryAuthServlet(BaseFederationServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, origin, content, query, context, event_id): def on_POST(self, origin, content, query, context, event_id):
new_content = yield self.handler.on_query_auth_request( new_content = yield self.handler.on_query_auth_request(
origin, content, event_id origin, content, context, event_id
) )
defer.returnValue((200, new_content)) defer.returnValue((200, new_content))
@ -420,9 +452,10 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
class On3pidBindServlet(BaseFederationServlet): class On3pidBindServlet(BaseFederationServlet):
PATH = "/3pid/onbind" PATH = "/3pid/onbind"
REQUIRE_AUTH = False
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, origin, content, query):
content = parse_json_object_from_request(request)
if "invites" in content: if "invites" in content:
last_exception = None last_exception = None
for invite in content["invites"]: for invite in content["invites"]:
@ -444,11 +477,6 @@ class On3pidBindServlet(BaseFederationServlet):
raise last_exception raise last_exception
defer.returnValue((200, {})) defer.returnValue((200, {}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
class OpenIdUserInfo(BaseFederationServlet): class OpenIdUserInfo(BaseFederationServlet):
""" """
@ -469,9 +497,11 @@ class OpenIdUserInfo(BaseFederationServlet):
PATH = "/openid/userinfo" PATH = "/openid/userinfo"
REQUIRE_AUTH = False
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, origin, content, query):
token = parse_string(request, "access_token") token = query.get("access_token", [None])[0]
if token is None: if token is None:
defer.returnValue((401, { defer.returnValue((401, {
"errcode": "M_MISSING_TOKEN", "error": "Access Token required" "errcode": "M_MISSING_TOKEN", "error": "Access Token required"
@ -488,11 +518,6 @@ class OpenIdUserInfo(BaseFederationServlet):
defer.returnValue((200, {"sub": user_id})) defer.returnValue((200, {"sub": user_id}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
class PublicRoomList(BaseFederationServlet): class PublicRoomList(BaseFederationServlet):
""" """
@ -533,11 +558,26 @@ class PublicRoomList(BaseFederationServlet):
defer.returnValue((200, data)) defer.returnValue((200, data))
class FederationVersionServlet(BaseFederationServlet):
PATH = "/version"
REQUIRE_AUTH = False
def on_GET(self, origin, content, query):
return defer.succeed((200, {
"server": {
"name": "Synapse",
"version": get_version_string(synapse)
},
}))
SERVLET_CLASSES = ( SERVLET_CLASSES = (
FederationSendServlet, FederationSendServlet,
FederationPullServlet, FederationPullServlet,
FederationEventServlet, FederationEventServlet,
FederationStateServlet, FederationStateServlet,
FederationStateIdsServlet,
FederationBackfillServlet, FederationBackfillServlet,
FederationQueryServlet, FederationQueryServlet,
FederationMakeJoinServlet, FederationMakeJoinServlet,
@ -555,6 +595,7 @@ SERVLET_CLASSES = (
On3pidBindServlet, On3pidBindServlet,
OpenIdUserInfo, OpenIdUserInfo,
PublicRoomList, PublicRoomList,
FederationVersionServlet,
) )

View file

@ -31,10 +31,21 @@ from .search import SearchHandler
class Handlers(object): class Handlers(object):
""" A collection of all the event handlers. """ Deprecated. A collection of handlers.
There's no need to lazily create these; we'll just make them all eagerly At some point most of the classes whose name ended "Handler" were
at construction time. accessed through this class.
However this makes it painful to unit test the handlers and to run cut
down versions of synapse that only use specific handlers because using a
single handler required creating all of the handlers. So some of the
handlers have been lifted out of the Handlers object and are now accessed
directly through the homeserver object itself.
Any new handlers should follow the new pattern of being accessed through
the homeserver object and should not be added to the Handlers object.
The remaining handlers should be moved out of the handlers object.
""" """
def __init__(self, hs): def __init__(self, hs):

View file

@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import LimitExceededError import synapse.types
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, Requester from synapse.api.errors import LimitExceededError
from synapse.types import UserID
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,11 +31,15 @@ class BaseHandler(object):
Common base class for the event handlers. Common base class for the event handlers.
Attributes: Attributes:
store (synapse.storage.events.StateStore): store (synapse.storage.DataStore):
state_handler (synapse.state.StateHandler): state_handler (synapse.state.StateHandler):
""" """
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@ -120,7 +124,8 @@ class BaseHandler(object):
# and having homeservers have their own users leave keeps more # and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having # of that decision-making and control local to the guest-having
# homeserver. # homeserver.
requester = Requester(target_user, "", True) requester = synapse.types.create_requester(
target_user, is_guest=True)
handler = self.hs.get_handlers().room_member_handler handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership( yield handler.update_membership(
requester, requester,

View file

@ -20,6 +20,7 @@ from synapse.api.constants import LoginType
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.config.ldap import LDAPMode
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
@ -28,6 +29,12 @@ import bcrypt
import pymacaroons import pymacaroons
import simplejson import simplejson
try:
import ldap3
except ImportError:
ldap3 = None
pass
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
@ -38,6 +45,10 @@ class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
super(AuthHandler, self).__init__(hs) super(AuthHandler, self).__init__(hs)
self.checkers = { self.checkers = {
LoginType.PASSWORD: self._check_password_auth, LoginType.PASSWORD: self._check_password_auth,
@ -50,19 +61,23 @@ class AuthHandler(BaseHandler):
self.INVALID_TOKEN_HTTP_STATUS = 401 self.INVALID_TOKEN_HTTP_STATUS = 401
self.ldap_enabled = hs.config.ldap_enabled self.ldap_enabled = hs.config.ldap_enabled
self.ldap_server = hs.config.ldap_server if self.ldap_enabled:
self.ldap_port = hs.config.ldap_port if not ldap3:
self.ldap_tls = hs.config.ldap_tls raise RuntimeError(
self.ldap_search_base = hs.config.ldap_search_base 'Missing ldap3 library. This is required for LDAP Authentication.'
self.ldap_search_property = hs.config.ldap_search_property )
self.ldap_email_property = hs.config.ldap_email_property self.ldap_mode = hs.config.ldap_mode
self.ldap_full_name_property = hs.config.ldap_full_name_property self.ldap_uri = hs.config.ldap_uri
self.ldap_start_tls = hs.config.ldap_start_tls
if self.ldap_enabled is True: self.ldap_base = hs.config.ldap_base
import ldap self.ldap_filter = hs.config.ldap_filter
logger.info("Import ldap version: %s", ldap.__version__) self.ldap_attributes = hs.config.ldap_attributes
if self.ldap_mode == LDAPMode.SEARCH:
self.ldap_bind_dn = hs.config.ldap_bind_dn
self.ldap_bind_password = hs.config.ldap_bind_password
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip): def check_auth(self, flows, clientdict, clientip):
@ -220,7 +235,6 @@ class AuthHandler(BaseHandler):
sess = self._get_session_info(session_id) sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default) return sess.setdefault('serverdict', {}).get(key, default)
@defer.inlineCallbacks
def _check_password_auth(self, authdict, _): def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict: if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM) raise LoginError(400, "", Codes.MISSING_PARAM)
@ -230,11 +244,7 @@ class AuthHandler(BaseHandler):
if not user_id.startswith('@'): if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string() user_id = UserID.create(user_id, self.hs.hostname).to_string()
if not (yield self._check_password(user_id, password)): return self._check_password(user_id, password)
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip): def _check_recaptcha(self, authdict, clientip):
@ -270,7 +280,16 @@ class AuthHandler(BaseHandler):
data = pde.response data = pde.response
resp_body = simplejson.loads(data) resp_body = simplejson.loads(data)
if 'success' in resp_body and resp_body['success']: if 'success' in resp_body:
# Note that we do NOT check the hostname here: we explicitly
# intend the CAPTCHA to be presented by whatever client the
# user is using, we just care that they have completed a CAPTCHA.
logger.info(
"%s reCAPTCHA from hostname %s",
"Successful" if resp_body['success'] else "Failed",
resp_body.get('hostname')
)
if resp_body['success']:
defer.returnValue(True) defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@ -338,67 +357,84 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id] return self.sessions[session_id]
@defer.inlineCallbacks def validate_password_login(self, user_id, password):
def login_with_password(self, user_id, password):
""" """
Authenticates the user with their username and password. Authenticates the user with their username and password.
Used only by the v1 login API. Used only by the v1 login API.
Args: Args:
user_id (str): User ID user_id (str): complete @user:id
password (str): Password password (str): Password
Returns: Returns:
A tuple of: defer.Deferred: (str) canonical user id
The user's ID.
The access token for the user's session.
The refresh token for the user's session.
Raises: Raises:
StoreError if there was a problem storing the token. StoreError if there was a problem accessing the database
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
return self._check_password(user_id, password)
if not (yield self._check_password(user_id, password)):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_login_tuple_for_user_id(self, user_id): def get_login_tuple_for_user_id(self, user_id, device_id=None,
initial_display_name=None):
""" """
Gets login tuple for the user with the given user ID. Gets login tuple for the user with the given user ID.
Creates a new access/refresh token for the user.
The user is assumed to have been authenticated by some other The user is assumed to have been authenticated by some other
machanism (e.g. CAS) machanism (e.g. CAS), and the user_id converted to the canonical case.
The device will be recorded in the table if it is not there already.
Args: Args:
user_id (str): User ID user_id (str): canonical User ID
device_id (str|None): the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
initial_display_name (str): display name to associate with the
device if it needs re-registering
Returns: Returns:
A tuple of: A tuple of:
The user's ID.
The access token for the user's session. The access token for the user's session.
The refresh token for the user's session. The refresh token for the user's session.
Raises: Raises:
StoreError if there was a problem storing the token. StoreError if there was a problem storing the token.
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id) logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id)
refresh_token = yield self.issue_refresh_token(user_id, device_id)
logger.info("Logging in user %s", user_id) # the device *should* have been registered before we got here; however,
access_token = yield self.issue_access_token(user_id) # it's possible we raced against a DELETE operation. The thing we
refresh_token = yield self.issue_refresh_token(user_id) # really don't want is active access_tokens without a record of the
defer.returnValue((user_id, access_token, refresh_token)) # device, so we double-check it here.
if device_id is not None:
yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
defer.returnValue((access_token, refresh_token))
@defer.inlineCallbacks @defer.inlineCallbacks
def does_user_exist(self, user_id): def check_user_exists(self, user_id):
"""
Checks to see if a user with the given id exists. Will check case
insensitively, but return None if there are multiple inexact matches.
Args:
(str) user_id: complete @user:id
Returns:
defer.Deferred: (str) canonical_user_id, or None if zero or
multiple matches
"""
try: try:
yield self._find_user_id_and_pwd_hash(user_id) res = yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(True) defer.returnValue(res[0])
except LoginError: except LoginError:
defer.returnValue(False) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id): def _find_user_id_and_pwd_hash(self, user_id):
@ -428,84 +464,232 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_password(self, user_id, password): def _check_password(self, user_id, password):
""" """Authenticate a user against the LDAP and local databases.
user_id is checked case insensitively against the local database, but
will throw if there are multiple inexact matches.
Args:
user_id (str): complete @user:id
Returns: Returns:
True if the user_id successfully authenticated (str) the canonical_user_id
Raises:
LoginError if the password was incorrect
""" """
valid_ldap = yield self._check_ldap_password(user_id, password) valid_ldap = yield self._check_ldap_password(user_id, password)
if valid_ldap: if valid_ldap:
defer.returnValue(True) defer.returnValue(user_id)
valid_local_password = yield self._check_local_password(user_id, password) result = yield self._check_local_password(user_id, password)
if valid_local_password: defer.returnValue(result)
defer.returnValue(True)
defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_local_password(self, user_id, password): def _check_local_password(self, user_id, password):
try: """Authenticate a user against the local password database.
user_id is checked case insensitively, but will throw if there are
multiple inexact matches.
Args:
user_id (str): complete @user:id
Returns:
(str) the canonical_user_id
Raises:
LoginError if the password was incorrect
"""
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(self.validate_hash(password, password_hash)) result = self.validate_hash(password, password_hash)
except LoginError: if not result:
defer.returnValue(False) logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_ldap_password(self, user_id, password): def _check_ldap_password(self, user_id, password):
if not self.ldap_enabled: """ Attempt to authenticate a user against an LDAP Server
logger.debug("LDAP not configured") and register an account if none exists.
Returns:
True if authentication against LDAP was successful
"""
if not ldap3 or not self.ldap_enabled:
defer.returnValue(False) defer.returnValue(False)
import ldap if self.ldap_mode not in LDAPMode.LIST:
raise RuntimeError(
logger.info("Authenticating %s with LDAP" % user_id) 'Invalid ldap mode specified: {mode}'.format(
try: mode=self.ldap_mode
ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port) )
logger.debug("Connecting LDAP server at %s" % ldap_url)
l = ldap.initialize(ldap_url)
if self.ldap_tls:
logger.debug("Initiating TLS")
self._connection.start_tls_s()
local_name = UserID.from_string(user_id).localpart
dn = "%s=%s, %s" % (
self.ldap_search_property,
local_name,
self.ldap_search_base)
logger.debug("DN for LDAP authentication: %s" % dn)
l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8'))
if not (yield self.does_user_exist(user_id)):
handler = self.hs.get_handlers().registration_handler
user_id, access_token = (
yield handler.register(localpart=local_name)
) )
try:
server = ldap3.Server(self.ldap_uri)
logger.debug(
"Attempting ldap connection with %s",
self.ldap_uri
)
localpart = UserID.from_string(user_id).localpart
if self.ldap_mode == LDAPMode.SIMPLE:
# bind with the the local users ldap credentials
bind_dn = "{prop}={value},{base}".format(
prop=self.ldap_attributes['uid'],
value=localpart,
base=self.ldap_base
)
conn = ldap3.Connection(server, bind_dn, password)
logger.debug(
"Established ldap connection in simple mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded ldap connection in simple mode through StartTLS: %s",
conn
)
conn.bind()
elif self.ldap_mode == LDAPMode.SEARCH:
# connect with preconfigured credentials and search for local user
conn = ldap3.Connection(
server,
self.ldap_bind_dn,
self.ldap_bind_password
)
logger.debug(
"Established ldap connection in search mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded ldap connection in search mode through StartTLS: %s",
conn
)
conn.bind()
# find matching dn
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_filter:
query = "(&{query}{filter})".format(
query=query,
filter=self.ldap_filter
)
logger.debug("ldap search filter: %s", query)
result = conn.search(self.ldap_base, query)
if result and len(conn.response) == 1:
# found exactly one result
user_dn = conn.response[0]['dn']
logger.debug('ldap search found dn: %s', user_dn)
# unbind and reconnect, rebind with found dn
conn.unbind()
conn = ldap3.Connection(
server,
user_dn,
password,
auto_bind=True
)
else:
# found 0 or > 1 results, abort!
logger.warn(
"ldap search returned unexpected (%d!=1) amount of results",
len(conn.response)
)
defer.returnValue(False)
logger.info(
"User authenticated against ldap server: %s",
conn
)
# check for existing account, if none exists, create one
if not (yield self.check_user_exists(user_id)):
# query user metadata for account creation
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
query = "(&{filter}{user_filter})".format(
filter=query,
user_filter=self.ldap_filter
)
logger.debug("ldap registration filter: %s", query)
result = conn.search(
search_base=self.ldap_base,
search_filter=query,
attributes=[
self.ldap_attributes['name'],
self.ldap_attributes['mail']
]
)
if len(conn.response) == 1:
attrs = conn.response[0]['attributes']
mail = attrs[self.ldap_attributes['mail']][0]
name = attrs[self.ldap_attributes['name']][0]
# create account
registration_handler = self.hs.get_handlers().registration_handler
user_id, access_token = (
yield registration_handler.register(localpart=localpart)
)
# TODO: bind email, set displayname with data from ldap directory
logger.info(
"ldap registration successful: %d: %s (%s, %)",
user_id,
localpart,
name,
mail
)
else:
logger.warn(
"ldap registration failed: unexpected (%d!=1) amount of results",
len(result)
)
defer.returnValue(False)
defer.returnValue(True) defer.returnValue(True)
except ldap.LDAPError, e: except ldap3.core.exceptions.LDAPException as e:
logger.warn("LDAP error: %s", e) logger.warn("Error during ldap authentication: %s", e)
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def issue_access_token(self, user_id): def issue_access_token(self, user_id, device_id=None):
access_token = self.generate_access_token(user_id) access_token = self.generate_access_token(user_id)
yield self.store.add_access_token_to_user(user_id, access_token) yield self.store.add_access_token_to_user(user_id, access_token,
device_id)
defer.returnValue(access_token) defer.returnValue(access_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def issue_refresh_token(self, user_id): def issue_refresh_token(self, user_id, device_id=None):
refresh_token = self.generate_refresh_token(user_id) refresh_token = self.generate_refresh_token(user_id)
yield self.store.add_refresh_token_to_user(user_id, refresh_token) yield self.store.add_refresh_token_to_user(user_id, refresh_token,
device_id)
defer.returnValue(refresh_token) defer.returnValue(refresh_token)
def generate_access_token(self, user_id, extra_caveats=None): def generate_access_token(self, user_id, extra_caveats=None,
duration_in_ms=(60 * 60 * 1000)):
extra_caveats = extra_caveats or [] extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
expiry = now + (60 * 60 * 1000) expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("time < %d" % (expiry,))
for caveat in extra_caveats: for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat) macaroon.add_first_party_caveat(caveat)
@ -613,7 +797,8 @@ class AuthHandler(BaseHandler):
Returns: Returns:
Hashed password (str). Hashed password (str).
""" """
return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds)) return bcrypt.hashpw(password + self.hs.config.password_pepper,
bcrypt.gensalt(self.bcrypt_rounds))
def validate_hash(self, password, stored_hash): def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash. """Validates that self.hash(password) == stored_hash.
@ -626,6 +811,7 @@ class AuthHandler(BaseHandler):
Whether self.hash(password) == stored_hash (bool). Whether self.hash(password) == stored_hash (bool).
""" """
if stored_hash: if stored_hash:
return bcrypt.hashpw(password, stored_hash.encode('utf-8')) == stored_hash return bcrypt.hashpw(password + self.hs.config.password_pepper,
stored_hash.encode('utf-8')) == stored_hash
else: else:
return False return False

181
synapse/handlers/device.py Normal file
View file

@ -0,0 +1,181 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api import errors
from synapse.util import stringutils
from twisted.internet import defer
from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__)
class DeviceHandler(BaseHandler):
def __init__(self, hs):
super(DeviceHandler, self).__init__(hs)
@defer.inlineCallbacks
def check_device_registered(self, user_id, device_id,
initial_device_display_name=None):
"""
If the given device has not been registered, register it with the
supplied display name.
If no device_id is supplied, we make one up.
Args:
user_id (str): @user:id
device_id (str | None): device id supplied by client
initial_device_display_name (str | None): device display name from
client
Returns:
str: device id (generated if none was supplied)
"""
if device_id is not None:
yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=True,
)
defer.returnValue(device_id)
# if the device id is not specified, we'll autogen one, but loop a few
# times in case of a clash.
attempts = 0
while attempts < 5:
try:
device_id = stringutils.random_string_with_symbols(16)
yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=False,
)
defer.returnValue(device_id)
except errors.StoreError:
attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.")
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""
Retrieve the given user's devices
Args:
user_id (str):
Returns:
defer.Deferred: list[dict[str, X]]: info on each device
"""
device_map = yield self.store.get_devices_by_user(user_id)
ips = yield self.store.get_last_client_ip_by_device(
devices=((user_id, device_id) for device_id in device_map.keys())
)
devices = device_map.values()
for device in devices:
_update_device_from_client_ips(device, ips)
defer.returnValue(devices)
@defer.inlineCallbacks
def get_device(self, user_id, device_id):
""" Retrieve the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred: dict[str, X]: info on the device
Raises:
errors.NotFoundError: if the device was not found
"""
try:
device = yield self.store.get_device(user_id, device_id)
except errors.StoreError:
raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device(
devices=((user_id, device_id),)
)
_update_device_from_client_ips(device, ips)
defer.returnValue(device)
@defer.inlineCallbacks
def delete_device(self, user_id, device_id):
""" Delete the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred:
"""
try:
yield self.store.delete_device(user_id, device_id)
except errors.StoreError, e:
if e.code == 404:
# no match
pass
else:
raise
yield self.store.user_delete_access_tokens(
user_id, device_id=device_id,
delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
)
@defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
""" Update the given device
Args:
user_id (str):
device_id (str):
content (dict): body of update request
Returns:
defer.Deferred:
"""
try:
yield self.store.update_device(
user_id,
device_id,
new_display_name=content.get("display_name")
)
except errors.StoreError, e:
if e.code == 404:
raise errors.NotFoundError()
else:
raise
def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({
"last_seen_ts": ip.get("last_seen"),
"last_seen_ip": ip.get("ip"),
})

View file

@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import json
import logging
from twisted.internet import defer
from synapse.api import errors
import synapse.types
logger = logging.getLogger(__name__)
class E2eKeysHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.federation = hs.get_replication_layer()
self.is_mine_id = hs.is_mine_id
self.server_name = hs.hostname
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
self.federation.register_query_handler(
"client_keys", self.on_federation_query_client_keys
)
@defer.inlineCallbacks
def query_devices(self, query_body):
""" Handle a device key query from a client
{
"device_keys": {
"<user_id>": ["<device_id>"]
}
}
->
{
"device_keys": {
"<user_id>": {
"<device_id>": {
...
}
}
}
}
"""
device_keys_query = query_body.get("device_keys", {})
# separate users by domain.
# make a map from domain to user_id to device_ids
queries_by_domain = collections.defaultdict(dict)
for user_id, device_ids in device_keys_query.items():
user = synapse.types.UserID.from_string(user_id)
queries_by_domain[user.domain][user_id] = device_ids
# do the queries
# TODO: do these in parallel
results = {}
for destination, destination_query in queries_by_domain.items():
if destination == self.server_name:
res = yield self.query_local_devices(destination_query)
else:
res = yield self.federation.query_client_keys(
destination, {"device_keys": destination_query}
)
res = res["device_keys"]
for user_id, keys in res.items():
if user_id in destination_query:
results[user_id] = keys
defer.returnValue((200, {"device_keys": results}))
@defer.inlineCallbacks
def query_local_devices(self, query):
"""Get E2E device keys for local users
Args:
query (dict[string, list[string]|None): map from user_id to a list
of devices to query (None for all devices)
Returns:
defer.Deferred: (resolves to dict[string, dict[string, dict]]):
map from user_id -> device_id -> device details
"""
local_query = []
result_dict = {}
for user_id, device_ids in query.items():
if not self.is_mine_id(user_id):
logger.warning("Request for keys for non-local user %s",
user_id)
raise errors.SynapseError(400, "Not a user here")
if not device_ids:
local_query.append((user_id, None))
else:
for device_id in device_ids:
local_query.append((user_id, device_id))
# make sure that each queried user appears in the result dict
result_dict[user_id] = {}
results = yield self.store.get_e2e_device_keys(local_query)
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
for user_id, device_keys in results.items():
for device_id, device_info in device_keys.items():
r = json.loads(device_info["key_json"])
r["unsigned"] = {}
display_name = device_info["device_display_name"]
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
result_dict[user_id][device_id] = r
defer.returnValue(result_dict)
@defer.inlineCallbacks
def on_federation_query_client_keys(self, query_body):
""" Handle a device key query from a federated server
"""
device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query)
defer.returnValue({"device_keys": res})

View file

@ -124,7 +124,7 @@ class FederationHandler(BaseHandler):
try: try:
event_stream_id, max_stream_id = yield self._persist_auth_tree( event_stream_id, max_stream_id = yield self._persist_auth_tree(
auth_chain, state, event origin, auth_chain, state, event
) )
except AuthError as e: except AuthError as e:
raise FederationError( raise FederationError(
@ -335,18 +335,35 @@ class FederationHandler(BaseHandler):
state_events.update({s.event_id: s for s in state}) state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state events_to_state[e_id] = state
seen_events = yield self.store.have_events(
set(auth_events.keys()) | set(state_events.keys())
)
all_events = events + state_events.values() + auth_events.values()
required_auth = set( required_auth = set(
a_id for event in all_events for a_id, _ in event.auth_events a_id
for event in events + state_events.values() + auth_events.values()
for a_id, _ in event.auth_events
)
auth_events.update({
e_id: event_map[e_id] for e_id in required_auth if e_id in event_map
})
missing_auth = required_auth - set(auth_events)
failed_to_fetch = set()
# Try and fetch any missing auth events from both DB and remote servers.
# We repeatedly do this until we stop finding new auth events.
while missing_auth - failed_to_fetch:
logger.info("Missing auth for backfill: %r", missing_auth)
ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
auth_events.update(ret_events)
required_auth.update(
a_id for event in ret_events.values() for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
if missing_auth - failed_to_fetch:
logger.info(
"Fetching missing auth for backfill: %r",
missing_auth - failed_to_fetch
) )
missing_auth = required_auth - set(auth_events)
if missing_auth:
logger.info("Missing auth for backfill: %r", missing_auth)
results = yield defer.gatherResults( results = yield defer.gatherResults(
[ [
self.replication_layer.get_pdu( self.replication_layer.get_pdu(
@ -355,11 +372,21 @@ class FederationHandler(BaseHandler):
outlier=True, outlier=True,
timeout=10000, timeout=10000,
) )
for event_id in missing_auth for event_id in missing_auth - failed_to_fetch
], ],
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results}) auth_events.update({a.event_id: a for a in results})
required_auth.update(
a_id for event in results for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
failed_to_fetch = missing_auth - set(auth_events)
seen_events = yield self.store.have_events(
set(auth_events.keys()) | set(state_events.keys())
)
ev_infos = [] ev_infos = []
for a in auth_events.values(): for a in auth_events.values():
@ -372,6 +399,7 @@ class FederationHandler(BaseHandler):
(auth_events[a_id].type, auth_events[a_id].state_key): (auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id] auth_events[a_id]
for a_id, _ in a.auth_events for a_id, _ in a.auth_events
if a_id in auth_events
} }
}) })
@ -383,6 +411,7 @@ class FederationHandler(BaseHandler):
(auth_events[a_id].type, auth_events[a_id].state_key): (auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id] auth_events[a_id]
for a_id, _ in event_map[e_id].auth_events for a_id, _ in event_map[e_id].auth_events
if a_id in auth_events
} }
}) })
@ -637,7 +666,7 @@ class FederationHandler(BaseHandler):
pass pass
event_stream_id, max_stream_id = yield self._persist_auth_tree( event_stream_id, max_stream_id = yield self._persist_auth_tree(
auth_chain, state, event origin, auth_chain, state, event
) )
with PreserveLoggingContext(): with PreserveLoggingContext():
@ -688,7 +717,9 @@ class FederationHandler(BaseHandler):
logger.warn("Failed to create join %r because %s", event, e) logger.warn("Failed to create join %r because %s", event, e)
raise e raise e
self.auth.check(event, auth_events=context.current_state) # The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
defer.returnValue(event) defer.returnValue(event)
@ -918,7 +949,9 @@ class FederationHandler(BaseHandler):
) )
try: try:
self.auth.check(event, auth_events=context.current_state) # The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request`
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
except AuthError as e: except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e) logger.warn("Failed to create new leave %r because %s", event, e)
raise e raise e
@ -987,14 +1020,9 @@ class FederationHandler(BaseHandler):
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True): def get_state_for_pdu(self, room_id, event_id):
yield run_on_reactor() yield run_on_reactor()
if do_auth:
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
state_groups = yield self.store.get_state_groups( state_groups = yield self.store.get_state_groups(
room_id, [event_id] room_id, [event_id]
) )
@ -1114,6 +1142,7 @@ class FederationHandler(BaseHandler):
backfilled=backfilled, backfilled=backfilled,
) )
if not backfilled:
# this intentionally does not yield: we don't care about the result # this intentionally does not yield: we don't care about the result
# and don't need to wait for it. # and don't need to wait for it.
preserve_fn(self.hs.get_pusherpool().on_new_notifications)( preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
@ -1150,11 +1179,19 @@ class FederationHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _persist_auth_tree(self, auth_events, state, event): def _persist_auth_tree(self, origin, auth_events, state, event):
"""Checks the auth chain is valid (and passes auth checks) for the """Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists the auth chain and state atomically. state and event. Then persists the auth chain and state atomically.
Persists the event seperately. Persists the event seperately.
Will attempt to fetch missing auth events.
Args:
origin (str): Where the events came from
auth_events (list)
state (list)
event (Event)
Returns: Returns:
2-tuple of (event_stream_id, max_stream_id) from the persist_event 2-tuple of (event_stream_id, max_stream_id) from the persist_event
call for `event` call for `event`
@ -1167,7 +1204,7 @@ class FederationHandler(BaseHandler):
event_map = { event_map = {
e.event_id: e e.event_id: e
for e in auth_events for e in itertools.chain(auth_events, state, [event])
} }
create_event = None create_event = None
@ -1176,10 +1213,29 @@ class FederationHandler(BaseHandler):
create_event = e create_event = e
break break
missing_auth_events = set()
for e in itertools.chain(auth_events, state, [event]):
for e_id, _ in e.auth_events:
if e_id not in event_map:
missing_auth_events.add(e_id)
for e_id in missing_auth_events:
m_ev = yield self.replication_layer.get_pdu(
[origin],
e_id,
outlier=True,
timeout=10000,
)
if m_ev and m_ev.event_id == e_id:
event_map[e_id] = m_ev
else:
logger.info("Failed to find auth event %r", e_id)
for e in itertools.chain(auth_events, state, [event]): for e in itertools.chain(auth_events, state, [event]):
auth_for_e = { auth_for_e = {
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id] (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
for e_id, _ in e.auth_events for e_id, _ in e.auth_events
if e_id in event_map
} }
if create_event: if create_event:
auth_for_e[(EventTypes.Create, "")] = create_event auth_for_e[(EventTypes.Create, "")] = create_event
@ -1413,7 +1469,7 @@ class FederationHandler(BaseHandler):
local_view = dict(auth_events) local_view = dict(auth_events)
remote_view = dict(auth_events) remote_view = dict(auth_events)
remote_view.update({ remote_view.update({
(d.type, d.state_key): d for d in different_events (d.type, d.state_key): d for d in different_events if d
}) })
new_state, prev_state = self.state_handler.resolve_events( new_state, prev_state = self.state_handler.resolve_events(

View file

@ -21,7 +21,7 @@ from synapse.api.errors import (
) )
from ._base import BaseHandler from ._base import BaseHandler
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, Codes
import json import json
import logging import logging
@ -41,6 +41,20 @@ class IdentityHandler(BaseHandler):
hs.config.use_insecure_ssl_client_just_for_testing_do_not_use hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
) )
def _should_trust_id_server(self, id_server):
if id_server not in self.trusted_id_servers:
if self.trust_any_id_server_just_for_testing_do_not_use:
logger.warn(
"Trusting untrustworthy ID server %r even though it isn't"
" in the trusted id list for testing because"
" 'use_insecure_ssl_client_just_for_testing_do_not_use'"
" is set in the config",
id_server,
)
else:
return False
return True
@defer.inlineCallbacks @defer.inlineCallbacks
def threepid_from_creds(self, creds): def threepid_from_creds(self, creds):
yield run_on_reactor() yield run_on_reactor()
@ -59,18 +73,11 @@ class IdentityHandler(BaseHandler):
else: else:
raise SynapseError(400, "No client_secret in creds") raise SynapseError(400, "No client_secret in creds")
if id_server not in self.trusted_id_servers: if not self._should_trust_id_server(id_server):
if self.trust_any_id_server_just_for_testing_do_not_use:
logger.warn( logger.warn(
"Trusting untrustworthy ID server %r even though it isn't" '%s is not a trusted ID server: rejecting 3pid ' +
" in the trusted id list for testing because" 'credentials', id_server
" 'use_insecure_ssl_client_just_for_testing_do_not_use'"
" is set in the config",
id_server,
) )
else:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', id_server)
defer.returnValue(None) defer.returnValue(None)
data = {} data = {}
@ -129,6 +136,12 @@ class IdentityHandler(BaseHandler):
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs): def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
yield run_on_reactor() yield run_on_reactor()
if not self._should_trust_id_server(id_server):
raise SynapseError(
400, "Untrusted ID server '%s'" % id_server,
Codes.SERVER_NOT_TRUSTED
)
params = { params = {
'email': email, 'email': email,
'client_secret': client_secret, 'client_secret': client_secret,

View file

@ -26,7 +26,7 @@ from synapse.types import (
UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id
) )
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute, run_on_reactor from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -50,9 +50,23 @@ class MessageHandler(BaseHandler):
self.validator = EventValidator() self.validator = EventValidator()
self.snapshot_cache = SnapshotCache() self.snapshot_cache = SnapshotCache()
self.pagination_lock = ReadWriteLock()
@defer.inlineCallbacks
def purge_history(self, room_id, event_id):
event = yield self.store.get_event(event_id)
if event.room_id != room_id:
raise SynapseError(400, "Event is for wrong room.")
depth = event.depth
with (yield self.pagination_lock.write(room_id)):
yield self.store.delete_old_state(room_id, depth)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None, def get_messages(self, requester, room_id=None, pagin_config=None,
as_client_event=True): as_client_event=True, event_filter=None):
"""Get messages in a room. """Get messages in a room.
Args: Args:
@ -61,11 +75,11 @@ class MessageHandler(BaseHandler):
pagin_config (synapse.api.streams.PaginationConfig): The pagination pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any. config rules to apply, if any.
as_client_event (bool): True to get events in client-server format. as_client_event (bool): True to get events in client-server format.
event_filter (Filter): Filter to apply to results or None
Returns: Returns:
dict: Pagination API results dict: Pagination API results
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
data_source = self.hs.get_event_sources().sources["room"]
if pagin_config.from_token: if pagin_config.from_token:
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key
@ -85,6 +99,7 @@ class MessageHandler(BaseHandler):
source_config = pagin_config.get_source_config("room") source_config = pagin_config.get_source_config("room")
with (yield self.pagination_lock.read(room_id)):
membership, member_event_id = yield self._check_in_room_or_world_readable( membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id room_id, user_id
) )
@ -95,7 +110,7 @@ class MessageHandler(BaseHandler):
if room_token.topological: if room_token.topological:
max_topo = room_token.topological max_topo = room_token.topological
else: else:
max_topo = yield self.store.get_max_topological_token_for_stream_and_room( max_topo = yield self.store.get_max_topological_token(
room_id, room_token.stream room_id, room_token.stream
) )
@ -114,8 +129,13 @@ class MessageHandler(BaseHandler):
room_id, max_topo room_id, max_topo
) )
events, next_key = yield data_source.get_pagination_rows( events, next_key = yield self.store.paginate_room_events(
requester.user, source_config, room_id room_id=room_id,
from_key=source_config.from_key,
to_key=source_config.to_key,
direction=source_config.direction,
limit=source_config.limit,
event_filter=event_filter,
) )
next_token = pagin_config.from_token.copy_and_replace( next_token = pagin_config.from_token.copy_and_replace(
@ -129,6 +149,9 @@ class MessageHandler(BaseHandler):
"end": next_token.to_string(), "end": next_token.to_string(),
}) })
if event_filter:
events = event_filter.filter(events)
events = yield filter_events_for_client( events = yield filter_events_for_client(
self.store, self.store,
user_id, user_id,

View file

@ -13,15 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from twisted.internet import defer from twisted.internet import defer
import synapse.types
from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.types import UserID, Requester from synapse.types import UserID
from ._base import BaseHandler from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -165,7 +165,9 @@ class ProfileHandler(BaseHandler):
try: try:
# Assume the user isn't a guest because we don't let guests set # Assume the user isn't a guest because we don't let guests set
# profile or avatar data. # profile or avatar data.
requester = Requester(user, "", False) # XXX why are we recreating `requester` here for each room?
# what was wrong with the `requester` we were passed?
requester = synapse.types.create_requester(user)
yield handler.update_membership( yield handler.update_membership(
requester, requester,
user, user,

View file

@ -14,18 +14,19 @@
# limitations under the License. # limitations under the License.
"""Contains functions for registering clients.""" """Contains functions for registering clients."""
import logging
import urllib
from twisted.internet import defer from twisted.internet import defer
from synapse.types import UserID, Requester import synapse.types
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
) )
from ._base import BaseHandler
from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient from synapse.http.client import CaptchaServerHttpClient
from synapse.types import UserID
import logging from synapse.util.async import run_on_reactor
import urllib from ._base import BaseHandler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -52,6 +53,13 @@ class RegistrationHandler(BaseHandler):
Codes.INVALID_USERNAME Codes.INVALID_USERNAME
) )
if localpart[0] == '_':
raise SynapseError(
400,
"User ID may not begin with _",
Codes.INVALID_USERNAME
)
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
@ -90,7 +98,8 @@ class RegistrationHandler(BaseHandler):
password=None, password=None,
generate_token=True, generate_token=True,
guest_access_token=None, guest_access_token=None,
make_guest=False make_guest=False,
admin=False,
): ):
"""Registers a new client on the server. """Registers a new client on the server.
@ -100,6 +109,11 @@ class RegistrationHandler(BaseHandler):
password (str) : The password to assign to this user so they can password (str) : The password to assign to this user so they can
login again. This can be None which means they cannot login again login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user). via a password (e.g. the user is an application service user).
generate_token (bool): Whether a new access token should be
generated. Having this be True should be considered deprecated,
since it offers no means of associating a device_id with the
access_token. Instead you should call auth_handler.issue_access_token
after registration.
Returns: Returns:
A tuple of (user_id, access_token). A tuple of (user_id, access_token).
Raises: Raises:
@ -141,6 +155,7 @@ class RegistrationHandler(BaseHandler):
# If the user was a guest then they already have a profile # If the user was a guest then they already have a profile
None if was_guest else user.localpart None if was_guest else user.localpart
), ),
admin=admin,
) )
else: else:
# autogen a sequential user ID # autogen a sequential user ID
@ -194,15 +209,13 @@ class RegistrationHandler(BaseHandler):
user_id, allowed_appservice=service user_id, allowed_appservice=service
) )
token = self.auth_handler().generate_access_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token,
password_hash="", password_hash="",
appservice_id=service_id, appservice_id=service_id,
create_profile_with_localpart=user.localpart, create_profile_with_localpart=user.localpart,
) )
defer.returnValue((user_id, token)) defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_recaptcha(self, ip, private_key, challenge, response): def check_recaptcha(self, ip, private_key, challenge, response):
@ -358,7 +371,8 @@ class RegistrationHandler(BaseHandler):
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_or_create_user(self, localpart, displayname, duration_seconds): def get_or_create_user(self, localpart, displayname, duration_in_ms,
password_hash=None):
"""Creates a new user if the user does not exist, """Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one. else revokes all previous access tokens and generates a new one.
@ -387,14 +401,14 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
token = self.auth_handler().generate_short_term_login_token( token = self.auth_handler().generate_access_token(
user_id, duration_seconds) user_id, None, duration_in_ms)
if need_register: if need_register:
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=None, password_hash=password_hash,
create_profile_with_localpart=user.localpart, create_profile_with_localpart=user.localpart,
) )
else: else:
@ -404,8 +418,9 @@ class RegistrationHandler(BaseHandler):
if displayname is not None: if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname) logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler profile_handler = self.hs.get_handlers().profile_handler
requester = synapse.types.create_requester(user)
yield profile_handler.set_displayname( yield profile_handler.set_displayname(
user, Requester(user, token, False), displayname user, requester, displayname
) )
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))

View file

@ -345,8 +345,8 @@ class RoomCreationHandler(BaseHandler):
class RoomListHandler(BaseHandler): class RoomListHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(RoomListHandler, self).__init__(hs) super(RoomListHandler, self).__init__(hs)
self.response_cache = ResponseCache() self.response_cache = ResponseCache(hs)
self.remote_list_request_cache = ResponseCache() self.remote_list_request_cache = ResponseCache(hs)
self.remote_list_cache = {} self.remote_list_cache = {}
self.fetch_looping_call = hs.get_clock().looping_call( self.fetch_looping_call = hs.get_clock().looping_call(
self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL

View file

@ -14,24 +14,22 @@
# limitations under the License. # limitations under the License.
import logging
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from twisted.internet import defer from twisted.internet import defer
from unpaddedbase64 import decode_base64
from ._base import BaseHandler import synapse.types
from synapse.types import UserID, RoomID, Requester
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, Membership, EventTypes, Membership,
) )
from synapse.api.errors import AuthError, SynapseError, Codes from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.types import UserID, RoomID
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.distributor import user_left_room, user_joined_room from synapse.util.distributor import user_left_room, user_joined_room
from ._base import BaseHandler
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -315,7 +313,7 @@ class RoomMemberHandler(BaseHandler):
) )
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else: else:
requester = Requester(target_user, None, False) requester = synapse.types.create_requester(target_user)
message_handler = self.hs.get_handlers().message_handler message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context) prev_event = message_handler.deduplicate_state_event(event, context)

View file

@ -138,7 +138,7 @@ class SyncHandler(object):
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.response_cache = ResponseCache() self.response_cache = ResponseCache(hs)
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False): full_state=False):

View file

@ -205,6 +205,7 @@ class JsonResource(HttpServer, resource.Resource):
def register_paths(self, method, path_patterns, callback): def register_paths(self, method, path_patterns, callback):
for path_pattern in path_patterns: for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append( self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback) self._PathEntry(path_pattern, callback)
) )

View file

@ -27,7 +27,8 @@ import gc
from twisted.internet import reactor from twisted.internet import reactor
from .metric import ( from .metric import (
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric CounterMetric, CallbackMetric, DistributionMetric, CacheMetric,
MemoryUsageMetric,
) )
@ -66,6 +67,21 @@ class Metrics(object):
return self._register(CacheMetric, *args, **kwargs) return self._register(CacheMetric, *args, **kwargs)
def register_memory_metrics(hs):
try:
import psutil
process = psutil.Process()
process.memory_info().rss
except (ImportError, AttributeError):
logger.warn(
"psutil is not installed or incorrect version."
" Disabling memory metrics."
)
return
metric = MemoryUsageMetric(hs, psutil)
all_metrics.append(metric)
def get_metrics_for(pkg_name): def get_metrics_for(pkg_name):
""" Returns a Metrics instance for conveniently creating metrics """ Returns a Metrics instance for conveniently creating metrics
namespaced with the given name prefix. """ namespaced with the given name prefix. """

View file

@ -153,3 +153,43 @@ class CacheMetric(object):
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total), """%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size), """%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
] ]
class MemoryUsageMetric(object):
"""Keeps track of the current memory usage, using psutil.
The class will keep the current min/max/sum/counts of rss over the last
WINDOW_SIZE_SEC, by polling UPDATE_HZ times per second
"""
UPDATE_HZ = 2 # number of times to get memory per second
WINDOW_SIZE_SEC = 30 # the size of the window in seconds
def __init__(self, hs, psutil):
clock = hs.get_clock()
self.memory_snapshots = []
self.process = psutil.Process()
clock.looping_call(self._update_curr_values, 1000 / self.UPDATE_HZ)
def _update_curr_values(self):
max_size = self.UPDATE_HZ * self.WINDOW_SIZE_SEC
self.memory_snapshots.append(self.process.memory_info().rss)
self.memory_snapshots[:] = self.memory_snapshots[-max_size:]
def render(self):
if not self.memory_snapshots:
return []
max_rss = max(self.memory_snapshots)
min_rss = min(self.memory_snapshots)
sum_rss = sum(self.memory_snapshots)
len_rss = len(self.memory_snapshots)
return [
"process_psutil_rss:max %d" % max_rss,
"process_psutil_rss:min %d" % min_rss,
"process_psutil_rss:total %d" % sum_rss,
"process_psutil_rss:count %d" % len_rss,
]

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
import logging import logging
@ -92,7 +93,11 @@ class EmailPusher(object):
def on_stop(self): def on_stop(self):
if self.timed_call: if self.timed_call:
try:
self.timed_call.cancel() self.timed_call.cancel()
except (AlreadyCalled, AlreadyCancelled):
pass
self.timed_call = None
@defer.inlineCallbacks @defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering): def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
@ -140,9 +145,8 @@ class EmailPusher(object):
being run. being run.
""" """
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range( fn = self.store.get_unread_push_actions_for_user_in_range_for_email
self.user_id, start, self.max_stream_ordering unprocessed = yield fn(self.user_id, start, self.max_stream_ordering)
)
soonest_due_at = None soonest_due_at = None
@ -190,7 +194,10 @@ class EmailPusher(object):
soonest_due_at = should_notify_at soonest_due_at = should_notify_at
if self.timed_call is not None: if self.timed_call is not None:
try:
self.timed_call.cancel() self.timed_call.cancel()
except (AlreadyCalled, AlreadyCancelled):
pass
self.timed_call = None self.timed_call = None
if soonest_due_at is not None: if soonest_due_at is not None:

View file

@ -16,6 +16,7 @@
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
import logging import logging
import push_rule_evaluator import push_rule_evaluator
@ -38,6 +39,7 @@ class HttpPusher(object):
self.hs = hs self.hs = hs
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.state_handler = self.hs.get_state_handler()
self.user_id = pusherdict['user_name'] self.user_id = pusherdict['user_name']
self.app_id = pusherdict['app_id'] self.app_id = pusherdict['app_id']
self.app_display_name = pusherdict['app_display_name'] self.app_display_name = pusherdict['app_display_name']
@ -108,7 +110,11 @@ class HttpPusher(object):
def on_stop(self): def on_stop(self):
if self.timed_call: if self.timed_call:
try:
self.timed_call.cancel() self.timed_call.cancel()
except (AlreadyCalled, AlreadyCancelled):
pass
self.timed_call = None
@defer.inlineCallbacks @defer.inlineCallbacks
def _process(self): def _process(self):
@ -140,7 +146,8 @@ class HttpPusher(object):
run once per pusher. run once per pusher.
""" """
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range( fn = self.store.get_unread_push_actions_for_user_in_range_for_http
unprocessed = yield fn(
self.user_id, self.last_stream_ordering, self.max_stream_ordering self.user_id, self.last_stream_ordering, self.max_stream_ordering
) )
@ -237,7 +244,9 @@ class HttpPusher(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge): def _build_notification_dict(self, event, tweaks, badge):
ctx = yield push_tools.get_context_for_event(self.hs.get_datastore(), event) ctx = yield push_tools.get_context_for_event(
self.state_handler, event, self.user_id
)
d = { d = {
'notification': { 'notification': {
@ -269,8 +278,8 @@ class HttpPusher(object):
if 'content' in event: if 'content' in event:
d['notification']['content'] = event.content d['notification']['content'] = event.content
if len(ctx['aliases']): # We no longer send aliases separately, instead, we send the human
d['notification']['room_alias'] = ctx['aliases'][0] # readable name of the room, which may be an alias.
if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0: if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0:
d['notification']['sender_display_name'] = ctx['sender_display_name'] d['notification']['sender_display_name'] = ctx['sender_display_name']
if 'name' in ctx and len(ctx['name']) > 0: if 'name' in ctx and len(ctx['name']) > 0:

View file

@ -14,6 +14,9 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.util.presentable_names import (
calculate_room_name, name_from_member_event
)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -45,24 +48,21 @@ def get_badge_count(store, user_id):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_context_for_event(store, ev): def get_context_for_event(state_handler, ev, user_id):
name_aliases = yield store.get_room_name_and_aliases( ctx = {}
ev.room_id
)
ctx = {'aliases': name_aliases[1]} room_state = yield state_handler.get_current_state(ev.room_id)
if name_aliases[0] is not None:
ctx['name'] = name_aliases[0]
their_member_events_for_room = yield store.get_current_state( # we no longer bother setting room_alias, and make room_name the
room_id=ev.room_id, # human-readable name instead, be that m.room.name, an alias or
event_type='m.room.member', # a list of people in the room
state_key=ev.user_id name = calculate_room_name(
room_state, user_id, fallback_to_single_member=False
) )
for mev in their_member_events_for_room: if name:
if mev.content['membership'] == 'join' and 'displayname' in mev.content: ctx['name'] = name
dn = mev.content['displayname']
if dn is not None: sender_state_event = room_state[("m.room.member", ev.sender)]
ctx['sender_display_name'] = dn ctx['sender_display_name'] = name_from_member_event(sender_state_event)
defer.returnValue(ctx) defer.returnValue(ctx)

View file

@ -48,6 +48,12 @@ CONDITIONAL_REQUIREMENTS = {
"Jinja2>=2.8": ["Jinja2>=2.8"], "Jinja2>=2.8": ["Jinja2>=2.8"],
"bleach>=1.4.2": ["bleach>=1.4.2"], "bleach>=1.4.2": ["bleach>=1.4.2"],
}, },
"ldap": {
"ldap3>=1.0": ["ldap3>=1.0"],
},
"psutil": {
"psutil>=2.0.0": ["psutil>=2.0.0"],
},
} }

View file

@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage.directory import DirectoryStore
class DirectoryStore(BaseSlavedStore):
get_aliases_for_room = DirectoryStore.__dict__[
"get_aliases_for_room"
].orig

View file

@ -18,7 +18,6 @@ from ._slaved_id_tracker import SlavedIdTracker
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.room import RoomStore
from synapse.storage.roommember import RoomMemberStore from synapse.storage.roommember import RoomMemberStore
from synapse.storage.event_federation import EventFederationStore from synapse.storage.event_federation import EventFederationStore
from synapse.storage.event_push_actions import EventPushActionsStore from synapse.storage.event_push_actions import EventPushActionsStore
@ -64,7 +63,6 @@ class SlavedEventStore(BaseSlavedStore):
# Cached functions can't be accessed through a class instance so we need # Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them. # to reach inside the __dict__ to extract them.
get_room_name_and_aliases = RoomStore.__dict__["get_room_name_and_aliases"]
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
get_latest_event_ids_in_room = EventFederationStore.__dict__[ get_latest_event_ids_in_room = EventFederationStore.__dict__[
@ -95,8 +93,11 @@ class SlavedEventStore(BaseSlavedStore):
StreamStore.__dict__["get_recent_event_ids_for_room"] StreamStore.__dict__["get_recent_event_ids_for_room"]
) )
get_unread_push_actions_for_user_in_range = ( get_unread_push_actions_for_user_in_range_for_http = (
DataStore.get_unread_push_actions_for_user_in_range.__func__ DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
)
get_unread_push_actions_for_user_in_range_for_email = (
DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__
) )
get_push_action_users_in_range = ( get_push_action_users_in_range = (
DataStore.get_push_action_users_in_range.__func__ DataStore.get_push_action_users_in_range.__func__
@ -144,6 +145,15 @@ class SlavedEventStore(BaseSlavedStore):
_get_events_around_txn = DataStore._get_events_around_txn.__func__ _get_events_around_txn = DataStore._get_events_around_txn.__func__
_get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__ _get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
get_backfill_events = DataStore.get_backfill_events.__func__
_get_backfill_events = DataStore._get_backfill_events.__func__
get_missing_events = DataStore.get_missing_events.__func__
_get_missing_events = DataStore._get_missing_events.__func__
get_auth_chain = DataStore.get_auth_chain.__func__
get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__
_get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__
def stream_positions(self): def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions() result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token() result["events"] = self._stream_id_gen.get_current_token()
@ -202,7 +212,6 @@ class SlavedEventStore(BaseSlavedStore):
self.get_rooms_for_user.invalidate_all() self.get_rooms_for_user.invalidate_all()
self.get_users_in_room.invalidate((event.room_id,)) self.get_users_in_room.invalidate((event.room_id,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,)) # self.get_joined_hosts_for_room.invalidate((event.room_id,))
self.get_room_name_and_aliases.invalidate((event.room_id,))
self._invalidate_get_event_cache(event.event_id) self._invalidate_get_event_cache(event.event_id)
@ -246,9 +255,3 @@ class SlavedEventStore(BaseSlavedStore):
self._get_current_state_for_key.invalidate(( self._get_current_state_for_key.invalidate((
event.room_id, event.type, event.state_key event.room_id, event.type, event.state_key
)) ))
if event.type in [EventTypes.Name, EventTypes.Aliases]:
self.get_room_name_and_aliases.invalidate(
(event.room_id,)
)
pass

View file

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.storage.keys import KeyStore
class SlavedKeyStore(BaseSlavedStore):
_get_server_verify_key = KeyStore.__dict__[
"_get_server_verify_key"
]
get_server_verify_keys = DataStore.get_server_verify_keys.__func__
store_server_verify_key = DataStore.store_server_verify_key.__func__
get_server_certificate = DataStore.get_server_certificate.__func__
store_server_certificate = DataStore.store_server_certificate.__func__
get_server_keys_json = DataStore.get_server_keys_json.__func__
store_server_keys_json = DataStore.store_server_keys_json.__func__

View file

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage import DataStore
class RoomStore(BaseSlavedStore):
get_public_room_ids = DataStore.get_public_room_ids.__func__

View file

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.storage.transactions import TransactionStore
class TransactionStore(BaseSlavedStore):
get_destination_retry_timings = TransactionStore.__dict__[
"get_destination_retry_timings"
].orig
_get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
# For now, don't record the destination rety timings
def set_destination_retry_timings(*args, **kwargs):
return defer.succeed(None)

View file

@ -46,6 +46,7 @@ from synapse.rest.client.v2_alpha import (
account_data, account_data,
report_event, report_event,
openid, openid,
devices,
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -90,3 +91,4 @@ class ClientRestResource(JsonResource):
account_data.register_servlets(hs, client_resource) account_data.register_servlets(hs, client_resource)
report_event.register_servlets(hs, client_resource) report_event.register_servlets(hs, client_resource)
openid.register_servlets(hs, client_resource) openid.register_servlets(hs, client_resource)
devices.register_servlets(hs, client_resource)

View file

@ -46,5 +46,82 @@ class WhoisRestServlet(ClientV1RestServlet):
defer.returnValue((200, ret)) defer.returnValue((200, ret))
class PurgeMediaCacheRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/purge_media_cache")
def __init__(self, hs):
self.media_repository = hs.get_media_repository()
super(PurgeMediaCacheRestServlet, self).__init__(hs)
@defer.inlineCallbacks
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
before_ts = request.args.get("before_ts", None)
if not before_ts:
raise SynapseError(400, "Missing 'before_ts' arg")
logger.info("before_ts: %r", before_ts[0])
try:
before_ts = int(before_ts[0])
except Exception:
raise SynapseError(400, "Invalid 'before_ts' arg")
ret = yield self.media_repository.delete_old_remote_media(before_ts)
defer.returnValue((200, ret))
class PurgeHistoryRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
)
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
yield self.handlers.message_handler.purge_history(room_id, event_id)
defer.returnValue((200, {}))
class DeactivateAccountRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self.store = hs.get_datastore()
super(DeactivateAccountRestServlet, self).__init__(hs)
@defer.inlineCallbacks
def on_POST(self, request, target_user_id):
UserID.from_string(target_user_id)
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
# FIXME: Theoretically there is a race here wherein user resets password
# using threepid.
yield self.store.user_delete_access_tokens(target_user_id)
yield self.store.user_delete_threepids(target_user_id)
yield self.store.user_set_password_hash(target_user_id, None)
defer.returnValue((200, {}))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
WhoisRestServlet(hs).register(http_server) WhoisRestServlet(hs).register(http_server)
PurgeMediaCacheRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
PurgeHistoryRestServlet(hs).register(http_server)

View file

@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet):
""" """
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
self.hs = hs self.hs = hs
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory() self.builder_factory = hs.get_event_builder_factory()

View file

@ -59,6 +59,7 @@ class LoginRestServlet(ClientV1RestServlet):
self.servername = hs.config.server_name self.servername = hs.config.server_name
self.http_client = hs.get_simple_http_client() self.http_client = hs.get_simple_http_client()
self.auth_handler = self.hs.get_auth_handler() self.auth_handler = self.hs.get_auth_handler()
self.device_handler = self.hs.get_device_handler()
def on_GET(self, request): def on_GET(self, request):
flows = [] flows = []
@ -145,15 +146,23 @@ class LoginRestServlet(ClientV1RestServlet):
).to_string() ).to_string()
auth_handler = self.auth_handler auth_handler = self.auth_handler
user_id, access_token, refresh_token = yield auth_handler.login_with_password( user_id = yield auth_handler.validate_password_login(
user_id=user_id, user_id=user_id,
password=login_submission["password"]) password=login_submission["password"],
)
device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name")
)
)
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token, "refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id,
} }
defer.returnValue((200, result)) defer.returnValue((200, result))
@ -165,14 +174,19 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = ( user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token) yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
) )
user_id, access_token, refresh_token = ( device_id = yield self._register_device(user_id, login_submission)
yield auth_handler.get_login_tuple_for_user_id(user_id) access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name")
)
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token, "refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id,
} }
defer.returnValue((200, result)) defer.returnValue((200, result))
@ -196,13 +210,15 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = UserID.create(user, self.hs.hostname).to_string() user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id) registered_user_id = yield auth_handler.check_user_exists(user_id)
if user_exists: if registered_user_id:
user_id, access_token, refresh_token = ( access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id) yield auth_handler.get_login_tuple_for_user_id(
registered_user_id
)
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": registered_user_id, # may have changed
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token, "refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
@ -245,18 +261,27 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = UserID.create(user, self.hs.hostname).to_string() user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id) registered_user_id = yield auth_handler.check_user_exists(user_id)
if user_exists: if registered_user_id:
user_id, access_token, refresh_token = ( device_id = yield self._register_device(
yield auth_handler.get_login_tuple_for_user_id(user_id) registered_user_id, login_submission
)
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
registered_user_id, device_id,
login_submission.get("initial_device_display_name")
)
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": registered_user_id,
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token, "refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
} }
else: else:
# TODO: we should probably check that the register isn't going
# to fonx/change our user_id before registering the device
device_id = yield self._register_device(user_id, login_submission)
user_id, access_token = ( user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user) yield self.handlers.registration_handler.register(localpart=user)
) )
@ -295,6 +320,26 @@ class LoginRestServlet(ClientV1RestServlet):
return (user, attributes) return (user, attributes)
def _register_device(self, user_id, login_submission):
"""Register a device for a user.
This is called after the user's credentials have been validated, but
before the access token has been issued.
Args:
(str) user_id: full canonical @user:id
(object) login_submission: dictionary supplied to /login call, from
which we pull device_id and initial_device_name
Returns:
defer.Deferred: (str) device_id
"""
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get(
"initial_device_display_name")
return self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
class SAML2RestServlet(ClientV1RestServlet): class SAML2RestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/saml2", releases=()) PATTERNS = client_path_patterns("/login/saml2", releases=())
@ -414,13 +459,13 @@ class CasTicketServlet(ClientV1RestServlet):
user_id = UserID.create(user, self.hs.hostname).to_string() user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id) registered_user_id = yield auth_handler.check_user_exists(user_id)
if not user_exists: if not registered_user_id:
user_id, _ = ( registered_user_id, _ = (
yield self.handlers.registration_handler.register(localpart=user) yield self.handlers.registration_handler.register(localpart=user)
) )
login_token = auth_handler.generate_short_term_login_token(user_id) login_token = auth_handler.generate_short_term_login_token(registered_user_id)
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
login_token) login_token)
request.redirect(redirect_url) request.redirect(redirect_url)

View file

@ -52,6 +52,10 @@ class RegisterRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False) PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False)
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(RegisterRestServlet, self).__init__(hs) super(RegisterRestServlet, self).__init__(hs)
# sessions are stored as: # sessions are stored as:
# self.sessions = { # self.sessions = {
@ -60,6 +64,7 @@ class RegisterRestServlet(ClientV1RestServlet):
# TODO: persistent storage # TODO: persistent storage
self.sessions = {} self.sessions = {}
self.enable_registration = hs.config.enable_registration self.enable_registration = hs.config.enable_registration
self.auth_handler = hs.get_auth_handler()
def on_GET(self, request): def on_GET(self, request):
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
@ -299,9 +304,10 @@ class RegisterRestServlet(ClientV1RestServlet):
user_localpart = register_json["user"].encode("utf-8") user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler handler = self.handlers.registration_handler
(user_id, token) = yield handler.appservice_register( user_id = yield handler.appservice_register(
user_localpart, as_token user_localpart, as_token
) )
token = yield self.auth_handler.issue_access_token(user_id)
self._remove_session(session) self._remove_session(session)
defer.returnValue({ defer.returnValue({
"user_id": user_id, "user_id": user_id,
@ -324,6 +330,14 @@ class RegisterRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Shared secret registration is not enabled") raise SynapseError(400, "Shared secret registration is not enabled")
user = register_json["user"].encode("utf-8") user = register_json["user"].encode("utf-8")
password = register_json["password"].encode("utf-8")
admin = register_json.get("admin", None)
# Its important to check as we use null bytes as HMAC field separators
if "\x00" in user:
raise SynapseError(400, "Invalid user")
if "\x00" in password:
raise SynapseError(400, "Invalid password")
# str() because otherwise hmac complains that 'unicode' does not # str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface # have the buffer interface
@ -331,17 +345,21 @@ class RegisterRestServlet(ClientV1RestServlet):
want_mac = hmac.new( want_mac = hmac.new(
key=self.hs.config.registration_shared_secret, key=self.hs.config.registration_shared_secret,
msg=user,
digestmod=sha1, digestmod=sha1,
).hexdigest() )
want_mac.update(user)
password = register_json["password"].encode("utf-8") want_mac.update("\x00")
want_mac.update(password)
want_mac.update("\x00")
want_mac.update("admin" if admin else "notadmin")
want_mac = want_mac.hexdigest()
if compare_digest(want_mac, got_mac): if compare_digest(want_mac, got_mac):
handler = self.handlers.registration_handler handler = self.handlers.registration_handler
user_id, token = yield handler.register( user_id, token = yield handler.register(
localpart=user, localpart=user,
password=password, password=password,
admin=bool(admin),
) )
self._remove_session(session) self._remove_session(session)
defer.returnValue({ defer.returnValue({
@ -410,12 +428,15 @@ class CreateUserRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Failed to parse 'duration_seconds'") raise SynapseError(400, "Failed to parse 'duration_seconds'")
if duration_seconds > self.direct_user_creation_max_duration: if duration_seconds > self.direct_user_creation_max_duration:
duration_seconds = self.direct_user_creation_max_duration duration_seconds = self.direct_user_creation_max_duration
password_hash = user_json["password_hash"].encode("utf-8") \
if user_json.get("password_hash") else None
handler = self.handlers.registration_handler handler = self.handlers.registration_handler
user_id, token = yield handler.get_or_create_user( user_id, token = yield handler.get_or_create_user(
localpart=localpart, localpart=localpart,
displayname=displayname, displayname=displayname,
duration_seconds=duration_seconds duration_in_ms=(duration_seconds * 1000),
password_hash=password_hash
) )
defer.returnValue({ defer.returnValue({

View file

@ -20,12 +20,14 @@ from .base import ClientV1RestServlet, client_path_patterns
from synapse.api.errors import SynapseError, Codes, AuthError from synapse.api.errors import SynapseError, Codes, AuthError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias from synapse.types import UserID, RoomID, RoomAlias
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
import logging import logging
import urllib import urllib
import ujson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -327,12 +329,19 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
request, default_limit=10, request, default_limit=10,
) )
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
filter_bytes = request.args.get("filter", None)
if filter_bytes:
filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8")
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None
handler = self.handlers.message_handler handler = self.handlers.message_handler
msgs = yield handler.get_messages( msgs = yield handler.get_messages(
room_id=room_id, room_id=room_id,
requester=requester, requester=requester,
pagin_config=pagination_config, pagin_config=pagination_config,
as_client_event=as_client_event as_client_event=as_client_event,
event_filter=event_filter,
) )
defer.returnValue((200, msgs)) defer.returnValue((200, msgs))

View file

@ -25,7 +25,9 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def client_v2_patterns(path_regex, releases=(0,)): def client_v2_patterns(path_regex, releases=(0,),
v2_alpha=True,
unstable=True):
"""Creates a regex compiled client path with the correct client path """Creates a regex compiled client path with the correct client path
prefix. prefix.
@ -35,7 +37,10 @@ def client_v2_patterns(path_regex, releases=(0,)):
Returns: Returns:
SRE_Pattern SRE_Pattern
""" """
patterns = [re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)] patterns = []
if v2_alpha:
patterns.append(re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex))
if unstable:
unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable") unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable")
patterns.append(re.compile("^" + unstable_prefix + path_regex)) patterns.append(re.compile("^" + unstable_prefix + path_regex))
for release in releases: for release in releases:

View file

@ -28,8 +28,40 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
def __init__(self, hs):
super(PasswordRequestTokenRestServlet, self).__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
if existingUid is None:
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
class PasswordRestServlet(RestServlet): class PasswordRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password") PATTERNS = client_v2_patterns("/account/password$")
def __init__(self, hs): def __init__(self, hs):
super(PasswordRestServlet, self).__init__() super(PasswordRestServlet, self).__init__()
@ -89,8 +121,83 @@ class PasswordRestServlet(RestServlet):
return 200, {} return 200, {}
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/deactivate$")
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
super(DeactivateAccountRestServlet, self).__init__()
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
authed, result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
user_id = None
requester = None
if LoginType.PASSWORD in result:
# if using password, they should also be logged in
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
if user_id != result[LoginType.PASSWORD]:
raise LoginError(400, "", Codes.UNKNOWN)
else:
logger.error("Auth succeeded but no known type!", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
# FIXME: Theoretically there is a race here wherein user resets password
# using threepid.
yield self.store.user_delete_access_tokens(user_id)
yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None)
defer.returnValue((200, {}))
class ThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs):
self.hs = hs
super(ThreepidRequestTokenRestServlet, self).__init__()
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
if existingUid is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
class ThreepidRestServlet(RestServlet): class ThreepidRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid") PATTERNS = client_v2_patterns("/account/3pid$")
def __init__(self, hs): def __init__(self, hs):
super(ThreepidRestServlet, self).__init__() super(ThreepidRestServlet, self).__init__()
@ -157,5 +264,8 @@ class ThreepidRestServlet(RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
PasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
ThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server)

View file

@ -0,0 +1,100 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.http import servlet
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
class DevicesRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(DevicesRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
devices = yield self.device_handler.get_devices_by_user(
requester.user.to_string()
)
defer.returnValue((200, {"devices": devices}))
class DeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
releases=[], v2_alpha=False)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(DeviceRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def on_GET(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
device = yield self.device_handler.get_device(
requester.user.to_string(),
device_id,
)
defer.returnValue((200, device))
@defer.inlineCallbacks
def on_DELETE(self, request, device_id):
# XXX: it's not completely obvious we want to expose this endpoint.
# It allows the client to delete access tokens, which feels like a
# thing which merits extra auth. But if we want to do the interactive-
# auth dance, we should really make it possible to delete more than one
# device at a time.
requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_device(
requester.user.to_string(),
device_id,
)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_PUT(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
body = servlet.parse_json_object_from_request(request)
yield self.device_handler.update_device(
requester.user.to_string(),
device_id,
body
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
DevicesRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)

View file

@ -13,24 +13,25 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import simplejson as json
from canonicaljson import encode_canonical_json
from twisted.internet import defer from twisted.internet import defer
import synapse.api.errors
import synapse.server
import synapse.types
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID from synapse.types import UserID
from canonicaljson import encode_canonical_json
from ._base import client_v2_patterns from ._base import client_v2_patterns
import logging
import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class KeyUploadServlet(RestServlet): class KeyUploadServlet(RestServlet):
""" """
POST /keys/upload/<device_id> HTTP/1.1 POST /keys/upload HTTP/1.1
Content-Type: application/json Content-Type: application/json
{ {
@ -53,23 +54,45 @@ class KeyUploadServlet(RestServlet):
}, },
} }
""" """
PATTERNS = client_v2_patterns("/keys/upload/(?P<device_id>[^/]*)", releases=()) PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
releases=())
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(KeyUploadServlet, self).__init__() super(KeyUploadServlet, self).__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, device_id): def on_POST(self, request, device_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
# TODO: Check that the device_id matches that in the authentication
# or derive the device_id from the authentication instead.
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
if device_id is not None:
# passing the device_id here is deprecated; however, we allow it
# for now for compatibility with older clients.
if (requester.device_id is not None and
device_id != requester.device_id):
logger.warning("Client uploading keys for a different device "
"(logged in as %s, uploading for %s)",
requester.device_id, device_id)
else:
device_id = requester.device_id
if device_id is None:
raise synapse.api.errors.SynapseError(
400,
"To upload keys, you must pass device_id when authenticating"
)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys. # TODO: Validate the JSON to make sure it has the right keys.
@ -102,13 +125,12 @@ class KeyUploadServlet(RestServlet):
user_id, device_id, time_now, key_list user_id, device_id, time_now, key_list
) )
result = yield self.store.count_e2e_one_time_keys(user_id, device_id) # the device should have been registered already, but it may have been
defer.returnValue((200, {"one_time_key_counts": result})) # deleted due to a race with a DELETE request. Or we may be using an
# old access_token without an associated device_id. Either way, we
@defer.inlineCallbacks # need to double-check the device is registered to avoid ending up with
def on_GET(self, request, device_id): # keys without a corresponding device.
requester = yield self.auth.get_user_by_req(request) self.device_handler.check_device_registered(user_id, device_id)
user_id = requester.user.to_string()
result = yield self.store.count_e2e_one_time_keys(user_id, device_id) result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result})) defer.returnValue((200, {"one_time_key_counts": result}))
@ -162,17 +184,19 @@ class KeyQueryServlet(RestServlet):
) )
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
super(KeyQueryServlet, self).__init__() super(KeyQueryServlet, self).__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.federation = hs.get_replication_layer() self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.is_mine = hs.is_mine
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id, device_id): def on_POST(self, request, user_id, device_id):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = yield self.handle_request(body) result = yield self.e2e_keys_handler.query_devices(body)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -181,45 +205,11 @@ class KeyQueryServlet(RestServlet):
auth_user_id = requester.user.to_string() auth_user_id = requester.user.to_string()
user_id = user_id if user_id else auth_user_id user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else [] device_ids = [device_id] if device_id else []
result = yield self.handle_request( result = yield self.e2e_keys_handler.query_devices(
{"device_keys": {user_id: device_ids}} {"device_keys": {user_id: device_ids}}
) )
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def handle_request(self, body):
local_query = []
remote_queries = {}
for user_id, device_ids in body.get("device_keys", {}).items():
user = UserID.from_string(user_id)
if self.is_mine(user):
if not device_ids:
local_query.append((user_id, None))
else:
for device_id in device_ids:
local_query.append((user_id, device_id))
else:
remote_queries.setdefault(user.domain, {})[user_id] = list(
device_ids
)
results = yield self.store.get_e2e_device_keys(local_query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
for destination, device_keys in remote_queries.items():
remote_result = yield self.federation.query_client_keys(
destination, {"device_keys": device_keys}
)
for user_id, keys in remote_result["device_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
defer.returnValue((200, {"device_keys": json_result}))
class OneTimeKeyServlet(RestServlet): class OneTimeKeyServlet(RestServlet):
""" """

View file

@ -41,17 +41,59 @@ else:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RegisterRestServlet(RestServlet): class RegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register") PATTERNS = client_v2_patterns("/register/email/requestToken$")
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(RegisterRequestTokenRestServlet, self).__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
if existingUid is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
class RegisterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(RegisterRestServlet, self).__init__() super(RegisterRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -70,10 +112,6 @@ class RegisterRestServlet(RestServlet):
"Do not understand membership kind: %s" % (kind,) "Do not understand membership kind: %s" % (kind,)
) )
if '/register/email/requestToken' in request.path:
ret = yield self.onEmailTokenRequest(request)
defer.returnValue(ret)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
# we do basic sanity checks here because the auth layer will store these # we do basic sanity checks here because the auth layer will store these
@ -104,10 +142,11 @@ class RegisterRestServlet(RestServlet):
# Set the desired user according to the AS API (which uses the # Set the desired user according to the AS API (which uses the
# 'user' key not 'username'). Since this is a new addition, we'll # 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one. # fallback to 'username' if they gave one.
if isinstance(body.get("user"), basestring): desired_username = body.get("user", desired_username)
desired_username = body["user"]
if isinstance(desired_username, basestring):
result = yield self._do_appservice_registration( result = yield self._do_appservice_registration(
desired_username, request.args["access_token"][0] desired_username, request.args["access_token"][0], body
) )
defer.returnValue((200, result)) # we throw for non 200 responses defer.returnValue((200, result)) # we throw for non 200 responses
return return
@ -117,7 +156,7 @@ class RegisterRestServlet(RestServlet):
# FIXME: Should we really be determining if this is shared secret # FIXME: Should we really be determining if this is shared secret
# auth based purely on the 'mac' key? # auth based purely on the 'mac' key?
result = yield self._do_shared_secret_registration( result = yield self._do_shared_secret_registration(
desired_username, desired_password, body["mac"] desired_username, desired_password, body
) )
defer.returnValue((200, result)) # we throw for non 200 responses defer.returnValue((200, result)) # we throw for non 200 responses
return return
@ -157,12 +196,12 @@ class RegisterRestServlet(RestServlet):
[LoginType.EMAIL_IDENTITY] [LoginType.EMAIL_IDENTITY]
] ]
authed, result, params, session_id = yield self.auth_handler.check_auth( authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request) flows, body, self.hs.get_ip_from_request(request)
) )
if not authed: if not authed:
defer.returnValue((401, result)) defer.returnValue((401, auth_result))
return return
if registered_user_id is not None: if registered_user_id is not None:
@ -170,106 +209,58 @@ class RegisterRestServlet(RestServlet):
"Already registered user ID %r for this session", "Already registered user ID %r for this session",
registered_user_id registered_user_id
) )
access_token = yield self.auth_handler.issue_access_token(registered_user_id) # don't re-register the email address
refresh_token = yield self.auth_handler.issue_refresh_token( add_email = False
registered_user_id else:
)
defer.returnValue((200, {
"user_id": registered_user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
"refresh_token": refresh_token,
}))
# NB: This may be from the auth handler and NOT from the POST # NB: This may be from the auth handler and NOT from the POST
if 'password' not in params: if 'password' not in params:
raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM) raise SynapseError(400, "Missing password.",
Codes.MISSING_PARAM)
desired_username = params.get("username", None) desired_username = params.get("username", None)
new_password = params.get("password", None) new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None) guest_access_token = params.get("guest_access_token", None)
(user_id, token) = yield self.registration_handler.register( (registered_user_id, _) = yield self.registration_handler.register(
localpart=desired_username, localpart=desired_username,
password=new_password, password=new_password,
guest_access_token=guest_access_token, guest_access_token=guest_access_token,
generate_token=False,
) )
# remember that we've now registered that user account, and with what # remember that we've now registered that user account, and with
# user ID (since the user may not have specified) # what user ID (since the user may not have specified)
self.auth_handler.set_session_data( self.auth_handler.set_session_data(
session_id, "registered_user_id", user_id session_id, "registered_user_id", registered_user_id
) )
if result and LoginType.EMAIL_IDENTITY in result: add_email = True
threepid = result[LoginType.EMAIL_IDENTITY]
for reqd in ['medium', 'address', 'validated_at']: return_dict = yield self._create_registration_details(
if reqd not in threepid: registered_user_id, params
logger.info("Can't add incomplete 3pid")
else:
yield self.auth_handler.add_threepid(
user_id,
threepid['medium'],
threepid['address'],
threepid['validated_at'],
) )
# And we add an email pusher for them by default, but only if add_email and auth_result and LoginType.EMAIL_IDENTITY in auth_result:
# if email notifications are enabled (so people don't start threepid = auth_result[LoginType.EMAIL_IDENTITY]
# getting mail spam where they weren't before if email yield self._register_email_threepid(
# notifs are set up on a home server) registered_user_id, threepid, return_dict["access_token"],
if ( params.get("bind_email")
self.hs.config.email_enable_notifs and
self.hs.config.email_notif_for_new_users
):
# Pull the ID of the access token back out of the db
# It would really make more sense for this to be passed
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
user_tuple = yield self.store.get_user_by_access_token(
token
) )
yield self.hs.get_pusherpool().add_pusher( defer.returnValue((200, return_dict))
user_id=user_id,
access_token=user_tuple["token_id"],
kind="email",
app_id="m.email",
app_display_name="Email Notifications",
device_display_name=threepid["address"],
pushkey=threepid["address"],
lang=None, # We don't know a user's language here
data={},
)
if 'bind_email' in params and params['bind_email']:
logger.info("bind_email specified: binding")
emailThreepid = result[LoginType.EMAIL_IDENTITY]
threepid_creds = emailThreepid['threepid_creds']
logger.debug("Binding emails %s to %s" % (
emailThreepid, user_id
))
yield self.identity_handler.bind_threepid(threepid_creds, user_id)
else:
logger.info("bind_email not specified: not binding email")
result = yield self._create_registration_details(user_id, token)
defer.returnValue((200, result))
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
return 200, {} return 200, {}
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_appservice_registration(self, username, as_token): def _do_appservice_registration(self, username, as_token, body):
(user_id, token) = yield self.registration_handler.appservice_register( user_id = yield self.registration_handler.appservice_register(
username, as_token username, as_token
) )
defer.returnValue((yield self._create_registration_details(user_id, token))) defer.returnValue((yield self._create_registration_details(user_id, body)))
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_shared_secret_registration(self, username, password, mac): def _do_shared_secret_registration(self, username, password, body):
if not self.hs.config.registration_shared_secret: if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled") raise SynapseError(400, "Shared secret registration is not enabled")
@ -277,7 +268,7 @@ class RegisterRestServlet(RestServlet):
# str() because otherwise hmac complains that 'unicode' does not # str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface # have the buffer interface
got_mac = str(mac) got_mac = str(body["mac"])
want_mac = hmac.new( want_mac = hmac.new(
key=self.hs.config.registration_shared_secret, key=self.hs.config.registration_shared_secret,
@ -290,43 +281,132 @@ class RegisterRestServlet(RestServlet):
403, "HMAC incorrect", 403, "HMAC incorrect",
) )
(user_id, token) = yield self.registration_handler.register( (user_id, _) = yield self.registration_handler.register(
localpart=username, password=password localpart=username, password=password, generate_token=False,
) )
defer.returnValue((yield self._create_registration_details(user_id, token)))
result = yield self._create_registration_details(user_id, body)
defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_registration_details(self, user_id, token): def _register_email_threepid(self, user_id, threepid, token, bind_email):
refresh_token = yield self.auth_handler.issue_refresh_token(user_id) """Add an email address as a 3pid identifier
Also adds an email pusher for the email address, if configured in the
HS config
Also optionally binds emails to the given user_id on the identity server
Args:
user_id (str): id of user
threepid (object): m.login.email.identity auth response
token (str): access_token for the user
bind_email (bool): true if the client requested the email to be
bound at the identity server
Returns:
defer.Deferred:
"""
reqd = ('medium', 'address', 'validated_at')
if any(x not in threepid for x in reqd):
logger.info("Can't add incomplete 3pid")
defer.returnValue()
yield self.auth_handler.add_threepid(
user_id,
threepid['medium'],
threepid['address'],
threepid['validated_at'],
)
# And we add an email pusher for them by default, but only
# if email notifications are enabled (so people don't start
# getting mail spam where they weren't before if email
# notifs are set up on a home server)
if (self.hs.config.email_enable_notifs and
self.hs.config.email_notif_for_new_users):
# Pull the ID of the access token back out of the db
# It would really make more sense for this to be passed
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
user_tuple = yield self.store.get_user_by_access_token(
token
)
token_id = user_tuple["token_id"]
yield self.hs.get_pusherpool().add_pusher(
user_id=user_id,
access_token=token_id,
kind="email",
app_id="m.email",
app_display_name="Email Notifications",
device_display_name=threepid["address"],
pushkey=threepid["address"],
lang=None, # We don't know a user's language here
data={},
)
if bind_email:
logger.info("bind_email specified: binding")
logger.debug("Binding emails %s to %s" % (
threepid, user_id
))
yield self.identity_handler.bind_threepid(
threepid['threepid_creds'], user_id
)
else:
logger.info("bind_email not specified: not binding email")
@defer.inlineCallbacks
def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token
and refresh_token.
Args:
(str) user_id: full canonical @user:id
(object) params: registration parameters, from which we pull
device_id and initial_device_name
Returns:
defer.Deferred: (object) dictionary for response from /register
"""
device_id = yield self._register_device(user_id, params)
access_token, refresh_token = (
yield self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id,
initial_display_name=params.get("initial_device_display_name")
)
)
defer.returnValue({ defer.returnValue({
"user_id": user_id, "user_id": user_id,
"access_token": token, "access_token": access_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"refresh_token": refresh_token, "refresh_token": refresh_token,
"device_id": device_id,
}) })
@defer.inlineCallbacks def _register_device(self, user_id, params):
def onEmailTokenRequest(self, request): """Register a device for a user.
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt'] This is called after the user's credentials have been validated, but
absent = [] before the access token has been issued.
for k in required:
if k not in body:
absent.append(k)
if len(absent) > 0: Args:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) (str) user_id: full canonical @user:id
(object) params: registration parameters, from which we pull
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( device_id and initial_device_name
'email', body['email'] Returns:
defer.Deferred: (str) device_id
"""
# register the user's device
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
device_id = self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
) )
return device_id
if existingUid is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_guest_registration(self): def _do_guest_registration(self):
@ -336,7 +416,11 @@ class RegisterRestServlet(RestServlet):
generate_token=False, generate_token=False,
make_guest=True make_guest=True
) )
access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"]) access_token = self.auth_handler.generate_access_token(
user_id, ["guest = true"]
)
# XXX the "guest" caveat is not copied by /tokenrefresh. That's ok
# so long as we don't return a refresh_token here.
defer.returnValue((200, { defer.returnValue((200, {
"user_id": user_id, "user_id": user_id,
"access_token": access_token, "access_token": access_token,
@ -345,4 +429,5 @@ class RegisterRestServlet(RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
RegisterRequestTokenRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server)

View file

@ -39,9 +39,13 @@ class TokenRefreshRestServlet(RestServlet):
try: try:
old_refresh_token = body["refresh_token"] old_refresh_token = body["refresh_token"]
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token( refresh_result = yield self.store.exchange_refresh_token(
old_refresh_token, auth_handler.generate_refresh_token) old_refresh_token, auth_handler.generate_refresh_token
new_access_token = yield auth_handler.issue_access_token(user_id) )
(user_id, new_refresh_token, device_id) = refresh_result
new_access_token = yield auth_handler.issue_access_token(
user_id, device_id
)
defer.returnValue((200, { defer.returnValue((200, {
"access_token": new_access_token, "access_token": new_access_token,
"refresh_token": new_refresh_token, "refresh_token": new_refresh_token,

View file

@ -26,7 +26,11 @@ class VersionsRestServlet(RestServlet):
def on_GET(self, request): def on_GET(self, request):
return (200, { return (200, {
"versions": ["r0.0.1"] "versions": [
"r0.0.1",
"r0.1.0",
"r0.2.0",
]
}) })

View file

@ -15,14 +15,12 @@
from synapse.http.server import respond_with_json_bytes, finish_request from synapse.http.server import respond_with_json_bytes, finish_request
from synapse.util.stringutils import random_string
from synapse.api.errors import ( from synapse.api.errors import (
cs_exception, SynapseError, CodeMessageException, Codes, cs_error Codes, cs_error
) )
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from twisted.web import server, resource from twisted.web import server, resource
from twisted.internet import defer
import base64 import base64
import simplejson as json import simplejson as json
@ -50,64 +48,10 @@ class ContentRepoResource(resource.Resource):
""" """
isLeaf = True isLeaf = True
def __init__(self, hs, directory, auth, external_addr): def __init__(self, hs, directory):
resource.Resource.__init__(self) resource.Resource.__init__(self)
self.hs = hs self.hs = hs
self.directory = directory self.directory = directory
self.auth = auth
self.external_addr = external_addr.rstrip('/')
self.max_upload_size = hs.config.max_upload_size
if not os.path.isdir(self.directory):
os.mkdir(self.directory)
logger.info("ContentRepoResource : Created %s directory.",
self.directory)
@defer.inlineCallbacks
def map_request_to_name(self, request):
# auth the user
requester = yield self.auth.get_user_by_req(request)
# namespace all file uploads on the user
prefix = base64.urlsafe_b64encode(
requester.user.to_string()
).replace('=', '')
# use a random string for the main portion
main_part = random_string(24)
# suffix with a file extension if we can make one. This is nice to
# provide a hint to clients on the file information. We will also reuse
# this info to spit back the content type to the client.
suffix = ""
if request.requestHeaders.hasHeader("Content-Type"):
content_type = request.requestHeaders.getRawHeaders(
"Content-Type")[0]
suffix = "." + base64.urlsafe_b64encode(content_type)
if (content_type.split("/")[0].lower() in
["image", "video", "audio"]):
file_ext = content_type.split("/")[-1]
# be a little paranoid and only allow a-z
file_ext = re.sub("[^a-z]", "", file_ext)
suffix += "." + file_ext
file_name = prefix + main_part + suffix
file_path = os.path.join(self.directory, file_name)
logger.info("User %s is uploading a file to path %s",
request.user.user_id.to_string(),
file_path)
# keep trying to make a non-clashing file, with a sensible max attempts
attempts = 0
while os.path.exists(file_path):
main_part = random_string(24)
file_name = prefix + main_part + suffix
file_path = os.path.join(self.directory, file_name)
attempts += 1
if attempts > 25: # really? Really?
raise SynapseError(500, "Unable to create file.")
defer.returnValue(file_path)
def render_GET(self, request): def render_GET(self, request):
# no auth here on purpose, to allow anyone to view, even across home # no auth here on purpose, to allow anyone to view, even across home
@ -155,58 +99,6 @@ class ContentRepoResource(resource.Resource):
return server.NOT_DONE_YET return server.NOT_DONE_YET
def render_POST(self, request):
self._async_render(request)
return server.NOT_DONE_YET
def render_OPTIONS(self, request): def render_OPTIONS(self, request):
respond_with_json_bytes(request, 200, {}, send_cors=True) respond_with_json_bytes(request, 200, {}, send_cors=True)
return server.NOT_DONE_YET return server.NOT_DONE_YET
@defer.inlineCallbacks
def _async_render(self, request):
try:
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length")
if content_length is None:
raise SynapseError(
msg="Request must specify a Content-Length", code=400
)
if int(content_length) > self.max_upload_size:
raise SynapseError(
msg="Upload request body is too large",
code=413,
)
fname = yield self.map_request_to_name(request)
# TODO I have a suspicious feeling this is just going to block
with open(fname, "wb") as f:
f.write(request.content.read())
# FIXME (erikj): These should use constants.
file_name = os.path.basename(fname)
# FIXME: we can't assume what the repo's public mounted path is
# ...plus self-signed SSL won't work to remote clients anyway
# ...and we can't assume that it's SSL anyway, as we might want to
# serve it via the non-SSL listener...
url = "%s/_matrix/content/%s" % (
self.external_addr, file_name
)
respond_with_json_bytes(request, 200,
json.dumps({"content_token": url}),
send_cors=True)
except CodeMessageException as e:
logger.exception(e)
respond_with_json_bytes(request, e.code,
json.dumps(cs_exception(e)))
except Exception as e:
logger.error("Failed to store file: %s" % e)
respond_with_json_bytes(
request,
500,
json.dumps({"error": "Internal server error"}),
send_cors=True)

View file

@ -65,3 +65,9 @@ class MediaFilePaths(object):
file_id[0:2], file_id[2:4], file_id[4:], file_id[0:2], file_id[2:4], file_id[4:],
file_name file_name
) )
def remote_media_thumbnail_dir(self, server_name, file_id):
return os.path.join(
self.base_path, "remote_thumbnail", server_name,
file_id[0:2], file_id[2:4], file_id[4:],
)

View file

@ -30,11 +30,13 @@ from synapse.api.errors import SynapseError
from twisted.internet import defer, threads from twisted.internet import defer, threads
from synapse.util.async import ObservableDeferred from synapse.util.async import Linearizer
from synapse.util.stringutils import is_ascii from synapse.util.stringutils import is_ascii
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import preserve_context_over_fn
import os import os
import errno
import shutil
import cgi import cgi
import logging import logging
@ -43,8 +45,11 @@ import urlparse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
class MediaRepository(object): class MediaRepository(object):
def __init__(self, hs, filepaths): def __init__(self, hs):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.client = MatrixFederationHttpClient(hs) self.client = MatrixFederationHttpClient(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -52,11 +57,28 @@ class MediaRepository(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.max_upload_size = hs.config.max_upload_size self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels self.max_image_pixels = hs.config.max_image_pixels
self.filepaths = filepaths self.filepaths = MediaFilePaths(hs.config.media_store_path)
self.downloads = {}
self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer()
self.recently_accessed_remotes = set()
self.clock.looping_call(
self._update_recently_accessed_remotes,
UPDATE_RECENTLY_ACCESSED_REMOTES_TS
)
@defer.inlineCallbacks
def _update_recently_accessed_remotes(self):
media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()
yield self.store.update_cached_last_access_time(
media, self.clock.time_msec()
)
@staticmethod @staticmethod
def _makedirs(filepath): def _makedirs(filepath):
dirname = os.path.dirname(filepath) dirname = os.path.dirname(filepath)
@ -93,22 +115,12 @@ class MediaRepository(object):
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
@defer.inlineCallbacks
def get_remote_media(self, server_name, media_id): def get_remote_media(self, server_name, media_id):
key = (server_name, media_id) key = (server_name, media_id)
download = self.downloads.get(key) with (yield self.remote_media_linearizer.queue(key)):
if download is None: media_info = yield self._get_remote_media_impl(server_name, media_id)
download = self._get_remote_media_impl(server_name, media_id) defer.returnValue(media_info)
download = ObservableDeferred(
download,
consumeErrors=True
)
self.downloads[key] = download
@download.addBoth
def callback(media_info):
del self.downloads[key]
return media_info
return download.observe()
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id): def _get_remote_media_impl(self, server_name, media_id):
@ -119,6 +131,11 @@ class MediaRepository(object):
media_info = yield self._download_remote_file( media_info = yield self._download_remote_file(
server_name, media_id server_name, media_id
) )
else:
self.recently_accessed_remotes.add((server_name, media_id))
yield self.store.update_cached_last_access_time(
[(server_name, media_id)], self.clock.time_msec()
)
defer.returnValue(media_info) defer.returnValue(media_info)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -416,6 +433,41 @@ class MediaRepository(object):
"height": m_height, "height": m_height,
}) })
@defer.inlineCallbacks
def delete_old_remote_media(self, before_ts):
old_media = yield self.store.get_remote_media_before(before_ts)
deleted = 0
for media in old_media:
origin = media["media_origin"]
media_id = media["media_id"]
file_id = media["filesystem_id"]
key = (origin, media_id)
logger.info("Deleting: %r", key)
with (yield self.remote_media_linearizer.queue(key)):
full_path = self.filepaths.remote_media_filepath(origin, file_id)
try:
os.remove(full_path)
except OSError as e:
logger.warn("Failed to remove file: %r", full_path)
if e.errno == errno.ENOENT:
pass
else:
continue
thumbnail_dir = self.filepaths.remote_media_thumbnail_dir(
origin, file_id
)
shutil.rmtree(thumbnail_dir, ignore_errors=True)
yield self.store.delete_remote_media(origin, media_id)
deleted += 1
defer.returnValue({"deleted": deleted})
class MediaRepositoryResource(Resource): class MediaRepositoryResource(Resource):
"""File uploading and downloading. """File uploading and downloading.
@ -464,9 +516,8 @@ class MediaRepositoryResource(Resource):
def __init__(self, hs): def __init__(self, hs):
Resource.__init__(self) Resource.__init__(self)
filepaths = MediaFilePaths(hs.config.media_store_path)
media_repo = MediaRepository(hs, filepaths) media_repo = hs.get_media_repository()
self.putChild("upload", UploadResource(hs, media_repo)) self.putChild("upload", UploadResource(hs, media_repo))
self.putChild("download", DownloadResource(hs, media_repo)) self.putChild("download", DownloadResource(hs, media_repo))

View file

@ -29,6 +29,8 @@ from synapse.http.server import (
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii from synapse.util.stringutils import is_ascii
from copy import deepcopy
import os import os
import re import re
import fnmatch import fnmatch
@ -329,20 +331,24 @@ class PreviewUrlResource(Resource):
# ...or if they are within a <script/> or <style/> tag. # ...or if they are within a <script/> or <style/> tag.
# This is a very very very coarse approximation to a plain text # This is a very very very coarse approximation to a plain text
# render of the page. # render of the page.
text_nodes = tree.xpath("//text()[not(ancestor::header | ancestor::nav | "
"ancestor::aside | ancestor::footer | " # We don't just use XPATH here as that is slow on some machines.
"ancestor::script | ancestor::style)]" +
"[ancestor::body]") # We clone `tree` as we modify it.
text = '' cloned_tree = deepcopy(tree.find("body"))
for text_node in text_nodes:
if len(text) < 500: TAGS_TO_REMOVE = ("header", "nav", "aside", "footer", "script", "style",)
text += text_node + ' ' for el in cloned_tree.iter(TAGS_TO_REMOVE):
else: el.getparent().remove(el)
break
text = re.sub(r'[\t ]+', ' ', text) # Split all the text nodes into paragraphs (by splitting on new
text = re.sub(r'[\t \r\n]*[\r\n]+', '\n', text) # lines)
text = text.strip()[:500] text_nodes = (
og['og:description'] = text if text else None re.sub(r'\s+', '\n', el.text).strip()
for el in cloned_tree.iter()
if el.text and isinstance(el.tag, basestring) # Removes comments
)
og['og:description'] = summarize_paragraphs(text_nodes)
# TODO: delete the url downloads to stop diskfilling, # TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG # as we only ever cared about its OG
@ -450,3 +456,56 @@ class PreviewUrlResource(Resource):
content_type.startswith("application/xhtml") content_type.startswith("application/xhtml")
): ):
return True return True
def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
# Try to get a summary of between 200 and 500 words, respecting
# first paragraph and then word boundaries.
# TODO: Respect sentences?
description = ''
# Keep adding paragraphs until we get to the MIN_SIZE.
for text_node in text_nodes:
if len(description) < min_size:
text_node = re.sub(r'[\t \r\n]+', ' ', text_node)
description += text_node + '\n\n'
else:
break
description = description.strip()
description = re.sub(r'[\t ]+', ' ', description)
description = re.sub(r'[\t \r\n]*[\r\n]+', '\n\n', description)
# If the concatenation of paragraphs to get above MIN_SIZE
# took us over MAX_SIZE, then we need to truncate mid paragraph
if len(description) > max_size:
new_desc = ""
# This splits the paragraph into words, but keeping the
# (preceeding) whitespace intact so we can easily concat
# words back together.
for match in re.finditer("\s*\S+", description):
word = match.group()
# Keep adding words while the total length is less than
# MAX_SIZE.
if len(word) + len(new_desc) < max_size:
new_desc += word
else:
# At this point the next word *will* take us over
# MAX_SIZE, but we also want to ensure that its not
# a huge word. If it is add it anyway and we'll
# truncate later.
if len(new_desc) < min_size:
new_desc += word
break
# Double check that we're not over the limit
if len(new_desc) > max_size:
new_desc = new_desc[:max_size]
# We always add an ellipsis because at the very least
# we chopped mid paragraph.
description = new_desc.strip() + ""
return description if description else None

View file

@ -19,37 +19,38 @@
# partial one for unit test mocking. # partial one for unit test mocking.
# Imports required for the default HomeServer() implementation # Imports required for the default HomeServer() implementation
from twisted.web.client import BrowserLikePolicyForHTTPS
from twisted.enterprise import adbapi
from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.appservice.api import ApplicationServiceApi
from synapse.federation import initialize_http_replication
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.notifier import Notifier
from synapse.api.auth import Auth
from synapse.handlers import Handlers
from synapse.handlers.presence import PresenceHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
from synapse.handlers.room import RoomListHandler
from synapse.handlers.auth import AuthHandler
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.state import StateHandler
from synapse.storage import DataStore
from synapse.util import Clock
from synapse.util.distributor import Distributor
from synapse.streams.events import EventSources
from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.keyring import Keyring
from synapse.push.pusherpool import PusherPool
from synapse.events.builder import EventBuilderFactory
from synapse.api.filtering import Filtering
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
import logging import logging
from twisted.enterprise import adbapi
from twisted.web.client import BrowserLikePolicyForHTTPS
from synapse.api.auth import Auth
from synapse.api.filtering import Filtering
from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice.api import ApplicationServiceApi
from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory
from synapse.federation import initialize_http_replication
from synapse.handlers import Handlers
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler
from synapse.handlers.device import DeviceHandler
from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.presence import PresenceHandler
from synapse.handlers.room import RoomListHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier
from synapse.push.pusherpool import PusherPool
from synapse.rest.media.v1.media_repository import MediaRepository
from synapse.state import StateHandler
from synapse.storage import DataStore
from synapse.streams.events import EventSources
from synapse.util import Clock
from synapse.util.distributor import Distributor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -91,6 +92,8 @@ class HomeServer(object):
'typing_handler', 'typing_handler',
'room_list_handler', 'room_list_handler',
'auth_handler', 'auth_handler',
'device_handler',
'e2e_keys_handler',
'application_service_api', 'application_service_api',
'application_service_scheduler', 'application_service_scheduler',
'application_service_handler', 'application_service_handler',
@ -113,6 +116,7 @@ class HomeServer(object):
'filtering', 'filtering',
'http_client_context_factory', 'http_client_context_factory',
'simple_http_client', 'simple_http_client',
'media_repository',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -195,6 +199,12 @@ class HomeServer(object):
def build_auth_handler(self): def build_auth_handler(self):
return AuthHandler(self) return AuthHandler(self)
def build_device_handler(self):
return DeviceHandler(self)
def build_e2e_keys_handler(self):
return E2eKeysHandler(self)
def build_application_service_api(self): def build_application_service_api(self):
return ApplicationServiceApi(self) return ApplicationServiceApi(self)
@ -233,6 +243,9 @@ class HomeServer(object):
**self.db_config.get("args", {}) **self.db_config.get("args", {})
) )
def build_media_repository(self):
return MediaRepository(self)
def remove_pusher(self, app_id, push_key, user_id): def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

25
synapse/server.pyi Normal file
View file

@ -0,0 +1,25 @@
import synapse.handlers
import synapse.handlers.auth
import synapse.handlers.device
import synapse.handlers.e2e_keys
import synapse.storage
import synapse.state
class HomeServer(object):
def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler:
pass
def get_datastore(self) -> synapse.storage.DataStore:
pass
def get_device_handler(self) -> synapse.handlers.device.DeviceHandler:
pass
def get_e2e_keys_handler(self) -> synapse.handlers.e2e_keys.E2eKeysHandler:
pass
def get_handlers(self) -> synapse.handlers.Handlers:
pass
def get_state_handler(self) -> synapse.state.StateHandler:
pass

View file

@ -379,7 +379,8 @@ class StateHandler(object):
try: try:
# FIXME: hs.get_auth() is bad style, but we need to do it to # FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps. # get around circular deps.
self.hs.get_auth().check(event, auth_events) # The signatures have already been checked at this point
self.hs.get_auth().check(event, auth_events, do_sig_check=False)
prev_event = event prev_event = event
except AuthError: except AuthError:
return prev_event return prev_event
@ -391,7 +392,8 @@ class StateHandler(object):
try: try:
# FIXME: hs.get_auth() is bad style, but we need to do it to # FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps. # get around circular deps.
self.hs.get_auth().check(event, auth_events) # The signatures have already been checked at this point
self.hs.get_auth().check(event, auth_events, do_sig_check=False)
return event return event
except AuthError: except AuthError:
pass pass

View file

@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.storage.devices import DeviceStore
from .appservice import ( from .appservice import (
ApplicationServiceStore, ApplicationServiceTransactionStore ApplicationServiceStore, ApplicationServiceTransactionStore
) )
@ -80,6 +82,7 @@ class DataStore(RoomMemberStore, RoomStore,
EventPushActionsStore, EventPushActionsStore,
OpenIdStore, OpenIdStore,
ClientIpStore, ClientIpStore,
DeviceStore,
): ):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
@ -92,7 +95,8 @@ class DataStore(RoomMemberStore, RoomStore,
extra_tables=[("local_invites", "stream_id")] extra_tables=[("local_invites", "stream_id")]
) )
self._backfill_id_gen = StreamIdGenerator( self._backfill_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering", step=-1 db_conn, "events", "stream_ordering", step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
) )
self._receipts_id_gen = StreamIdGenerator( self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id" db_conn, "receipts_linearized", "stream_id"

View file

@ -597,10 +597,13 @@ class SQLBaseStore(object):
more rows, returning the result as a list of dicts. more rows, returning the result as a list of dicts.
Args: Args:
table : string giving the table name table (str): the table name
keyvalues : dict of column names and values to select the rows with, keyvalues (dict[str, Any] | None):
or None to not apply a WHERE clause. column names and values to select the rows with, or None to not
retcols : list of strings giving the names of the columns to return apply a WHERE clause.
retcols (iterable[str]): the names of the columns to return
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
""" """
return self.runInteraction( return self.runInteraction(
desc, desc,
@ -615,9 +618,11 @@ class SQLBaseStore(object):
Args: Args:
txn : Transaction object txn : Transaction object
table : string giving the table name table (str): the table name
keyvalues : dict of column names and values to select the rows with keyvalues (dict[str, T] | None):
retcols : list of strings giving the names of the columns to return column names and values to select the rows with, or None to not
apply a WHERE clause.
retcols (iterable[str]): the names of the columns to return
""" """
if keyvalues: if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % ( sql = "SELECT %s FROM %s WHERE %s" % (
@ -807,6 +812,11 @@ class SQLBaseStore(object):
if txn.rowcount > 1: if txn.rowcount > 1:
raise StoreError(500, "more than one row matched") raise StoreError(500, "more than one row matched")
def _simple_delete(self, table, keyvalues, desc):
return self.runInteraction(
desc, self._simple_delete_txn, table, keyvalues
)
@staticmethod @staticmethod
def _simple_delete_txn(txn, table, keyvalues): def _simple_delete_txn(txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % ( sql = "DELETE FROM %s WHERE %s" % (

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore from ._base import SQLBaseStore
from . import engines
from twisted.internet import defer from twisted.internet import defer
@ -87,10 +88,12 @@ class BackgroundUpdateStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def start_doing_background_updates(self): def start_doing_background_updates(self):
while True: assert self._background_update_timer is None, \
if self._background_update_timer is not None: "background updates already running"
return
logger.info("Starting background schema updates")
while True:
sleep = defer.Deferred() sleep = defer.Deferred()
self._background_update_timer = self._clock.call_later( self._background_update_timer = self._clock.call_later(
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
@ -101,22 +104,23 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_timer = None self._background_update_timer = None
try: try:
result = yield self.do_background_update( result = yield self.do_next_background_update(
self.BACKGROUND_UPDATE_DURATION_MS self.BACKGROUND_UPDATE_DURATION_MS
) )
except: except:
logger.exception("Error doing update") logger.exception("Error doing update")
else:
if result is None: if result is None:
logger.info( logger.info(
"No more background updates to do." "No more background updates to do."
" Unscheduling background update task." " Unscheduling background update task."
) )
return defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def do_background_update(self, desired_duration_ms): def do_next_background_update(self, desired_duration_ms):
"""Does some amount of work on a background update """Does some amount of work on the next queued background update
Args: Args:
desired_duration_ms(float): How long we want to spend desired_duration_ms(float): How long we want to spend
updating. updating.
@ -135,11 +139,21 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_queue.append(update['update_name']) self._background_update_queue.append(update['update_name'])
if not self._background_update_queue: if not self._background_update_queue:
# no work left to do
defer.returnValue(None) defer.returnValue(None)
# pop from the front, and add back to the back
update_name = self._background_update_queue.pop(0) update_name = self._background_update_queue.pop(0)
self._background_update_queue.append(update_name) self._background_update_queue.append(update_name)
res = yield self._do_background_update(update_name, desired_duration_ms)
defer.returnValue(res)
@defer.inlineCallbacks
def _do_background_update(self, update_name, desired_duration_ms):
logger.info("Starting update batch on background update '%s'",
update_name)
update_handler = self._background_update_handlers[update_name] update_handler = self._background_update_handlers[update_name]
performance = self._background_update_performance.get(update_name) performance = self._background_update_performance.get(update_name)
@ -202,6 +216,64 @@ class BackgroundUpdateStore(SQLBaseStore):
""" """
self._background_update_handlers[update_name] = update_handler self._background_update_handlers[update_name] = update_handler
def register_background_index_update(self, update_name, index_name,
table, columns):
"""Helper for store classes to do a background index addition
To use:
1. use a schema delta file to add a background update. Example:
INSERT INTO background_updates (update_name, progress_json) VALUES
('my_new_index', '{}');
2. In the Store constructor, call this method
Args:
update_name (str): update_name to register for
index_name (str): name of index to add
table (str): table to add index to
columns (list[str]): columns/expressions to include in index
"""
# if this is postgres, we add the indexes concurrently. Otherwise
# we fall back to doing it inline
if isinstance(self.database_engine, engines.PostgresEngine):
conc = True
else:
conc = False
sql = "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)" \
% {
"conc": "CONCURRENTLY" if conc else "",
"name": index_name,
"table": table,
"columns": ", ".join(columns),
}
def create_index_concurrently(conn):
conn.rollback()
# postgres insists on autocommit for the index
conn.set_session(autocommit=True)
c = conn.cursor()
c.execute(sql)
conn.set_session(autocommit=False)
def create_index(conn):
c = conn.cursor()
c.execute(sql)
@defer.inlineCallbacks
def updater(progress, batch_size):
logger.info("Adding index %s to %s", index_name, table)
if conc:
yield self.runWithConnection(create_index_concurrently)
else:
yield self.runWithConnection(create_index)
yield self._end_background_update(update_name)
defer.returnValue(1)
self.register_background_update_handler(update_name, updater)
def start_background_update(self, update_name, progress): def start_background_update(self, update_name, progress):
"""Starts a background update running. """Starts a background update running.

View file

@ -13,10 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, Cache import logging
from twisted.internet import defer from twisted.internet import defer
from ._base import Cache
from . import background_updates
logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller # Number of msec of granularity to store the user IP 'last seen' time. Smaller
# times give more inserts into the database even for readonly API hits # times give more inserts into the database even for readonly API hits
@ -24,8 +28,7 @@ from twisted.internet import defer
LAST_SEEN_GRANULARITY = 120 * 1000 LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpStore(SQLBaseStore): class ClientIpStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs): def __init__(self, hs):
self.client_ip_last_seen = Cache( self.client_ip_last_seen = Cache(
name="client_ip_last_seen", name="client_ip_last_seen",
@ -34,8 +37,15 @@ class ClientIpStore(SQLBaseStore):
super(ClientIpStore, self).__init__(hs) super(ClientIpStore, self).__init__(hs)
self.register_background_index_update(
"user_ips_device_index",
index_name="user_ips_device_id",
table="user_ips",
columns=["user_id", "device_id", "last_seen"],
)
@defer.inlineCallbacks @defer.inlineCallbacks
def insert_client_ip(self, user, access_token, ip, user_agent): def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec()) now = int(self._clock.time_msec())
key = (user.to_string(), access_token, ip) key = (user.to_string(), access_token, ip)
@ -59,6 +69,7 @@ class ClientIpStore(SQLBaseStore):
"access_token": access_token, "access_token": access_token,
"ip": ip, "ip": ip,
"user_agent": user_agent, "user_agent": user_agent,
"device_id": device_id,
}, },
values={ values={
"last_seen": now, "last_seen": now,
@ -66,3 +77,69 @@ class ClientIpStore(SQLBaseStore):
desc="insert_client_ip", desc="insert_client_ip",
lock=False, lock=False,
) )
@defer.inlineCallbacks
def get_last_client_ip_by_device(self, devices):
"""For each device_id listed, give the user_ip it was last seen on
Args:
devices (iterable[(str, str)]): list of (user_id, device_id) pairs
Returns:
defer.Deferred: resolves to a dict, where the keys
are (user_id, device_id) tuples. The values are also dicts, with
keys giving the column names
"""
res = yield self.runInteraction(
"get_last_client_ip_by_device",
self._get_last_client_ip_by_device_txn,
retcols=(
"user_id",
"access_token",
"ip",
"user_agent",
"device_id",
"last_seen",
),
devices=devices
)
ret = {(d["user_id"], d["device_id"]): d for d in res}
defer.returnValue(ret)
@classmethod
def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols):
where_clauses = []
bindings = []
for (user_id, device_id) in devices:
if device_id is None:
where_clauses.append("(user_id = ? AND device_id IS NULL)")
bindings.extend((user_id, ))
else:
where_clauses.append("(user_id = ? AND device_id = ?)")
bindings.extend((user_id, device_id))
inner_select = (
"SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
"WHERE %(where)s "
"GROUP BY user_id, device_id"
) % {
"where": " OR ".join(where_clauses),
}
sql = (
"SELECT %(retcols)s FROM user_ips "
"JOIN (%(inner_select)s) ips ON"
" user_ips.last_seen = ips.mls AND"
" user_ips.user_id = ips.user_id AND"
" (user_ips.device_id = ips.device_id OR"
" (user_ips.device_id IS NULL AND ips.device_id IS NULL)"
" )"
) % {
"retcols": ",".join("user_ips." + c for c in retcols),
"inner_select": inner_select,
}
txn.execute(sql, bindings)
return cls.cursor_to_dict(txn)

137
synapse/storage/devices.py Normal file
View file

@ -0,0 +1,137 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.errors import StoreError
from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
class DeviceStore(SQLBaseStore):
@defer.inlineCallbacks
def store_device(self, user_id, device_id,
initial_device_display_name,
ignore_if_known=True):
"""Ensure the given device is known; add it to the store if not
Args:
user_id (str): id of user associated with the device
device_id (str): id of device
initial_device_display_name (str): initial displayname of the
device
ignore_if_known (bool): ignore integrity errors which mean the
device is already known
Returns:
defer.Deferred
Raises:
StoreError: if ignore_if_known is False and the device was already
known
"""
try:
yield self._simple_insert(
"devices",
values={
"user_id": user_id,
"device_id": device_id,
"display_name": initial_device_display_name
},
desc="store_device",
or_ignore=ignore_if_known,
)
except Exception as e:
logger.error("store_device with device_id=%s failed: %s",
device_id, e)
raise StoreError(500, "Problem storing device.")
def get_device(self, user_id, device_id):
"""Retrieve a device.
Args:
user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to retrieve
Returns:
defer.Deferred for a dict containing the device information
Raises:
StoreError: if the device is not found
"""
return self._simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
)
def delete_device(self, user_id, device_id):
"""Delete a device.
Args:
user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to delete
Returns:
defer.Deferred
"""
return self._simple_delete_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
desc="delete_device",
)
def update_device(self, user_id, device_id, new_display_name=None):
"""Update a device.
Args:
user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to update
new_display_name (str|None): new displayname for device; None
to leave unchanged
Raises:
StoreError: if the device is not found
Returns:
defer.Deferred
"""
updates = {}
if new_display_name is not None:
updates["display_name"] = new_display_name
if not updates:
return defer.succeed(None)
return self._simple_update_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
updatevalues=updates,
desc="update_device",
)
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""Retrieve all of a user's registered devices.
Args:
user_id (str):
Returns:
defer.Deferred: resolves to a dict from device_id to a dict
containing "device_id", "user_id" and "display_name" for each
device.
"""
devices = yield self._simple_select_list(
table="devices",
keyvalues={"user_id": user_id},
retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user"
)
defer.returnValue({d["device_id"]: d for d in devices})

View file

@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import twisted.internet.defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
@ -36,24 +39,49 @@ class EndToEndKeyStore(SQLBaseStore):
query_list(list): List of pairs of user_ids and device_ids. query_list(list): List of pairs of user_ids and device_ids.
Returns: Returns:
Dict mapping from user-id to dict mapping from device_id to Dict mapping from user-id to dict mapping from device_id to
key json byte strings. dict containing "key_json", "device_display_name".
""" """
def _get_e2e_device_keys(txn): if not query_list:
result = {} return {}
for user_id, device_id in query_list:
user_result = result.setdefault(user_id, {}) return self.runInteraction(
keyvalues = {"user_id": user_id} "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list
if device_id:
keyvalues["device_id"] = device_id
rows = self._simple_select_list_txn(
txn, table="e2e_device_keys_json",
keyvalues=keyvalues,
retcols=["device_id", "key_json"]
) )
def _get_e2e_device_keys_txn(self, txn, query_list):
query_clauses = []
query_params = []
for (user_id, device_id) in query_list:
query_clause = "k.user_id = ?"
query_params.append(user_id)
if device_id:
query_clause += " AND k.device_id = ?"
query_params.append(device_id)
query_clauses.append(query_clause)
sql = (
"SELECT k.user_id, k.device_id, "
" d.display_name AS device_display_name, "
" k.key_json"
" FROM e2e_device_keys_json k"
" LEFT JOIN devices d ON d.user_id = k.user_id"
" AND d.device_id = k.device_id"
" WHERE %s"
) % (
" OR ".join("(" + q + ")" for q in query_clauses)
)
txn.execute(sql, query_params)
rows = self.cursor_to_dict(txn)
result = collections.defaultdict(dict)
for row in rows: for row in rows:
user_result[row["device_id"]] = row["key_json"] result[row["user_id"]][row["device_id"]] = row
return result return result
return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys)
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
def _add_e2e_one_time_keys(txn): def _add_e2e_one_time_keys(txn):
@ -123,3 +151,16 @@ class EndToEndKeyStore(SQLBaseStore):
return self.runInteraction( return self.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
) )
@twisted.internet.defer.inlineCallbacks
def delete_e2e_keys_by_device(self, user_id, device_id):
yield self._simple_delete(
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
desc="delete_e2e_device_keys_by_device"
)
yield self._simple_delete(
table="e2e_one_time_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
desc="delete_e2e_one_time_keys_by_device"
)

View file

@ -16,6 +16,8 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from twisted.internet import defer from twisted.internet import defer
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.types import RoomStreamToken
from .stream import lower_bound
import logging import logging
import ujson as json import ujson as json
@ -73,6 +75,9 @@ class EventPushActionsStore(SQLBaseStore):
stream_ordering = results[0][0] stream_ordering = results[0][0]
topological_ordering = results[0][1] topological_ordering = results[0][1]
token = RoomStreamToken(
topological_ordering, stream_ordering
)
sql = ( sql = (
"SELECT sum(notif), sum(highlight)" "SELECT sum(notif), sum(highlight)"
@ -80,15 +85,10 @@ class EventPushActionsStore(SQLBaseStore):
" WHERE" " WHERE"
" user_id = ?" " user_id = ?"
" AND room_id = ?" " AND room_id = ?"
" AND (" " AND %s"
" topological_ordering > ?" ) % (lower_bound(token, self.database_engine, inclusive=False),)
" OR (topological_ordering = ? AND stream_ordering > ?)"
")" txn.execute(sql, (user_id, room_id))
)
txn.execute(sql, (
user_id, room_id,
topological_ordering, topological_ordering, stream_ordering
))
row = txn.fetchone() row = txn.fetchone()
if row: if row:
return { return {
@ -117,21 +117,149 @@ class EventPushActionsStore(SQLBaseStore):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_unread_push_actions_for_user_in_range(self, user_id, def get_unread_push_actions_for_user_in_range_for_http(
min_stream_ordering, self, user_id, min_stream_ordering, max_stream_ordering, limit=20
max_stream_ordering=None, ):
limit=20): """Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the httppusher.
Args:
user_id (str): The user to fetch push actions for.
min_stream_ordering(int): The exclusive lower bound on the
stream ordering of event push actions to fetch.
max_stream_ordering(int): The inclusive upper bound on the
stream ordering of event push actions to fetch.
limit (int): The maximum number of rows to return.
Returns:
A promise which resolves to a list of dicts with the keys "event_id",
"room_id", "stream_ordering", "actions".
The list will be ordered by ascending stream_ordering.
The list will have between 0~limit entries.
"""
# find rooms that have a read receipt in them and return the next
# push actions
def get_after_receipt(txn):
# find rooms that have a read receipt in them and return the next
# push actions
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions"
" FROM ("
" SELECT room_id,"
" MAX(topological_ordering) as topological_ordering,"
" MAX(stream_ordering) as stream_ordering"
" FROM events"
" INNER JOIN receipts_linearized USING (room_id, event_id)"
" WHERE receipt_type = 'm.read' AND user_id = ?"
" GROUP BY room_id"
") AS rl,"
" event_push_actions AS ep"
" WHERE"
" ep.room_id = rl.room_id"
" AND ("
" ep.topological_ordering > rl.topological_ordering"
" OR ("
" ep.topological_ordering = rl.topological_ordering"
" AND ep.stream_ordering > rl.stream_ordering"
" )"
" )"
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
args = [
user_id, user_id,
min_stream_ordering, max_stream_ordering, limit,
]
txn.execute(sql, args)
return txn.fetchall()
after_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
)
# There are rooms with push actions in them but you don't have a read receipt in
# them e.g. rooms you've been invited to, so get push actions for rooms which do
# not have read receipts in them too.
def get_no_receipt(txn):
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" e.received_ts"
" FROM event_push_actions AS ep"
" INNER JOIN events AS e USING (room_id, event_id)"
" WHERE"
" ep.room_id NOT IN ("
" SELECT room_id FROM receipts_linearized"
" WHERE receipt_type = 'm.read' AND user_id = ?"
" GROUP BY room_id"
" )"
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
args = [
user_id, user_id,
min_stream_ordering, max_stream_ordering, limit,
]
txn.execute(sql, args)
return txn.fetchall()
no_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
)
notifs = [
{
"event_id": row[0],
"room_id": row[1],
"stream_ordering": row[2],
"actions": json.loads(row[3]),
} for row in after_read_receipt + no_read_receipt
]
# Now sort it so it's ordered correctly, since currently it will
# contain results from the first query, correctly ordered, followed
# by results from the second query, but we want them all ordered
# by stream_ordering, oldest first.
notifs.sort(key=lambda r: r['stream_ordering'])
# Take only up to the limit. We have to stop at the limit because
# one of the subqueries may have hit the limit.
defer.returnValue(notifs[:limit])
@defer.inlineCallbacks
def get_unread_push_actions_for_user_in_range_for_email(
self, user_id, min_stream_ordering, max_stream_ordering, limit=20
):
"""Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the emailpusher
Args:
user_id (str): The user to fetch push actions for.
min_stream_ordering(int): The exclusive lower bound on the
stream ordering of event push actions to fetch.
max_stream_ordering(int): The inclusive upper bound on the
stream ordering of event push actions to fetch.
limit (int): The maximum number of rows to return.
Returns:
A promise which resolves to a list of dicts with the keys "event_id",
"room_id", "stream_ordering", "actions", "received_ts".
The list will be ordered by descending received_ts.
The list will have between 0~limit entries.
"""
# find rooms that have a read receipt in them and return the most recent
# push actions
def get_after_receipt(txn): def get_after_receipt(txn):
sql = ( sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, " "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
"e.received_ts " " e.received_ts"
"FROM (" " FROM ("
" SELECT room_id, user_id, " " SELECT room_id,"
" max(topological_ordering) as topological_ordering, " " MAX(topological_ordering) as topological_ordering,"
" max(stream_ordering) as stream_ordering " " MAX(stream_ordering) as stream_ordering"
" FROM events" " FROM events"
" NATURAL JOIN receipts_linearized WHERE receipt_type = 'm.read'" " INNER JOIN receipts_linearized USING (room_id, event_id)"
" GROUP BY room_id, user_id" " WHERE receipt_type = 'm.read' AND user_id = ?"
" GROUP BY room_id"
") AS rl," ") AS rl,"
" event_push_actions AS ep" " event_push_actions AS ep"
" INNER JOIN events AS e USING (room_id, event_id)" " INNER JOIN events AS e USING (room_id, event_id)"
@ -144,46 +272,53 @@ class EventPushActionsStore(SQLBaseStore):
" AND ep.stream_ordering > rl.stream_ordering" " AND ep.stream_ordering > rl.stream_ordering"
" )" " )"
" )" " )"
" AND ep.stream_ordering > ?"
" AND ep.user_id = ?" " AND ep.user_id = ?"
" AND ep.user_id = rl.user_id" " AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
) )
args = [min_stream_ordering, user_id] args = [
if max_stream_ordering is not None: user_id, user_id,
sql += " AND ep.stream_ordering <= ?" min_stream_ordering, max_stream_ordering, limit,
args.append(max_stream_ordering) ]
sql += " ORDER BY ep.stream_ordering ASC LIMIT ?"
args.append(limit)
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
after_read_receipt = yield self.runInteraction( after_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range", get_after_receipt "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
) )
# There are rooms with push actions in them but you don't have a read receipt in
# them e.g. rooms you've been invited to, so get push actions for rooms which do
# not have read receipts in them too.
def get_no_receipt(txn): def get_no_receipt(txn):
sql = ( sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" e.received_ts" " e.received_ts"
" FROM event_push_actions AS ep" " FROM event_push_actions AS ep"
" JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id" " INNER JOIN events AS e USING (room_id, event_id)"
" WHERE ep.room_id not in (" " WHERE"
" SELECT room_id FROM events NATURAL JOIN receipts_linearized" " ep.room_id NOT IN ("
" SELECT room_id FROM receipts_linearized"
" WHERE receipt_type = 'm.read' AND user_id = ?" " WHERE receipt_type = 'm.read' AND user_id = ?"
" GROUP BY room_id" " GROUP BY room_id"
") AND ep.user_id = ? AND ep.stream_ordering > ?" " )"
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
) )
args = [user_id, user_id, min_stream_ordering] args = [
if max_stream_ordering is not None: user_id, user_id,
sql += " AND ep.stream_ordering <= ?" min_stream_ordering, max_stream_ordering, limit,
args.append(max_stream_ordering) ]
sql += " ORDER BY ep.stream_ordering ASC"
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
no_read_receipt = yield self.runInteraction( no_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range", get_no_receipt "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
) )
defer.returnValue([ # Make a list of dicts from the two sets of results.
notifs = [
{ {
"event_id": row[0], "event_id": row[0],
"room_id": row[1], "room_id": row[1],
@ -191,7 +326,16 @@ class EventPushActionsStore(SQLBaseStore):
"actions": json.loads(row[3]), "actions": json.loads(row[3]),
"received_ts": row[4], "received_ts": row[4],
} for row in after_read_receipt + no_read_receipt } for row in after_read_receipt + no_read_receipt
]) ]
# Now sort it so it's ordered correctly, since currently it will
# contain results from the first query, correctly ordered, followed
# by results from the second query, but we want them all ordered
# by received_ts (most recent first)
notifs.sort(key=lambda r: -(r['received_ts'] or 0))
# Now return the first `limit`
defer.returnValue(notifs[:limit])
@defer.inlineCallbacks @defer.inlineCallbacks
def get_time_of_last_push_action_before(self, stream_ordering): def get_time_of_last_push_action_before(self, stream_ordering):

View file

@ -23,9 +23,11 @@ from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import preserve_fn, PreserveLoggingContext from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from collections import deque, namedtuple from collections import deque, namedtuple, OrderedDict
from functools import wraps
import synapse import synapse
import synapse.metrics import synapse.metrics
@ -149,8 +151,29 @@ class _EventPeristenceQueue(object):
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
def _retry_on_integrity_error(func):
"""Wraps a database function so that it gets retried on IntegrityError,
with `delete_existing=True` passed in.
Args:
func: function that returns a Deferred and accepts a `delete_existing` arg
"""
@wraps(func)
@defer.inlineCallbacks
def f(self, *args, **kwargs):
try:
res = yield func(self, *args, **kwargs)
except self.database_engine.module.IntegrityError:
logger.exception("IntegrityError, retrying.")
res = yield func(self, *args, delete_existing=True, **kwargs)
defer.returnValue(res)
return f
class EventsStore(SQLBaseStore): class EventsStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
def __init__(self, hs): def __init__(self, hs):
super(EventsStore, self).__init__(hs) super(EventsStore, self).__init__(hs)
@ -158,6 +181,10 @@ class EventsStore(SQLBaseStore):
self.register_background_update_handler( self.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
) )
self.register_background_update_handler(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
self._background_reindex_fields_sender,
)
self._event_persist_queue = _EventPeristenceQueue() self._event_persist_queue = _EventPeristenceQueue()
@ -223,8 +250,10 @@ class EventsStore(SQLBaseStore):
self._event_persist_queue.handle_queue(room_id, persisting_queue) self._event_persist_queue.handle_queue(room_id, persisting_queue)
@_retry_on_integrity_error
@defer.inlineCallbacks @defer.inlineCallbacks
def _persist_events(self, events_and_contexts, backfilled=False): def _persist_events(self, events_and_contexts, backfilled=False,
delete_existing=False):
if not events_and_contexts: if not events_and_contexts:
return return
@ -267,12 +296,15 @@ class EventsStore(SQLBaseStore):
self._persist_events_txn, self._persist_events_txn,
events_and_contexts=chunk, events_and_contexts=chunk,
backfilled=backfilled, backfilled=backfilled,
delete_existing=delete_existing,
) )
persist_event_counter.inc_by(len(chunk)) persist_event_counter.inc_by(len(chunk))
@_retry_on_integrity_error
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _persist_event(self, event, context, current_state=None, backfilled=False): def _persist_event(self, event, context, current_state=None, backfilled=False,
delete_existing=False):
try: try:
with self._stream_id_gen.get_next() as stream_ordering: with self._stream_id_gen.get_next() as stream_ordering:
with self._state_groups_id_gen.get_next() as state_group_id: with self._state_groups_id_gen.get_next() as state_group_id:
@ -285,6 +317,7 @@ class EventsStore(SQLBaseStore):
context=context, context=context,
current_state=current_state, current_state=current_state,
backfilled=backfilled, backfilled=backfilled,
delete_existing=delete_existing,
) )
persist_event_counter.inc() persist_event_counter.inc()
except _RollbackButIsFineException: except _RollbackButIsFineException:
@ -317,7 +350,7 @@ class EventsStore(SQLBaseStore):
) )
if not events and not allow_none: if not events and not allow_none:
raise RuntimeError("Could not find event %s" % (event_id,)) raise SynapseError(404, "Could not find event %s" % (event_id,))
defer.returnValue(events[0] if events else None) defer.returnValue(events[0] if events else None)
@ -347,7 +380,8 @@ class EventsStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events}) defer.returnValue({e.event_id: e for e in events})
@log_function @log_function
def _persist_event_txn(self, txn, event, context, current_state, backfilled=False): def _persist_event_txn(self, txn, event, context, current_state, backfilled=False,
delete_existing=False):
# We purposefully do this first since if we include a `current_state` # We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table # key, we *want* to update the `current_state_events` table
if current_state: if current_state:
@ -355,7 +389,6 @@ class EventsStore(SQLBaseStore):
txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_room_name_and_aliases.invalidate, (event.room_id,))
# Add an entry to the current_state_resets table to record the point # Add an entry to the current_state_resets table to record the point
# where we clobbered the current state # where we clobbered the current state
@ -388,10 +421,38 @@ class EventsStore(SQLBaseStore):
txn, txn,
[(event, context)], [(event, context)],
backfilled=backfilled, backfilled=backfilled,
delete_existing=delete_existing,
) )
@log_function @log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled): def _persist_events_txn(self, txn, events_and_contexts, backfilled,
delete_existing=False):
"""Insert some number of room events into the necessary database tables.
Rejected events are only inserted into the events table, the events_json table,
and the rejections table. Things reading from those table will need to check
whether the event was rejected.
If delete_existing is True then existing events will be purged from the
database before insertion. This is useful when retrying due to IntegrityError.
"""
# Ensure that we don't have the same event twice.
# Pick the earliest non-outlier if there is one, else the earliest one.
new_events_and_contexts = OrderedDict()
for event, context in events_and_contexts:
prev_event_context = new_events_and_contexts.get(event.event_id)
if prev_event_context:
if not event.internal_metadata.is_outlier():
if prev_event_context[0].internal_metadata.is_outlier():
# To ensure correct ordering we pop, as OrderedDict is
# ordered by first insertion.
new_events_and_contexts.pop(event.event_id, None)
new_events_and_contexts[event.event_id] = (event, context)
else:
new_events_and_contexts[event.event_id] = (event, context)
events_and_contexts = new_events_and_contexts.values()
depth_updates = {} depth_updates = {}
for event, context in events_and_contexts: for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids # Remove the any existing cache entries for the event_ids
@ -402,21 +463,11 @@ class EventsStore(SQLBaseStore):
event.room_id, event.internal_metadata.stream_ordering, event.room_id, event.internal_metadata.stream_ordering,
) )
if not event.internal_metadata.is_outlier(): if not event.internal_metadata.is_outlier() and not context.rejected:
depth_updates[event.room_id] = max( depth_updates[event.room_id] = max(
event.depth, depth_updates.get(event.room_id, event.depth) event.depth, depth_updates.get(event.room_id, event.depth)
) )
if context.push_actions:
self._set_push_actions_for_event_and_users_txn(
txn, event, context.push_actions
)
if event.type == EventTypes.Redaction and event.redacts is not None:
self._remove_push_actions_for_event_id_txn(
txn, event.room_id, event.redacts
)
for room_id, depth in depth_updates.items(): for room_id, depth in depth_updates.items():
self._update_min_depth_for_room_txn(txn, room_id, depth) self._update_min_depth_for_room_txn(txn, room_id, depth)
@ -426,30 +477,21 @@ class EventsStore(SQLBaseStore):
), ),
[event.event_id for event, _ in events_and_contexts] [event.event_id for event, _ in events_and_contexts]
) )
have_persisted = { have_persisted = {
event_id: outlier event_id: outlier
for event_id, outlier in txn.fetchall() for event_id, outlier in txn.fetchall()
} }
event_map = {}
to_remove = set() to_remove = set()
for event, context in events_and_contexts: for event, context in events_and_contexts:
# Handle the case of the list including the same event multiple if context.rejected:
# times. The tricky thing here is when they differ by whether # If the event is rejected then we don't care if the event
# they are an outlier. # was an outlier or not.
if event.event_id in event_map: if event.event_id in have_persisted:
other = event_map[event.event_id] # If we have already seen the event then ignore it.
if not other.internal_metadata.is_outlier():
to_remove.add(event) to_remove.add(event)
continue continue
elif not event.internal_metadata.is_outlier():
to_remove.add(event)
continue
else:
to_remove.add(other)
event_map[event.event_id] = event
if event.event_id not in have_persisted: if event.event_id not in have_persisted:
continue continue
@ -458,6 +500,12 @@ class EventsStore(SQLBaseStore):
outlier_persisted = have_persisted[event.event_id] outlier_persisted = have_persisted[event.event_id]
if not event.internal_metadata.is_outlier() and outlier_persisted: if not event.internal_metadata.is_outlier() and outlier_persisted:
# We received a copy of an event that we had already stored as
# an outlier in the database. We now have some state at that
# so we need to update the state_groups table with that state.
# insert into the state_group, state_groups_state and
# event_to_state_groups tables.
self._store_mult_state_groups_txn(txn, ((event, context),)) self._store_mult_state_groups_txn(txn, ((event, context),))
metadata_json = encode_json( metadata_json = encode_json(
@ -473,6 +521,8 @@ class EventsStore(SQLBaseStore):
(metadata_json, event.event_id,) (metadata_json, event.event_id,)
) )
# Add an entry to the ex_outlier_stream table to replicate the
# change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group or context.new_state_group_id state_group_id = context.state_group or context.new_state_group_id
self._simple_insert_txn( self._simple_insert_txn(
@ -494,6 +544,8 @@ class EventsStore(SQLBaseStore):
(False, event.event_id,) (False, event.event_id,)
) )
# Update the event_backward_extremities table now that this
# event isn't an outlier any more.
self._update_extremeties(txn, [event]) self._update_extremeties(txn, [event])
events_and_contexts = [ events_and_contexts = [
@ -501,38 +553,12 @@ class EventsStore(SQLBaseStore):
] ]
if not events_and_contexts: if not events_and_contexts:
# Make sure we don't pass an empty list to functions that expect to
# be storing at least one element.
return return
self._store_mult_state_groups_txn(txn, events_and_contexts) # From this point onwards the events are only events that we haven't
# seen before.
self._handle_mult_prev_events(
txn,
events=[event for event, _ in events_and_contexts],
)
for event, _ in events_and_contexts:
if event.type == EventTypes.Name:
self._store_room_name_txn(txn, event)
elif event.type == EventTypes.Topic:
self._store_room_topic_txn(txn, event)
elif event.type == EventTypes.Message:
self._store_room_message_txn(txn, event)
elif event.type == EventTypes.Redaction:
self._store_redaction(txn, event)
elif event.type == EventTypes.RoomHistoryVisibility:
self._store_history_visibility_txn(txn, event)
elif event.type == EventTypes.GuestAccess:
self._store_guest_access_txn(txn, event)
self._store_room_members_txn(
txn,
[
event
for event, _ in events_and_contexts
if event.type == EventTypes.Member
],
backfilled=backfilled,
)
def event_dict(event): def event_dict(event):
return { return {
@ -544,6 +570,43 @@ class EventsStore(SQLBaseStore):
] ]
} }
if delete_existing:
# For paranoia reasons, we go and delete all the existing entries
# for these events so we can reinsert them.
# This gets around any problems with some tables already having
# entries.
logger.info("Deleting existing")
for table in (
"events",
"event_auth",
"event_json",
"event_content_hashes",
"event_destinations",
"event_edge_hashes",
"event_edges",
"event_forward_extremities",
"event_push_actions",
"event_reference_hashes",
"event_search",
"event_signatures",
"event_to_state_groups",
"guest_access",
"history_visibility",
"local_invites",
"room_names",
"state_events",
"rejections",
"redactions",
"room_memberships",
"state_events"
):
txn.executemany(
"DELETE FROM %s WHERE event_id = ?" % (table,),
[(ev.event_id,) for ev, _ in events_and_contexts]
)
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
table="event_json", table="event_json",
@ -576,15 +639,51 @@ class EventsStore(SQLBaseStore):
"content": encode_json(event.content).decode("UTF-8"), "content": encode_json(event.content).decode("UTF-8"),
"origin_server_ts": int(event.origin_server_ts), "origin_server_ts": int(event.origin_server_ts),
"received_ts": self._clock.time_msec(), "received_ts": self._clock.time_msec(),
"sender": event.sender,
"contains_url": (
"url" in event.content
and isinstance(event.content["url"], basestring)
),
} }
for event, _ in events_and_contexts for event, _ in events_and_contexts
], ],
) )
# Remove the rejected events from the list now that we've added them
# to the events table and the events_json table.
to_remove = set()
for event, context in events_and_contexts:
if context.rejected: if context.rejected:
# Insert the event_id into the rejections table
self._store_rejections_txn( self._store_rejections_txn(
txn, event.event_id, context.rejected txn, event.event_id, context.rejected
) )
to_remove.add(event)
events_and_contexts = [
ec for ec in events_and_contexts if ec[0] not in to_remove
]
if not events_and_contexts:
# Make sure we don't pass an empty list to functions that expect to
# be storing at least one element.
return
# From this point onwards the events are only ones that weren't rejected.
for event, context in events_and_contexts:
# Insert all the push actions into the event_push_actions table.
if context.push_actions:
self._set_push_actions_for_event_and_users_txn(
txn, event, context.push_actions
)
if event.type == EventTypes.Redaction and event.redacts is not None:
# Remove the entries in the event_push_actions table for the
# redacted event.
self._remove_push_actions_for_event_id_txn(
txn, event.room_id, event.redacts
)
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
@ -600,6 +699,49 @@ class EventsStore(SQLBaseStore):
], ],
) )
# Insert into the state_groups, state_groups_state, and
# event_to_state_groups tables.
self._store_mult_state_groups_txn(txn, events_and_contexts)
# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
self._handle_mult_prev_events(
txn,
events=[event for event, _ in events_and_contexts],
)
for event, _ in events_and_contexts:
if event.type == EventTypes.Name:
# Insert into the room_names and event_search tables.
self._store_room_name_txn(txn, event)
elif event.type == EventTypes.Topic:
# Insert into the topics table and event_search table.
self._store_room_topic_txn(txn, event)
elif event.type == EventTypes.Message:
# Insert into the event_search table.
self._store_room_message_txn(txn, event)
elif event.type == EventTypes.Redaction:
# Insert into the redactions table.
self._store_redaction(txn, event)
elif event.type == EventTypes.RoomHistoryVisibility:
# Insert into the event_search table.
self._store_history_visibility_txn(txn, event)
elif event.type == EventTypes.GuestAccess:
# Insert into the event_search table.
self._store_guest_access_txn(txn, event)
# Insert into the room_memberships table.
self._store_room_members_txn(
txn,
[
event
for event, _ in events_and_contexts
if event.type == EventTypes.Member
],
backfilled=backfilled,
)
# Insert event_reference_hashes table.
self._store_event_reference_hashes_txn( self._store_event_reference_hashes_txn(
txn, [event for event, _ in events_and_contexts] txn, [event for event, _ in events_and_contexts]
) )
@ -644,6 +786,7 @@ class EventsStore(SQLBaseStore):
], ],
) )
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts) self._add_to_cache(txn, events_and_contexts)
if backfilled: if backfilled:
@ -656,22 +799,11 @@ class EventsStore(SQLBaseStore):
# Outlier events shouldn't clobber the current state. # Outlier events shouldn't clobber the current state.
continue continue
if context.rejected:
# If the event failed it's auth checks then it shouldn't
# clobbler the current state.
continue
txn.call_after( txn.call_after(
self._get_current_state_for_key.invalidate, self._get_current_state_for_key.invalidate,
(event.room_id, event.type, event.state_key,) (event.room_id, event.type, event.state_key,)
) )
if event.type in [EventTypes.Name, EventTypes.Aliases]:
txn.call_after(
self.get_room_name_and_aliases.invalidate,
(event.room_id,)
)
self._simple_upsert_txn( self._simple_upsert_txn(
txn, txn,
"current_state_events", "current_state_events",
@ -1121,6 +1253,78 @@ class EventsStore(SQLBaseStore):
ret = yield self.runInteraction("count_messages", _count_messages) ret = yield self.runInteraction("count_messages", _count_messages)
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks
def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
INSERT_CLUMP_SIZE = 1000
def reindex_txn(txn):
sql = (
"SELECT stream_ordering, event_id, json FROM events"
" INNER JOIN event_json USING (event_id)"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
" ORDER BY stream_ordering DESC"
" LIMIT ?"
)
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
rows = txn.fetchall()
if not rows:
return 0
min_stream_id = rows[-1][0]
update_rows = []
for row in rows:
try:
event_id = row[1]
event_json = json.loads(row[2])
sender = event_json["sender"]
content = event_json["content"]
contains_url = "url" in content
if contains_url:
contains_url &= isinstance(content["url"], basestring)
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
continue
update_rows.append((sender, contains_url, event_id))
sql = (
"UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
)
for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
clump = update_rows[index:index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
"max_stream_id_exclusive": min_stream_id,
"rows_inserted": rows_inserted + len(rows)
}
self._background_update_progress_txn(
txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
)
return len(rows)
result = yield self.runInteraction(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
)
if not result:
yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_reindex_origin_server_ts(self, progress, batch_size): def _background_reindex_origin_server_ts(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"] target_min_stream_id = progress["target_min_stream_id_inclusive"]
@ -1288,6 +1492,162 @@ class EventsStore(SQLBaseStore):
) )
return self.runInteraction("get_all_new_events", get_all_new_events_txn) return self.runInteraction("get_all_new_events", get_all_new_events_txn)
def delete_old_state(self, room_id, topological_ordering):
return self.runInteraction(
"delete_old_state",
self._delete_old_state_txn, room_id, topological_ordering
)
def _delete_old_state_txn(self, txn, room_id, topological_ordering):
"""Deletes old room state
"""
# Tables that should be pruned:
# event_auth
# event_backward_extremities
# event_content_hashes
# event_destinations
# event_edge_hashes
# event_edges
# event_forward_extremities
# event_json
# event_push_actions
# event_reference_hashes
# event_search
# event_signatures
# event_to_state_groups
# events
# rejections
# room_depth
# state_groups
# state_groups_state
# First ensure that we're not about to delete all the forward extremeties
txn.execute(
"SELECT e.event_id, e.depth FROM events as e "
"INNER JOIN event_forward_extremities as f "
"ON e.event_id = f.event_id "
"AND e.room_id = f.room_id "
"WHERE f.room_id = ?",
(room_id,)
)
rows = txn.fetchall()
max_depth = max(row[0] for row in rows)
if max_depth <= topological_ordering:
# We need to ensure we don't delete all the events from the datanase
# otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties)
raise SynapseError(
400, "topological_ordering is greater than forward extremeties"
)
txn.execute(
"SELECT event_id, state_key FROM events"
" LEFT JOIN state_events USING (room_id, event_id)"
" WHERE room_id = ? AND topological_ordering < ?",
(room_id, topological_ordering,)
)
event_rows = txn.fetchall()
# We calculate the new entries for the backward extremeties by finding
# all events that point to events that are to be purged
txn.execute(
"SELECT DISTINCT e.event_id FROM events as e"
" INNER JOIN event_edges as ed ON e.event_id = ed.prev_event_id"
" INNER JOIN events as e2 ON e2.event_id = ed.event_id"
" WHERE e.room_id = ? AND e.topological_ordering < ?"
" AND e2.topological_ordering >= ?",
(room_id, topological_ordering, topological_ordering)
)
new_backwards_extrems = txn.fetchall()
txn.execute(
"DELETE FROM event_backward_extremities WHERE room_id = ?",
(room_id,)
)
# Update backward extremeties
txn.executemany(
"INSERT INTO event_backward_extremities (room_id, event_id)"
" VALUES (?, ?)",
[
(room_id, event_id) for event_id, in new_backwards_extrems
]
)
# Get all state groups that are only referenced by events that are
# to be deleted.
txn.execute(
"SELECT state_group FROM event_to_state_groups"
" INNER JOIN events USING (event_id)"
" WHERE state_group IN ("
" SELECT DISTINCT state_group FROM events"
" INNER JOIN event_to_state_groups USING (event_id)"
" WHERE room_id = ? AND topological_ordering < ?"
" )"
" GROUP BY state_group HAVING MAX(topological_ordering) < ?",
(room_id, topological_ordering, topological_ordering)
)
state_rows = txn.fetchall()
txn.executemany(
"DELETE FROM state_groups_state WHERE state_group = ?",
state_rows
)
txn.executemany(
"DELETE FROM state_groups WHERE id = ?",
state_rows
)
# Delete all non-state
txn.executemany(
"DELETE FROM event_to_state_groups WHERE event_id = ?",
[(event_id,) for event_id, _ in event_rows]
)
txn.execute(
"UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
(topological_ordering, room_id,)
)
# Delete all remote non-state events
to_delete = [
(event_id,) for event_id, state_key in event_rows
if state_key is None and not self.hs.is_mine_id(event_id)
]
for table in (
"events",
"event_json",
"event_auth",
"event_content_hashes",
"event_destinations",
"event_edge_hashes",
"event_edges",
"event_forward_extremities",
"event_push_actions",
"event_reference_hashes",
"event_search",
"event_signatures",
"rejections",
):
txn.executemany(
"DELETE FROM %s WHERE event_id = ?" % (table,),
to_delete
)
txn.executemany(
"DELETE FROM events WHERE event_id = ?",
to_delete
)
# Mark all state and own events as outliers
txn.executemany(
"UPDATE events SET outlier = ?"
" WHERE event_id = ?",
[
(True, event_id,) for event_id, state_key in event_rows
if state_key is not None or self.hs.is_mine_id(event_id)
]
)
AllNewEventsResult = namedtuple("AllNewEventsResult", [ AllNewEventsResult = namedtuple("AllNewEventsResult", [
"new_forward_events", "new_backfill_events", "new_forward_events", "new_backfill_events",

View file

@ -22,6 +22,10 @@ import OpenSSL
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
import hashlib import hashlib
import logging
logger = logging.getLogger(__name__)
class KeyStore(SQLBaseStore): class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys and tls X.509 certificates """Persistence for signature verification keys and tls X.509 certificates
@ -74,22 +78,22 @@ class KeyStore(SQLBaseStore):
) )
@cachedInlineCallbacks() @cachedInlineCallbacks()
def get_all_server_verify_keys(self, server_name): def _get_server_verify_key(self, server_name, key_id):
rows = yield self._simple_select_list( verify_key_bytes = yield self._simple_select_one_onecol(
table="server_signature_keys", table="server_signature_keys",
keyvalues={ keyvalues={
"server_name": server_name, "server_name": server_name,
"key_id": key_id,
}, },
retcols=["key_id", "verify_key"], retcol="verify_key",
desc="get_all_server_verify_keys", desc="_get_server_verify_key",
allow_none=True,
) )
defer.returnValue({ if verify_key_bytes:
row["key_id"]: decode_verify_key_bytes( defer.returnValue(decode_verify_key_bytes(
row["key_id"], str(row["verify_key"]) key_id, str(verify_key_bytes)
) ))
for row in rows
})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_keys(self, server_name, key_ids): def get_server_verify_keys(self, server_name, key_ids):
@ -101,12 +105,12 @@ class KeyStore(SQLBaseStore):
Returns: Returns:
(list of VerifyKey): The verification keys. (list of VerifyKey): The verification keys.
""" """
keys = yield self.get_all_server_verify_keys(server_name) keys = {}
defer.returnValue({ for key_id in key_ids:
k: keys[k] key = yield self._get_server_verify_key(server_name, key_id)
for k in key_ids if key:
if k in keys and keys[k] keys[key_id] = key
}) defer.returnValue(keys)
@defer.inlineCallbacks @defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms, def store_server_verify_key(self, server_name, from_server, time_now_ms,
@ -133,8 +137,6 @@ class KeyStore(SQLBaseStore):
desc="store_server_verify_key", desc="store_server_verify_key",
) )
self.get_all_server_verify_keys.invalidate((server_name,))
def store_server_keys_json(self, server_name, key_id, from_server, def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes): ts_now_ms, ts_expires_ms, key_json_bytes):
"""Stores the JSON bytes for a set of keys from a server """Stores the JSON bytes for a set of keys from a server

View file

@ -157,10 +157,25 @@ class MediaRepositoryStore(SQLBaseStore):
"created_ts": time_now_ms, "created_ts": time_now_ms,
"upload_name": upload_name, "upload_name": upload_name,
"filesystem_id": filesystem_id, "filesystem_id": filesystem_id,
"last_access_ts": time_now_ms,
}, },
desc="store_cached_remote_media", desc="store_cached_remote_media",
) )
def update_cached_last_access_time(self, origin_id_tuples, time_ts):
def update_cache_txn(txn):
sql = (
"UPDATE remote_media_cache SET last_access_ts = ?"
" WHERE media_origin = ? AND media_id = ?"
)
txn.executemany(sql, (
(time_ts, media_origin, media_id)
for media_origin, media_id in origin_id_tuples
))
return self.runInteraction("update_cached_last_access_time", update_cache_txn)
def get_remote_media_thumbnails(self, origin, media_id): def get_remote_media_thumbnails(self, origin, media_id):
return self._simple_select_list( return self._simple_select_list(
"remote_media_cache_thumbnails", "remote_media_cache_thumbnails",
@ -190,3 +205,32 @@ class MediaRepositoryStore(SQLBaseStore):
}, },
desc="store_remote_media_thumbnail", desc="store_remote_media_thumbnail",
) )
def get_remote_media_before(self, before_ts):
sql = (
"SELECT media_origin, media_id, filesystem_id"
" FROM remote_media_cache"
" WHERE last_access_ts < ?"
)
return self._execute(
"get_remote_media_before", self.cursor_to_dict, sql, before_ts
)
def delete_remote_media(self, media_origin, media_id):
def delete_remote_media_txn(txn):
self._simple_delete_txn(
txn,
"remote_media_cache",
keyvalues={
"media_origin": media_origin, "media_id": media_id
},
)
self._simple_delete_txn(
txn,
"remote_media_cache_thumbnails",
keyvalues={
"media_origin": media_origin, "media_id": media_id
},
)
return self.runInteraction("delete_remote_media", delete_remote_media_txn)

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 32 SCHEMA_VERSION = 33
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -18,25 +18,40 @@ import re
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError, Codes from synapse.api.errors import StoreError, Codes
from synapse.storage import background_updates
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
class RegistrationStore(SQLBaseStore): class RegistrationStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs): def __init__(self, hs):
super(RegistrationStore, self).__init__(hs) super(RegistrationStore, self).__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.register_background_index_update(
"access_tokens_device_index",
index_name="access_tokens_device_id",
table="access_tokens",
columns=["user_id", "device_id"],
)
self.register_background_index_update(
"refresh_tokens_device_index",
index_name="refresh_tokens_device_id",
table="refresh_tokens",
columns=["user_id", "device_id"],
)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token): def add_access_token_to_user(self, user_id, token, device_id=None):
"""Adds an access token for the given user. """Adds an access token for the given user.
Args: Args:
user_id (str): The user ID. user_id (str): The user ID.
token (str): The new access token to add. token (str): The new access token to add.
device_id (str): ID of the device to associate with the access
token
Raises: Raises:
StoreError if there was a problem adding this. StoreError if there was a problem adding this.
""" """
@ -47,18 +62,21 @@ class RegistrationStore(SQLBaseStore):
{ {
"id": next_id, "id": next_id,
"user_id": user_id, "user_id": user_id,
"token": token "token": token,
"device_id": device_id,
}, },
desc="add_access_token_to_user", desc="add_access_token_to_user",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def add_refresh_token_to_user(self, user_id, token): def add_refresh_token_to_user(self, user_id, token, device_id=None):
"""Adds a refresh token for the given user. """Adds a refresh token for the given user.
Args: Args:
user_id (str): The user ID. user_id (str): The user ID.
token (str): The new refresh token to add. token (str): The new refresh token to add.
device_id (str): ID of the device to associate with the access
token
Raises: Raises:
StoreError if there was a problem adding this. StoreError if there was a problem adding this.
""" """
@ -69,20 +87,23 @@ class RegistrationStore(SQLBaseStore):
{ {
"id": next_id, "id": next_id,
"user_id": user_id, "user_id": user_id,
"token": token "token": token,
"device_id": device_id,
}, },
desc="add_refresh_token_to_user", desc="add_refresh_token_to_user",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def register(self, user_id, token, password_hash, def register(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None, was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_localpart=None): create_profile_with_localpart=None, admin=False):
"""Attempts to register an account. """Attempts to register an account.
Args: Args:
user_id (str): The desired user ID to register. user_id (str): The desired user ID to register.
token (str): The desired access token to use for this user. token (str): The desired access token to use for this user. If this
is not None, the given access token is associated with the user
id.
password_hash (str): Optional. The password hash for this user. password_hash (str): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account. upgraded to a non-guest account.
@ -104,6 +125,7 @@ class RegistrationStore(SQLBaseStore):
make_guest, make_guest,
appservice_id, appservice_id,
create_profile_with_localpart, create_profile_with_localpart,
admin
) )
self.get_user_by_id.invalidate((user_id,)) self.get_user_by_id.invalidate((user_id,))
self.is_guest.invalidate((user_id,)) self.is_guest.invalidate((user_id,))
@ -118,6 +140,7 @@ class RegistrationStore(SQLBaseStore):
make_guest, make_guest,
appservice_id, appservice_id,
create_profile_with_localpart, create_profile_with_localpart,
admin,
): ):
now = int(self.clock.time()) now = int(self.clock.time())
@ -125,29 +148,48 @@ class RegistrationStore(SQLBaseStore):
try: try:
if was_guest: if was_guest:
txn.execute("UPDATE users SET" # Ensure that the guest user actually exists
" password_hash = ?," # ``allow_none=False`` makes this raise an exception
" upgrade_ts = ?," # if the row isn't in the database.
" is_guest = ?" self._simple_select_one_txn(
" WHERE name = ?", txn,
[password_hash, now, 1 if make_guest else 0, user_id]) "users",
keyvalues={
"name": user_id,
"is_guest": 1,
},
retcols=("name",),
allow_none=False,
)
self._simple_update_one_txn(
txn,
"users",
keyvalues={
"name": user_id,
"is_guest": 1,
},
updatevalues={
"password_hash": password_hash,
"upgrade_ts": now,
"is_guest": 1 if make_guest else 0,
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
}
)
else: else:
txn.execute("INSERT INTO users " self._simple_insert_txn(
"(" txn,
" name," "users",
" password_hash," values={
" creation_ts," "name": user_id,
" is_guest," "password_hash": password_hash,
" appservice_id" "creation_ts": now,
") " "is_guest": 1 if make_guest else 0,
"VALUES (?,?,?,?,?)", "appservice_id": appservice_id,
[ "admin": 1 if admin else 0,
user_id, }
password_hash, )
now,
1 if make_guest else 0,
appservice_id,
])
except self.database_engine.module.IntegrityError: except self.database_engine.module.IntegrityError:
raise StoreError( raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE 400, "User ID already taken.", errcode=Codes.USER_IN_USE
@ -209,16 +251,37 @@ class RegistrationStore(SQLBaseStore):
self.get_user_by_id.invalidate((user_id,)) self.get_user_by_id.invalidate((user_id,))
@defer.inlineCallbacks @defer.inlineCallbacks
def user_delete_access_tokens(self, user_id, except_token_ids=[]): def user_delete_access_tokens(self, user_id, except_token_ids=[],
def f(txn): device_id=None,
sql = "SELECT token FROM access_tokens WHERE user_id = ?" delete_refresh_tokens=False):
"""
Invalidate access/refresh tokens belonging to a user
Args:
user_id (str): ID of user the tokens belong to
except_token_ids (list[str]): list of access_tokens which should
*not* be deleted
device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
delete_refresh_tokens (bool): True to delete refresh tokens as
well as access tokens.
Returns:
defer.Deferred:
"""
def f(txn, table, except_tokens, call_after_delete):
sql = "SELECT token FROM %s WHERE user_id = ?" % table
clauses = [user_id] clauses = [user_id]
if except_token_ids: if device_id is not None:
sql += " AND device_id = ?"
clauses.append(device_id)
if except_tokens:
sql += " AND id NOT IN (%s)" % ( sql += " AND id NOT IN (%s)" % (
",".join(["?" for _ in except_token_ids]), ",".join(["?" for _ in except_tokens]),
) )
clauses += except_token_ids clauses += except_tokens
txn.execute(sql, clauses) txn.execute(sql, clauses)
@ -227,16 +290,33 @@ class RegistrationStore(SQLBaseStore):
n = 100 n = 100
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)] chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
for chunk in chunks: for chunk in chunks:
if call_after_delete:
for row in chunk: for row in chunk:
txn.call_after(self.get_user_by_access_token.invalidate, (row[0],)) txn.call_after(call_after_delete, (row[0],))
txn.execute( txn.execute(
"DELETE FROM access_tokens WHERE token in (%s)" % ( "DELETE FROM %s WHERE token in (%s)" % (
table,
",".join(["?" for _ in chunk]), ",".join(["?" for _ in chunk]),
), [r[0] for r in chunk] ), [r[0] for r in chunk]
) )
yield self.runInteraction("user_delete_access_tokens", f) # delete refresh tokens first, to stop new access tokens being
# allocated while our backs are turned
if delete_refresh_tokens:
yield self.runInteraction(
"user_delete_access_tokens", f,
table="refresh_tokens",
except_tokens=[],
call_after_delete=None,
)
yield self.runInteraction(
"user_delete_access_tokens", f,
table="access_tokens",
except_tokens=except_token_ids,
call_after_delete=self.get_user_by_access_token.invalidate,
)
def delete_access_token(self, access_token): def delete_access_token(self, access_token):
def f(txn): def f(txn):
@ -259,9 +339,8 @@ class RegistrationStore(SQLBaseStore):
Args: Args:
token (str): The access token of a user. token (str): The access token of a user.
Returns: Returns:
dict: Including the name (user_id) and the ID of their access token. defer.Deferred: None, if the token did not match, otherwise dict
Raises: including the keys `name`, `is_guest`, `device_id`, `token_id`.
StoreError if no user was found.
""" """
return self.runInteraction( return self.runInteraction(
"get_user_by_access_token", "get_user_by_access_token",
@ -270,18 +349,18 @@ class RegistrationStore(SQLBaseStore):
) )
def exchange_refresh_token(self, refresh_token, token_generator): def exchange_refresh_token(self, refresh_token, token_generator):
"""Exchange a refresh token for a new access token and refresh token. """Exchange a refresh token for a new one.
Doing so invalidates the old refresh token - refresh tokens are single Doing so invalidates the old refresh token - refresh tokens are single
use. use.
Args: Args:
token (str): The refresh token of a user. refresh_token (str): The refresh token of a user.
token_generator (fn: str -> str): Function which, when given a token_generator (fn: str -> str): Function which, when given a
user ID, returns a unique refresh token for that user. This user ID, returns a unique refresh token for that user. This
function must never return the same value twice. function must never return the same value twice.
Returns: Returns:
tuple of (user_id, refresh_token) tuple of (user_id, new_refresh_token, device_id)
Raises: Raises:
StoreError if no user was found with that refresh token. StoreError if no user was found with that refresh token.
""" """
@ -293,12 +372,13 @@ class RegistrationStore(SQLBaseStore):
) )
def _exchange_refresh_token(self, txn, old_token, token_generator): def _exchange_refresh_token(self, txn, old_token, token_generator):
sql = "SELECT user_id FROM refresh_tokens WHERE token = ?" sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?"
txn.execute(sql, (old_token,)) txn.execute(sql, (old_token,))
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
if not rows: if not rows:
raise StoreError(403, "Did not recognize refresh token") raise StoreError(403, "Did not recognize refresh token")
user_id = rows[0]["user_id"] user_id = rows[0]["user_id"]
device_id = rows[0]["device_id"]
# TODO(danielwh): Maybe perform a validation on the macaroon that # TODO(danielwh): Maybe perform a validation on the macaroon that
# macaroon.user_id == user_id. # macaroon.user_id == user_id.
@ -307,7 +387,7 @@ class RegistrationStore(SQLBaseStore):
sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?" sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
txn.execute(sql, (new_token, old_token,)) txn.execute(sql, (new_token, old_token,))
return user_id, new_token return user_id, new_token, device_id
@defer.inlineCallbacks @defer.inlineCallbacks
def is_server_admin(self, user): def is_server_admin(self, user):
@ -335,7 +415,8 @@ class RegistrationStore(SQLBaseStore):
def _query_for_auth(self, txn, token): def _query_for_auth(self, txn, token):
sql = ( sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id" "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
" access_tokens.device_id"
" FROM users" " FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id" " INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?" " WHERE token = ?"
@ -384,6 +465,15 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(ret['user_id']) defer.returnValue(ret['user_id'])
defer.returnValue(None) defer.returnValue(None)
def user_delete_threepids(self, user_id):
return self._simple_delete(
"user_threepids",
keyvalues={
"user_id": user_id,
},
desc="user_delete_threepids",
)
@defer.inlineCallbacks @defer.inlineCallbacks
def count_all_users(self): def count_all_users(self):
"""Counts all users registered on the homeserver.""" """Counts all users registered on the homeserver."""

View file

@ -18,7 +18,6 @@ from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
from .engines import PostgresEngine, Sqlite3Engine from .engines import PostgresEngine, Sqlite3Engine
import collections import collections
@ -192,49 +191,6 @@ class RoomStore(SQLBaseStore):
# This should be unreachable. # This should be unreachable.
raise Exception("Unrecognized database engine") raise Exception("Unrecognized database engine")
@cachedInlineCallbacks()
def get_room_name_and_aliases(self, room_id):
def get_room_name(txn):
sql = (
"SELECT name FROM room_names"
" INNER JOIN current_state_events USING (room_id, event_id)"
" WHERE room_id = ?"
" LIMIT 1"
)
txn.execute(sql, (room_id,))
rows = txn.fetchall()
if rows:
return rows[0][0]
else:
return None
return [row[0] for row in txn.fetchall()]
def get_room_aliases(txn):
sql = (
"SELECT content FROM current_state_events"
" INNER JOIN events USING (room_id, event_id)"
" WHERE room_id = ?"
)
txn.execute(sql, (room_id,))
return [row[0] for row in txn.fetchall()]
name = yield self.runInteraction("get_room_name", get_room_name)
alias_contents = yield self.runInteraction("get_room_aliases", get_room_aliases)
aliases = []
for c in alias_contents:
try:
content = json.loads(c)
except:
continue
aliases.extend(content.get('aliases', []))
defer.returnValue((name, aliases))
def add_event_report(self, room_id, event_id, user_id, reason, content, def add_event_report(self, room_id, event_id, user_id, reason, content,
received_ts): received_ts):
next_id = self._event_reports_id_gen.get_next() next_id = self._event_reports_id_gen.get_next()

View file

@ -0,0 +1,17 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (update_name, progress_json) VALUES
('access_tokens_device_index', '{}');

View file

@ -0,0 +1,21 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE devices (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
display_name TEXT,
CONSTRAINT device_uniqueness UNIQUE (user_id, device_id)
);

View file

@ -0,0 +1,19 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- make sure that we have a device record for each set of E2E keys, so that the
-- user can delete them if they like.
INSERT INTO devices
SELECT user_id, device_id, NULL FROM e2e_device_keys_json;

Some files were not shown because too many files have changed in this diff Show more