mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-28 23:20:09 +03:00
Merge branch 'release-v0.17.0' of github.com:matrix-org/synapse
This commit is contained in:
commit
d330d45e2d
136 changed files with 5946 additions and 1657 deletions
122
CHANGES.rst
122
CHANGES.rst
|
@ -1,3 +1,125 @@
|
||||||
|
Changes in synapse v0.17.0 (2016-08-08)
|
||||||
|
=======================================
|
||||||
|
|
||||||
|
This release contains significant security bug fixes regarding authenticating
|
||||||
|
events received over federation. PLEASE UPGRADE.
|
||||||
|
|
||||||
|
This release changes the LDAP configuration format in a backwards incompatible
|
||||||
|
way, see PR #843 for details.
|
||||||
|
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
* Add federation /version API (PR #990)
|
||||||
|
* Make psutil dependency optional (PR #992)
|
||||||
|
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Fix URL preview API to exclude HTML comments in description (PR #988)
|
||||||
|
* Fix error handling of remote joins (PR #991)
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.17.0-rc4 (2016-08-05)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
* Change the way we summarize URLs when previewing (PR #973)
|
||||||
|
* Add new ``/state_ids/`` federation API (PR #979)
|
||||||
|
* Speed up processing of ``/state/`` response (PR #986)
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Fix event persistence when event has already been partially persisted
|
||||||
|
(PR #975, #983, #985)
|
||||||
|
* Fix port script to also copy across backfilled events (PR #982)
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.17.0-rc3 (2016-08-02)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
* Forbid non-ASes from registering users whose names begin with '_' (PR #958)
|
||||||
|
* Add some basic admin API docs (PR #963)
|
||||||
|
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Send the correct host header when fetching keys (PR #941)
|
||||||
|
* Fix joining a room that has missing auth events (PR #964)
|
||||||
|
* Fix various push bugs (PR #966, #970)
|
||||||
|
* Fix adding emails on registration (PR #968)
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.17.0-rc2 (2016-08-02)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
(This release did not include the changes advertised and was identical to RC1)
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.17.0-rc1 (2016-07-28)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
This release changes the LDAP configuration format in a backwards incompatible
|
||||||
|
way, see PR #843 for details.
|
||||||
|
|
||||||
|
|
||||||
|
Features:
|
||||||
|
|
||||||
|
* Add purge_media_cache admin API (PR #902)
|
||||||
|
* Add deactivate account admin API (PR #903)
|
||||||
|
* Add optional pepper to password hashing (PR #907, #910 by KentShikama)
|
||||||
|
* Add an admin option to shared secret registration (breaks backwards compat)
|
||||||
|
(PR #909)
|
||||||
|
* Add purge local room history API (PR #911, #923, #924)
|
||||||
|
* Add requestToken endpoints (PR #915)
|
||||||
|
* Add an /account/deactivate endpoint (PR #921)
|
||||||
|
* Add filter param to /messages. Add 'contains_url' to filter. (PR #922)
|
||||||
|
* Add device_id support to /login (PR #929)
|
||||||
|
* Add device_id support to /v2/register flow. (PR #937, #942)
|
||||||
|
* Add GET /devices endpoint (PR #939, #944)
|
||||||
|
* Add GET /device/{deviceId} (PR #943)
|
||||||
|
* Add update and delete APIs for devices (PR #949)
|
||||||
|
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
* Rewrite LDAP Authentication against ldap3 (PR #843 by mweinelt)
|
||||||
|
* Linearize some federation endpoints based on (origin, room_id) (PR #879)
|
||||||
|
* Remove the legacy v0 content upload API. (PR #888)
|
||||||
|
* Use similar naming we use in email notifs for push (PR #894)
|
||||||
|
* Optionally include password hash in createUser endpoint (PR #905 by
|
||||||
|
KentShikama)
|
||||||
|
* Use a query that postgresql optimises better for get_events_around (PR #906)
|
||||||
|
* Fall back to 'username' if 'user' is not given for appservice registration.
|
||||||
|
(PR #927 by Half-Shot)
|
||||||
|
* Add metrics for psutil derived memory usage (PR #936)
|
||||||
|
* Record device_id in client_ips (PR #938)
|
||||||
|
* Send the correct host header when fetching keys (PR #941)
|
||||||
|
* Log the hostname the reCAPTCHA was completed on (PR #946)
|
||||||
|
* Make the device id on e2e key upload optional (PR #956)
|
||||||
|
* Add r0.2.0 to the "supported versions" list (PR #960)
|
||||||
|
* Don't include name of room for invites in push (PR #961)
|
||||||
|
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Fix substitution failure in mail template (PR #887)
|
||||||
|
* Put most recent 20 messages in email notif (PR #892)
|
||||||
|
* Ensure that the guest user is in the database when upgrading accounts
|
||||||
|
(PR #914)
|
||||||
|
* Fix various edge cases in auth handling (PR #919)
|
||||||
|
* Fix 500 ISE when sending alias event without a state_key (PR #925)
|
||||||
|
* Fix bug where we stored rejections in the state_group, persist all
|
||||||
|
rejections (PR #948)
|
||||||
|
* Fix lack of check of if the user is banned when handling 3pid invites
|
||||||
|
(PR #952)
|
||||||
|
* Fix a couple of bugs in the transaction and keyring code (PR #954, #955)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Changes in synapse v0.16.1-r1 (2016-07-08)
|
Changes in synapse v0.16.1-r1 (2016-07-08)
|
||||||
==========================================
|
==========================================
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ recursive-include docs *
|
||||||
recursive-include res *
|
recursive-include res *
|
||||||
recursive-include scripts *
|
recursive-include scripts *
|
||||||
recursive-include scripts-dev *
|
recursive-include scripts-dev *
|
||||||
|
recursive-include synapse *.pyi
|
||||||
recursive-include tests *.py
|
recursive-include tests *.py
|
||||||
|
|
||||||
recursive-include synapse/static *.css
|
recursive-include synapse/static *.css
|
||||||
|
@ -23,5 +24,7 @@ recursive-include synapse/static *.js
|
||||||
|
|
||||||
exclude jenkins.sh
|
exclude jenkins.sh
|
||||||
exclude jenkins*.sh
|
exclude jenkins*.sh
|
||||||
|
exclude jenkins*
|
||||||
|
recursive-exclude jenkins *.sh
|
||||||
|
|
||||||
prune demo/etc
|
prune demo/etc
|
||||||
|
|
|
@ -445,7 +445,7 @@ You have two choices here, which will influence the form of your Matrix user
|
||||||
IDs:
|
IDs:
|
||||||
|
|
||||||
1) Use the machine's own hostname as available on public DNS in the form of
|
1) Use the machine's own hostname as available on public DNS in the form of
|
||||||
its A or AAAA records. This is easier to set up initially, perhaps for
|
its A records. This is easier to set up initially, perhaps for
|
||||||
testing, but lacks the flexibility of SRV.
|
testing, but lacks the flexibility of SRV.
|
||||||
|
|
||||||
2) Set up a SRV record for your domain name. This requires you create a SRV
|
2) Set up a SRV record for your domain name. This requires you create a SRV
|
||||||
|
|
|
@ -27,7 +27,7 @@ running:
|
||||||
# Pull the latest version of the master branch.
|
# Pull the latest version of the master branch.
|
||||||
git pull
|
git pull
|
||||||
# Update the versions of synapse's python dependencies.
|
# Update the versions of synapse's python dependencies.
|
||||||
python synapse/python_dependencies.py | xargs -n1 pip install
|
python synapse/python_dependencies.py | xargs -n1 pip install --upgrade
|
||||||
|
|
||||||
|
|
||||||
Upgrading to v0.15.0
|
Upgrading to v0.15.0
|
||||||
|
|
12
docs/admin_api/README.rst
Normal file
12
docs/admin_api/README.rst
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
Admin APIs
|
||||||
|
==========
|
||||||
|
|
||||||
|
This directory includes documentation for the various synapse specific admin
|
||||||
|
APIs available.
|
||||||
|
|
||||||
|
Only users that are server admins can use these APIs. A user can be marked as a
|
||||||
|
server admin by updating the database directly, e.g.:
|
||||||
|
|
||||||
|
``UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'``
|
||||||
|
|
||||||
|
Restarting may be required for the changes to register.
|
15
docs/admin_api/purge_history_api.rst
Normal file
15
docs/admin_api/purge_history_api.rst
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
Purge History API
|
||||||
|
=================
|
||||||
|
|
||||||
|
The purge history API allows server admins to purge historic events from their
|
||||||
|
database, reclaiming disk space.
|
||||||
|
|
||||||
|
Depending on the amount of history being purged a call to the API may take
|
||||||
|
several minutes or longer. During this period users will not be able to
|
||||||
|
paginate further back in the room from the point being purged from.
|
||||||
|
|
||||||
|
The API is simply:
|
||||||
|
|
||||||
|
``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>``
|
||||||
|
|
||||||
|
including an ``access_token`` of a server admin.
|
19
docs/admin_api/purge_remote_media.rst
Normal file
19
docs/admin_api/purge_remote_media.rst
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
Purge Remote Media API
|
||||||
|
======================
|
||||||
|
|
||||||
|
The purge remote media API allows server admins to purge old cached remote
|
||||||
|
media.
|
||||||
|
|
||||||
|
The API is::
|
||||||
|
|
||||||
|
POST /_matrix/client/r0/admin/purge_media_cache
|
||||||
|
|
||||||
|
{
|
||||||
|
"before_ts": <unix_timestamp_in_ms>
|
||||||
|
}
|
||||||
|
|
||||||
|
Which will remove all cached media that was last accessed before
|
||||||
|
``<unix_timestamp_in_ms>``.
|
||||||
|
|
||||||
|
If the user re-requests purged remote media, synapse will re-request the media
|
||||||
|
from the originating server.
|
|
@ -43,7 +43,10 @@ Basically, PEP8
|
||||||
together, or want to deliberately extend or preserve vertical/horizontal
|
together, or want to deliberately extend or preserve vertical/horizontal
|
||||||
space)
|
space)
|
||||||
|
|
||||||
Comments should follow the google code style. This is so that we can generate
|
Comments should follow the `google code style <http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_.
|
||||||
documentation with sphinx (http://sphinxcontrib-napoleon.readthedocs.org/en/latest/)
|
This is so that we can generate documentation with
|
||||||
|
`sphinx <http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the
|
||||||
|
`examples <http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_
|
||||||
|
in the sphinx documentation.
|
||||||
|
|
||||||
Code should pass pep8 --max-line-length=100 without any warnings.
|
Code should pass pep8 --max-line-length=100 without any warnings.
|
||||||
|
|
|
@ -9,31 +9,35 @@ the Home Server to generate credentials that are valid for use on the TURN
|
||||||
server through the use of a secret shared between the Home Server and the
|
server through the use of a secret shared between the Home Server and the
|
||||||
TURN server.
|
TURN server.
|
||||||
|
|
||||||
This document described how to install coturn
|
This document describes how to install coturn
|
||||||
(https://code.google.com/p/coturn/) which also supports the TURN REST API,
|
(https://github.com/coturn/coturn) which also supports the TURN REST API,
|
||||||
and integrate it with synapse.
|
and integrate it with synapse.
|
||||||
|
|
||||||
coturn Setup
|
coturn Setup
|
||||||
============
|
============
|
||||||
|
|
||||||
|
You may be able to setup coturn via your package manager, or set it up manually using the usual ``configure, make, make install`` process.
|
||||||
|
|
||||||
1. Check out coturn::
|
1. Check out coturn::
|
||||||
svn checkout http://coturn.googlecode.com/svn/trunk/ coturn
|
|
||||||
|
git clone https://github.com/coturn/coturn.git coturn
|
||||||
cd coturn
|
cd coturn
|
||||||
|
|
||||||
2. Configure it::
|
2. Configure it::
|
||||||
|
|
||||||
./configure
|
./configure
|
||||||
|
|
||||||
You may need to install libevent2: if so, you should do so
|
You may need to install ``libevent2``: if so, you should do so
|
||||||
in the way recommended by your operating system.
|
in the way recommended by your operating system.
|
||||||
You can ignore warnings about lack of database support: a
|
You can ignore warnings about lack of database support: a
|
||||||
database is unnecessary for this purpose.
|
database is unnecessary for this purpose.
|
||||||
|
|
||||||
3. Build and install it::
|
3. Build and install it::
|
||||||
|
|
||||||
make
|
make
|
||||||
make install
|
make install
|
||||||
|
|
||||||
4. Make a config file in /etc/turnserver.conf. You can customise
|
4. Create or edit the config file in ``/etc/turnserver.conf``. The relevant
|
||||||
a config file from turnserver.conf.default. The relevant
|
|
||||||
lines, with example values, are::
|
lines, with example values, are::
|
||||||
|
|
||||||
lt-cred-mech
|
lt-cred-mech
|
||||||
|
@ -41,7 +45,7 @@ coturn Setup
|
||||||
static-auth-secret=[your secret key here]
|
static-auth-secret=[your secret key here]
|
||||||
realm=turn.myserver.org
|
realm=turn.myserver.org
|
||||||
|
|
||||||
See turnserver.conf.default for explanations of the options.
|
See turnserver.conf for explanations of the options.
|
||||||
One way to generate the static-auth-secret is with pwgen::
|
One way to generate the static-auth-secret is with pwgen::
|
||||||
|
|
||||||
pwgen -s 64 1
|
pwgen -s 64 1
|
||||||
|
@ -54,6 +58,7 @@ coturn Setup
|
||||||
import your private key and certificate.
|
import your private key and certificate.
|
||||||
|
|
||||||
7. Start the turn server::
|
7. Start the turn server::
|
||||||
|
|
||||||
bin/turnserver -o
|
bin/turnserver -o
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,83 +4,19 @@ set -eux
|
||||||
|
|
||||||
: ${WORKSPACE:="$(pwd)"}
|
: ${WORKSPACE:="$(pwd)"}
|
||||||
|
|
||||||
|
export WORKSPACE
|
||||||
export PYTHONDONTWRITEBYTECODE=yep
|
export PYTHONDONTWRITEBYTECODE=yep
|
||||||
export SYNAPSE_CACHE_FACTOR=1
|
export SYNAPSE_CACHE_FACTOR=1
|
||||||
|
|
||||||
# Output test results as junit xml
|
./jenkins/prepare_synapse.sh
|
||||||
export TRIAL_FLAGS="--reporter=subunit"
|
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
|
||||||
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
|
./jenkins/clone.sh dendron https://github.com/matrix-org/dendron.git
|
||||||
# Write coverage reports to a separate file for each process
|
./dendron/jenkins/build_dendron.sh
|
||||||
export COVERAGE_OPTS="-p"
|
./sytest/jenkins/prep_sytest_for_postgres.sh
|
||||||
export DUMP_COVERAGE_COMMAND="coverage help"
|
|
||||||
|
|
||||||
# Output flake8 violations to violations.flake8.log
|
./sytest/jenkins/install_and_run.sh \
|
||||||
# Don't exit with non-0 status code on Jenkins,
|
|
||||||
# so that the build steps continue and a later step can decided whether to
|
|
||||||
# UNSTABLE or FAILURE this build.
|
|
||||||
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
|
|
||||||
|
|
||||||
rm .coverage* || echo "No coverage files to remove"
|
|
||||||
|
|
||||||
tox --notest -e py27
|
|
||||||
|
|
||||||
TOX_BIN=$WORKSPACE/.tox/py27/bin
|
|
||||||
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
|
||||||
$TOX_BIN/pip install psycopg2
|
|
||||||
$TOX_BIN/pip install lxml
|
|
||||||
|
|
||||||
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
|
|
||||||
|
|
||||||
if [[ ! -e .dendron-base ]]; then
|
|
||||||
git clone https://github.com/matrix-org/dendron.git .dendron-base --mirror
|
|
||||||
else
|
|
||||||
(cd .dendron-base; git fetch -p)
|
|
||||||
fi
|
|
||||||
|
|
||||||
rm -rf dendron
|
|
||||||
git clone .dendron-base dendron --shared
|
|
||||||
cd dendron
|
|
||||||
|
|
||||||
: ${GOPATH:=${WORKSPACE}/.gopath}
|
|
||||||
if [[ "${GOPATH}" != *:* ]]; then
|
|
||||||
mkdir -p "${GOPATH}"
|
|
||||||
export PATH="${GOPATH}/bin:${PATH}"
|
|
||||||
fi
|
|
||||||
export GOPATH
|
|
||||||
|
|
||||||
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
|
|
||||||
|
|
||||||
go get github.com/constabulary/gb/...
|
|
||||||
gb generate
|
|
||||||
gb build
|
|
||||||
|
|
||||||
cd ..
|
|
||||||
|
|
||||||
|
|
||||||
if [[ ! -e .sytest-base ]]; then
|
|
||||||
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
|
|
||||||
else
|
|
||||||
(cd .sytest-base; git fetch -p)
|
|
||||||
fi
|
|
||||||
|
|
||||||
rm -rf sytest
|
|
||||||
git clone .sytest-base sytest --shared
|
|
||||||
cd sytest
|
|
||||||
|
|
||||||
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
|
|
||||||
|
|
||||||
: ${PORT_BASE:=8000}
|
|
||||||
|
|
||||||
./jenkins/prep_sytest_for_postgres.sh
|
|
||||||
|
|
||||||
mkdir -p var
|
|
||||||
|
|
||||||
echo >&2 "Running sytest with PostgreSQL";
|
|
||||||
./jenkins/install_and_run.sh --python $TOX_BIN/python \
|
|
||||||
--synapse-directory $WORKSPACE \
|
--synapse-directory $WORKSPACE \
|
||||||
--dendron $WORKSPACE/dendron/bin/dendron \
|
--dendron $WORKSPACE/dendron/bin/dendron \
|
||||||
--pusher \
|
--pusher \
|
||||||
--synchrotron \
|
--synchrotron \
|
||||||
--port-base $PORT_BASE
|
--federation-reader \
|
||||||
|
|
||||||
cd ..
|
|
||||||
|
|
|
@ -4,60 +4,14 @@ set -eux
|
||||||
|
|
||||||
: ${WORKSPACE:="$(pwd)"}
|
: ${WORKSPACE:="$(pwd)"}
|
||||||
|
|
||||||
|
export WORKSPACE
|
||||||
export PYTHONDONTWRITEBYTECODE=yep
|
export PYTHONDONTWRITEBYTECODE=yep
|
||||||
export SYNAPSE_CACHE_FACTOR=1
|
export SYNAPSE_CACHE_FACTOR=1
|
||||||
|
|
||||||
# Output test results as junit xml
|
./jenkins/prepare_synapse.sh
|
||||||
export TRIAL_FLAGS="--reporter=subunit"
|
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
|
||||||
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
|
|
||||||
# Write coverage reports to a separate file for each process
|
|
||||||
export COVERAGE_OPTS="-p"
|
|
||||||
export DUMP_COVERAGE_COMMAND="coverage help"
|
|
||||||
|
|
||||||
# Output flake8 violations to violations.flake8.log
|
./sytest/jenkins/prep_sytest_for_postgres.sh
|
||||||
# Don't exit with non-0 status code on Jenkins,
|
|
||||||
# so that the build steps continue and a later step can decided whether to
|
|
||||||
# UNSTABLE or FAILURE this build.
|
|
||||||
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
|
|
||||||
|
|
||||||
rm .coverage* || echo "No coverage files to remove"
|
./sytest/jenkins/install_and_run.sh \
|
||||||
|
|
||||||
tox --notest -e py27
|
|
||||||
|
|
||||||
TOX_BIN=$WORKSPACE/.tox/py27/bin
|
|
||||||
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
|
||||||
$TOX_BIN/pip install psycopg2
|
|
||||||
$TOX_BIN/pip install lxml
|
|
||||||
|
|
||||||
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
|
|
||||||
|
|
||||||
if [[ ! -e .sytest-base ]]; then
|
|
||||||
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
|
|
||||||
else
|
|
||||||
(cd .sytest-base; git fetch -p)
|
|
||||||
fi
|
|
||||||
|
|
||||||
rm -rf sytest
|
|
||||||
git clone .sytest-base sytest --shared
|
|
||||||
cd sytest
|
|
||||||
|
|
||||||
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
|
|
||||||
|
|
||||||
: ${PORT_BASE:=8000}
|
|
||||||
|
|
||||||
./jenkins/prep_sytest_for_postgres.sh
|
|
||||||
|
|
||||||
echo >&2 "Running sytest with PostgreSQL";
|
|
||||||
./jenkins/install_and_run.sh --coverage \
|
|
||||||
--python $TOX_BIN/python \
|
|
||||||
--synapse-directory $WORKSPACE \
|
--synapse-directory $WORKSPACE \
|
||||||
--port-base $PORT_BASE
|
|
||||||
|
|
||||||
cd ..
|
|
||||||
cp sytest/.coverage.* .
|
|
||||||
|
|
||||||
# Combine the coverage reports
|
|
||||||
echo "Combining:" .coverage.*
|
|
||||||
$TOX_BIN/python -m coverage combine
|
|
||||||
# Output coverage to coverage.xml
|
|
||||||
$TOX_BIN/coverage xml -o coverage.xml
|
|
||||||
|
|
|
@ -4,54 +4,12 @@ set -eux
|
||||||
|
|
||||||
: ${WORKSPACE:="$(pwd)"}
|
: ${WORKSPACE:="$(pwd)"}
|
||||||
|
|
||||||
|
export WORKSPACE
|
||||||
export PYTHONDONTWRITEBYTECODE=yep
|
export PYTHONDONTWRITEBYTECODE=yep
|
||||||
export SYNAPSE_CACHE_FACTOR=1
|
export SYNAPSE_CACHE_FACTOR=1
|
||||||
|
|
||||||
# Output test results as junit xml
|
./jenkins/prepare_synapse.sh
|
||||||
export TRIAL_FLAGS="--reporter=subunit"
|
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
|
||||||
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
|
|
||||||
# Write coverage reports to a separate file for each process
|
|
||||||
export COVERAGE_OPTS="-p"
|
|
||||||
export DUMP_COVERAGE_COMMAND="coverage help"
|
|
||||||
|
|
||||||
# Output flake8 violations to violations.flake8.log
|
./sytest/jenkins/install_and_run.sh \
|
||||||
# Don't exit with non-0 status code on Jenkins,
|
|
||||||
# so that the build steps continue and a later step can decided whether to
|
|
||||||
# UNSTABLE or FAILURE this build.
|
|
||||||
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
|
|
||||||
|
|
||||||
rm .coverage* || echo "No coverage files to remove"
|
|
||||||
|
|
||||||
tox --notest -e py27
|
|
||||||
TOX_BIN=$WORKSPACE/.tox/py27/bin
|
|
||||||
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
|
||||||
$TOX_BIN/pip install lxml
|
|
||||||
|
|
||||||
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
|
|
||||||
|
|
||||||
if [[ ! -e .sytest-base ]]; then
|
|
||||||
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
|
|
||||||
else
|
|
||||||
(cd .sytest-base; git fetch -p)
|
|
||||||
fi
|
|
||||||
|
|
||||||
rm -rf sytest
|
|
||||||
git clone .sytest-base sytest --shared
|
|
||||||
cd sytest
|
|
||||||
|
|
||||||
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
|
|
||||||
|
|
||||||
: ${PORT_BASE:=8500}
|
|
||||||
./jenkins/install_and_run.sh --coverage \
|
|
||||||
--python $TOX_BIN/python \
|
|
||||||
--synapse-directory $WORKSPACE \
|
--synapse-directory $WORKSPACE \
|
||||||
--port-base $PORT_BASE
|
|
||||||
|
|
||||||
cd ..
|
|
||||||
cp sytest/.coverage.* .
|
|
||||||
|
|
||||||
# Combine the coverage reports
|
|
||||||
echo "Combining:" .coverage.*
|
|
||||||
$TOX_BIN/python -m coverage combine
|
|
||||||
# Output coverage to coverage.xml
|
|
||||||
$TOX_BIN/coverage xml -o coverage.xml
|
|
||||||
|
|
|
@ -22,4 +22,8 @@ export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished w
|
||||||
|
|
||||||
rm .coverage* || echo "No coverage files to remove"
|
rm .coverage* || echo "No coverage files to remove"
|
||||||
|
|
||||||
|
tox --notest -e py27
|
||||||
|
TOX_BIN=$WORKSPACE/.tox/py27/bin
|
||||||
|
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
||||||
|
|
||||||
tox -e py27
|
tox -e py27
|
||||||
|
|
44
jenkins/clone.sh
Executable file
44
jenkins/clone.sh
Executable file
|
@ -0,0 +1,44 @@
|
||||||
|
#! /bin/bash
|
||||||
|
|
||||||
|
# This clones a project from github into a named subdirectory
|
||||||
|
# If the project has a branch with the same name as this branch
|
||||||
|
# then it will checkout that branch after cloning.
|
||||||
|
# Otherwise it will checkout "origin/develop."
|
||||||
|
# The first argument is the name of the directory to checkout
|
||||||
|
# the branch into.
|
||||||
|
# The second argument is the URL of the remote repository to checkout.
|
||||||
|
# Usually something like https://github.com/matrix-org/sytest.git
|
||||||
|
|
||||||
|
set -eux
|
||||||
|
|
||||||
|
NAME=$1
|
||||||
|
PROJECT=$2
|
||||||
|
BASE=".$NAME-base"
|
||||||
|
|
||||||
|
# Update our mirror.
|
||||||
|
if [ ! -d ".$NAME-base" ]; then
|
||||||
|
# Create a local mirror of the source repository.
|
||||||
|
# This saves us from having to download the entire repository
|
||||||
|
# when this script is next run.
|
||||||
|
git clone "$PROJECT" "$BASE" --mirror
|
||||||
|
else
|
||||||
|
# Fetch any updates from the source repository.
|
||||||
|
(cd "$BASE"; git fetch -p)
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Remove the existing repository so that we have a clean copy
|
||||||
|
rm -rf "$NAME"
|
||||||
|
# Cloning with --shared means that we will share portions of the
|
||||||
|
# .git directory with our local mirror.
|
||||||
|
git clone "$BASE" "$NAME" --shared
|
||||||
|
|
||||||
|
# Jenkins may have supplied us with the name of the branch in the
|
||||||
|
# environment. Otherwise we will have to guess based on the current
|
||||||
|
# commit.
|
||||||
|
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
|
||||||
|
cd "$NAME"
|
||||||
|
# check out the relevant branch
|
||||||
|
git checkout "${GIT_BRANCH}" || (
|
||||||
|
echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop"
|
||||||
|
git checkout "origin/develop"
|
||||||
|
)
|
19
jenkins/prepare_synapse.sh
Executable file
19
jenkins/prepare_synapse.sh
Executable file
|
@ -0,0 +1,19 @@
|
||||||
|
#! /bin/bash
|
||||||
|
|
||||||
|
cd "`dirname $0`/.."
|
||||||
|
|
||||||
|
TOX_DIR=$WORKSPACE/.tox
|
||||||
|
|
||||||
|
mkdir -p $TOX_DIR
|
||||||
|
|
||||||
|
if ! [ $TOX_DIR -ef .tox ]; then
|
||||||
|
ln -s "$TOX_DIR" .tox
|
||||||
|
fi
|
||||||
|
|
||||||
|
# set up the virtualenv
|
||||||
|
tox -e py27 --notest -v
|
||||||
|
|
||||||
|
TOX_BIN=$TOX_DIR/py27/bin
|
||||||
|
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
||||||
|
$TOX_BIN/pip install lxml
|
||||||
|
$TOX_BIN/pip install psycopg2
|
|
@ -36,7 +36,7 @@
|
||||||
<div class="debug">
|
<div class="debug">
|
||||||
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
|
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
|
||||||
an event was received at {{ reason.received_at|format_ts("%c") }}
|
an event was received at {{ reason.received_at|format_ts("%c") }}
|
||||||
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} (delay_before_mail_ms) mins ago,
|
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
|
||||||
{% if reason.last_sent_ts %}
|
{% if reason.last_sent_ts %}
|
||||||
and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
|
and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
|
||||||
which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
|
which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
|
||||||
|
|
|
@ -116,17 +116,19 @@ def get_json(origin_name, origin_key, destination, path):
|
||||||
authorization_headers = []
|
authorization_headers = []
|
||||||
|
|
||||||
for key, sig in signed_json["signatures"][origin_name].items():
|
for key, sig in signed_json["signatures"][origin_name].items():
|
||||||
authorization_headers.append(bytes(
|
header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
|
||||||
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
|
|
||||||
origin_name, key, sig,
|
origin_name, key, sig,
|
||||||
)
|
)
|
||||||
))
|
authorization_headers.append(bytes(header))
|
||||||
|
sys.stderr.write(header)
|
||||||
|
sys.stderr.write("\n")
|
||||||
|
|
||||||
result = requests.get(
|
result = requests.get(
|
||||||
lookup(destination, path),
|
lookup(destination, path),
|
||||||
headers={"Authorization": authorization_headers[0]},
|
headers={"Authorization": authorization_headers[0]},
|
||||||
verify=False,
|
verify=False,
|
||||||
)
|
)
|
||||||
|
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
|
||||||
return result.json()
|
return result.json()
|
||||||
|
|
||||||
|
|
||||||
|
@ -141,6 +143,7 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
json.dump(result, sys.stdout)
|
json.dump(result, sys.stdout)
|
||||||
|
print ""
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -1,10 +1,16 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
import getpass
|
import getpass
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
bcrypt_rounds=12
|
bcrypt_rounds=12
|
||||||
|
password_pepper = ""
|
||||||
|
|
||||||
def prompt_for_pass():
|
def prompt_for_pass():
|
||||||
password = getpass.getpass("Password: ")
|
password = getpass.getpass("Password: ")
|
||||||
|
@ -28,12 +34,22 @@ if __name__ == "__main__":
|
||||||
default=None,
|
default=None,
|
||||||
help="New password for user. Will prompt if omitted.",
|
help="New password for user. Will prompt if omitted.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-c", "--config",
|
||||||
|
type=argparse.FileType('r'),
|
||||||
|
help="Path to server config file. Used to read in bcrypt_rounds and password_pepper.",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
if "config" in args and args.config:
|
||||||
|
config = yaml.safe_load(args.config)
|
||||||
|
bcrypt_rounds = config.get("bcrypt_rounds", bcrypt_rounds)
|
||||||
|
password_config = config.get("password_config", {})
|
||||||
|
password_pepper = password_config.get("pepper", password_pepper)
|
||||||
password = args.password
|
password = args.password
|
||||||
|
|
||||||
if not password:
|
if not password:
|
||||||
password = prompt_for_pass()
|
password = prompt_for_pass()
|
||||||
|
|
||||||
print bcrypt.hashpw(password, bcrypt.gensalt(bcrypt_rounds))
|
print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds))
|
||||||
|
|
||||||
|
|
|
@ -25,18 +25,26 @@ import urllib2
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
def request_registration(user, password, server_location, shared_secret):
|
def request_registration(user, password, server_location, shared_secret, admin=False):
|
||||||
mac = hmac.new(
|
mac = hmac.new(
|
||||||
key=shared_secret,
|
key=shared_secret,
|
||||||
msg=user,
|
|
||||||
digestmod=hashlib.sha1,
|
digestmod=hashlib.sha1,
|
||||||
).hexdigest()
|
)
|
||||||
|
|
||||||
|
mac.update(user)
|
||||||
|
mac.update("\x00")
|
||||||
|
mac.update(password)
|
||||||
|
mac.update("\x00")
|
||||||
|
mac.update("admin" if admin else "notadmin")
|
||||||
|
|
||||||
|
mac = mac.hexdigest()
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"user": user,
|
"user": user,
|
||||||
"password": password,
|
"password": password,
|
||||||
"mac": mac,
|
"mac": mac,
|
||||||
"type": "org.matrix.login.shared_secret",
|
"type": "org.matrix.login.shared_secret",
|
||||||
|
"admin": admin,
|
||||||
}
|
}
|
||||||
|
|
||||||
server_location = server_location.rstrip("/")
|
server_location = server_location.rstrip("/")
|
||||||
|
@ -68,7 +76,7 @@ def request_registration(user, password, server_location, shared_secret):
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def register_new_user(user, password, server_location, shared_secret):
|
def register_new_user(user, password, server_location, shared_secret, admin):
|
||||||
if not user:
|
if not user:
|
||||||
try:
|
try:
|
||||||
default_user = getpass.getuser()
|
default_user = getpass.getuser()
|
||||||
|
@ -99,7 +107,14 @@ def register_new_user(user, password, server_location, shared_secret):
|
||||||
print "Passwords do not match"
|
print "Passwords do not match"
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
request_registration(user, password, server_location, shared_secret)
|
if not admin:
|
||||||
|
admin = raw_input("Make admin [no]: ")
|
||||||
|
if admin in ("y", "yes", "true"):
|
||||||
|
admin = True
|
||||||
|
else:
|
||||||
|
admin = False
|
||||||
|
|
||||||
|
request_registration(user, password, server_location, shared_secret, bool(admin))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -119,6 +134,11 @@ if __name__ == "__main__":
|
||||||
default=None,
|
default=None,
|
||||||
help="New password for user. Will prompt if omitted.",
|
help="New password for user. Will prompt if omitted.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-a", "--admin",
|
||||||
|
action="store_true",
|
||||||
|
help="Register new user as an admin. Will prompt if omitted.",
|
||||||
|
)
|
||||||
|
|
||||||
group = parser.add_mutually_exclusive_group(required=True)
|
group = parser.add_mutually_exclusive_group(required=True)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
|
@ -151,4 +171,4 @@ if __name__ == "__main__":
|
||||||
else:
|
else:
|
||||||
secret = args.shared_secret
|
secret = args.shared_secret
|
||||||
|
|
||||||
register_new_user(args.user, args.password, args.server_url, secret)
|
register_new_user(args.user, args.password, args.server_url, secret, args.admin)
|
||||||
|
|
|
@ -34,7 +34,7 @@ logger = logging.getLogger("synapse_port_db")
|
||||||
|
|
||||||
|
|
||||||
BOOLEAN_COLUMNS = {
|
BOOLEAN_COLUMNS = {
|
||||||
"events": ["processed", "outlier"],
|
"events": ["processed", "outlier", "contains_url"],
|
||||||
"rooms": ["is_public"],
|
"rooms": ["is_public"],
|
||||||
"event_edges": ["is_state"],
|
"event_edges": ["is_state"],
|
||||||
"presence_list": ["accepted"],
|
"presence_list": ["accepted"],
|
||||||
|
@ -92,8 +92,12 @@ class Store(object):
|
||||||
|
|
||||||
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
|
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
|
||||||
_simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
|
_simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
|
||||||
|
_simple_select_one = SQLBaseStore.__dict__["_simple_select_one"]
|
||||||
|
_simple_select_one_txn = SQLBaseStore.__dict__["_simple_select_one_txn"]
|
||||||
_simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
|
_simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
|
||||||
_simple_select_one_onecol_txn = SQLBaseStore.__dict__["_simple_select_one_onecol_txn"]
|
_simple_select_one_onecol_txn = SQLBaseStore.__dict__[
|
||||||
|
"_simple_select_one_onecol_txn"
|
||||||
|
]
|
||||||
|
|
||||||
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
|
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
|
||||||
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
|
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
|
||||||
|
@ -158,31 +162,40 @@ class Porter(object):
|
||||||
def setup_table(self, table):
|
def setup_table(self, table):
|
||||||
if table in APPEND_ONLY_TABLES:
|
if table in APPEND_ONLY_TABLES:
|
||||||
# It's safe to just carry on inserting.
|
# It's safe to just carry on inserting.
|
||||||
next_chunk = yield self.postgres_store._simple_select_one_onecol(
|
row = yield self.postgres_store._simple_select_one(
|
||||||
table="port_from_sqlite3",
|
table="port_from_sqlite3",
|
||||||
keyvalues={"table_name": table},
|
keyvalues={"table_name": table},
|
||||||
retcol="rowid",
|
retcols=("forward_rowid", "backward_rowid"),
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
total_to_port = None
|
total_to_port = None
|
||||||
if next_chunk is None:
|
if row is None:
|
||||||
if table == "sent_transactions":
|
if table == "sent_transactions":
|
||||||
next_chunk, already_ported, total_to_port = (
|
forward_chunk, already_ported, total_to_port = (
|
||||||
yield self._setup_sent_transactions()
|
yield self._setup_sent_transactions()
|
||||||
)
|
)
|
||||||
|
backward_chunk = 0
|
||||||
else:
|
else:
|
||||||
yield self.postgres_store._simple_insert(
|
yield self.postgres_store._simple_insert(
|
||||||
table="port_from_sqlite3",
|
table="port_from_sqlite3",
|
||||||
values={"table_name": table, "rowid": 1}
|
values={
|
||||||
|
"table_name": table,
|
||||||
|
"forward_rowid": 1,
|
||||||
|
"backward_rowid": 0,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
next_chunk = 1
|
forward_chunk = 1
|
||||||
|
backward_chunk = 0
|
||||||
already_ported = 0
|
already_ported = 0
|
||||||
|
else:
|
||||||
|
forward_chunk = row["forward_rowid"]
|
||||||
|
backward_chunk = row["backward_rowid"]
|
||||||
|
|
||||||
if total_to_port is None:
|
if total_to_port is None:
|
||||||
already_ported, total_to_port = yield self._get_total_count_to_port(
|
already_ported, total_to_port = yield self._get_total_count_to_port(
|
||||||
table, next_chunk
|
table, forward_chunk, backward_chunk
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
def delete_all(txn):
|
def delete_all(txn):
|
||||||
|
@ -196,46 +209,85 @@ class Porter(object):
|
||||||
|
|
||||||
yield self.postgres_store._simple_insert(
|
yield self.postgres_store._simple_insert(
|
||||||
table="port_from_sqlite3",
|
table="port_from_sqlite3",
|
||||||
values={"table_name": table, "rowid": 0}
|
values={
|
||||||
|
"table_name": table,
|
||||||
|
"forward_rowid": 1,
|
||||||
|
"backward_rowid": 0,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
next_chunk = 1
|
forward_chunk = 1
|
||||||
|
backward_chunk = 0
|
||||||
|
|
||||||
already_ported, total_to_port = yield self._get_total_count_to_port(
|
already_ported, total_to_port = yield self._get_total_count_to_port(
|
||||||
table, next_chunk
|
table, forward_chunk, backward_chunk
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((table, already_ported, total_to_port, next_chunk))
|
defer.returnValue(
|
||||||
|
(table, already_ported, total_to_port, forward_chunk, backward_chunk)
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def handle_table(self, table, postgres_size, table_size, next_chunk):
|
def handle_table(self, table, postgres_size, table_size, forward_chunk,
|
||||||
|
backward_chunk):
|
||||||
if not table_size:
|
if not table_size:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.progress.add_table(table, postgres_size, table_size)
|
self.progress.add_table(table, postgres_size, table_size)
|
||||||
|
|
||||||
if table == "event_search":
|
if table == "event_search":
|
||||||
yield self.handle_search_table(postgres_size, table_size, next_chunk)
|
yield self.handle_search_table(
|
||||||
|
postgres_size, table_size, forward_chunk, backward_chunk
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
select = (
|
forward_select = (
|
||||||
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
|
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
|
||||||
% (table,)
|
% (table,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
backward_select = (
|
||||||
|
"SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?"
|
||||||
|
% (table,)
|
||||||
|
)
|
||||||
|
|
||||||
|
do_forward = [True]
|
||||||
|
do_backward = [True]
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
def r(txn):
|
def r(txn):
|
||||||
txn.execute(select, (next_chunk, self.batch_size,))
|
forward_rows = []
|
||||||
rows = txn.fetchall()
|
backward_rows = []
|
||||||
|
if do_forward[0]:
|
||||||
|
txn.execute(forward_select, (forward_chunk, self.batch_size,))
|
||||||
|
forward_rows = txn.fetchall()
|
||||||
|
if not forward_rows:
|
||||||
|
do_forward[0] = False
|
||||||
|
|
||||||
|
if do_backward[0]:
|
||||||
|
txn.execute(backward_select, (backward_chunk, self.batch_size,))
|
||||||
|
backward_rows = txn.fetchall()
|
||||||
|
if not backward_rows:
|
||||||
|
do_backward[0] = False
|
||||||
|
|
||||||
|
if forward_rows or backward_rows:
|
||||||
headers = [column[0] for column in txn.description]
|
headers = [column[0] for column in txn.description]
|
||||||
|
else:
|
||||||
|
headers = None
|
||||||
|
|
||||||
return headers, rows
|
return headers, forward_rows, backward_rows
|
||||||
|
|
||||||
headers, rows = yield self.sqlite_store.runInteraction("select", r)
|
headers, frows, brows = yield self.sqlite_store.runInteraction(
|
||||||
|
"select", r
|
||||||
|
)
|
||||||
|
|
||||||
if rows:
|
if frows or brows:
|
||||||
next_chunk = rows[-1][0] + 1
|
if frows:
|
||||||
|
forward_chunk = max(row[0] for row in frows) + 1
|
||||||
|
if brows:
|
||||||
|
backward_chunk = min(row[0] for row in brows) - 1
|
||||||
|
|
||||||
|
rows = frows + brows
|
||||||
self._convert_rows(table, headers, rows)
|
self._convert_rows(table, headers, rows)
|
||||||
|
|
||||||
def insert(txn):
|
def insert(txn):
|
||||||
|
@ -247,7 +299,10 @@ class Porter(object):
|
||||||
txn,
|
txn,
|
||||||
table="port_from_sqlite3",
|
table="port_from_sqlite3",
|
||||||
keyvalues={"table_name": table},
|
keyvalues={"table_name": table},
|
||||||
updatevalues={"rowid": next_chunk},
|
updatevalues={
|
||||||
|
"forward_rowid": forward_chunk,
|
||||||
|
"backward_rowid": backward_chunk,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.postgres_store.execute(insert)
|
yield self.postgres_store.execute(insert)
|
||||||
|
@ -259,7 +314,8 @@ class Porter(object):
|
||||||
return
|
return
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def handle_search_table(self, postgres_size, table_size, next_chunk):
|
def handle_search_table(self, postgres_size, table_size, forward_chunk,
|
||||||
|
backward_chunk):
|
||||||
select = (
|
select = (
|
||||||
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
|
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
|
||||||
" FROM event_search as es"
|
" FROM event_search as es"
|
||||||
|
@ -270,7 +326,7 @@ class Porter(object):
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
def r(txn):
|
def r(txn):
|
||||||
txn.execute(select, (next_chunk, self.batch_size,))
|
txn.execute(select, (forward_chunk, self.batch_size,))
|
||||||
rows = txn.fetchall()
|
rows = txn.fetchall()
|
||||||
headers = [column[0] for column in txn.description]
|
headers = [column[0] for column in txn.description]
|
||||||
|
|
||||||
|
@ -279,7 +335,7 @@ class Porter(object):
|
||||||
headers, rows = yield self.sqlite_store.runInteraction("select", r)
|
headers, rows = yield self.sqlite_store.runInteraction("select", r)
|
||||||
|
|
||||||
if rows:
|
if rows:
|
||||||
next_chunk = rows[-1][0] + 1
|
forward_chunk = rows[-1][0] + 1
|
||||||
|
|
||||||
# We have to treat event_search differently since it has a
|
# We have to treat event_search differently since it has a
|
||||||
# different structure in the two different databases.
|
# different structure in the two different databases.
|
||||||
|
@ -312,7 +368,10 @@ class Porter(object):
|
||||||
txn,
|
txn,
|
||||||
table="port_from_sqlite3",
|
table="port_from_sqlite3",
|
||||||
keyvalues={"table_name": "event_search"},
|
keyvalues={"table_name": "event_search"},
|
||||||
updatevalues={"rowid": next_chunk},
|
updatevalues={
|
||||||
|
"forward_rowid": forward_chunk,
|
||||||
|
"backward_rowid": backward_chunk,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.postgres_store.execute(insert)
|
yield self.postgres_store.execute(insert)
|
||||||
|
@ -324,7 +383,6 @@ class Porter(object):
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def setup_db(self, db_config, database_engine):
|
def setup_db(self, db_config, database_engine):
|
||||||
db_conn = database_engine.module.connect(
|
db_conn = database_engine.module.connect(
|
||||||
**{
|
**{
|
||||||
|
@ -395,10 +453,32 @@ class Porter(object):
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"CREATE TABLE port_from_sqlite3 ("
|
"CREATE TABLE port_from_sqlite3 ("
|
||||||
" table_name varchar(100) NOT NULL UNIQUE,"
|
" table_name varchar(100) NOT NULL UNIQUE,"
|
||||||
" rowid bigint NOT NULL"
|
" forward_rowid bigint NOT NULL,"
|
||||||
|
" backward_rowid bigint NOT NULL"
|
||||||
")"
|
")"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# The old port script created a table with just a "rowid" column.
|
||||||
|
# We want people to be able to rerun this script from an old port
|
||||||
|
# so that they can pick up any missing events that were not
|
||||||
|
# ported across.
|
||||||
|
def alter_table(txn):
|
||||||
|
txn.execute(
|
||||||
|
"ALTER TABLE IF EXISTS port_from_sqlite3"
|
||||||
|
" RENAME rowid TO forward_rowid"
|
||||||
|
)
|
||||||
|
txn.execute(
|
||||||
|
"ALTER TABLE IF EXISTS port_from_sqlite3"
|
||||||
|
" ADD backward_rowid bigint NOT NULL DEFAULT 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self.postgres_store.runInteraction(
|
||||||
|
"alter_table", alter_table
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.info("Failed to create port table: %s", e)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.postgres_store.runInteraction(
|
yield self.postgres_store.runInteraction(
|
||||||
"create_port_table", create_port_table
|
"create_port_table", create_port_table
|
||||||
|
@ -458,7 +538,7 @@ class Porter(object):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _setup_sent_transactions(self):
|
def _setup_sent_transactions(self):
|
||||||
# Only save things from the last day
|
# Only save things from the last day
|
||||||
yesterday = int(time.time()*1000) - 86400000
|
yesterday = int(time.time() * 1000) - 86400000
|
||||||
|
|
||||||
# And save the max transaction id from each destination
|
# And save the max transaction id from each destination
|
||||||
select = (
|
select = (
|
||||||
|
@ -514,7 +594,11 @@ class Porter(object):
|
||||||
|
|
||||||
yield self.postgres_store._simple_insert(
|
yield self.postgres_store._simple_insert(
|
||||||
table="port_from_sqlite3",
|
table="port_from_sqlite3",
|
||||||
values={"table_name": "sent_transactions", "rowid": next_chunk}
|
values={
|
||||||
|
"table_name": "sent_transactions",
|
||||||
|
"forward_rowid": next_chunk,
|
||||||
|
"backward_rowid": 0,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_sent_table_size(txn):
|
def get_sent_table_size(txn):
|
||||||
|
@ -535,13 +619,18 @@ class Porter(object):
|
||||||
defer.returnValue((next_chunk, inserted_rows, total_count))
|
defer.returnValue((next_chunk, inserted_rows, total_count))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_remaining_count_to_port(self, table, next_chunk):
|
def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
|
||||||
rows = yield self.sqlite_store.execute_sql(
|
frows = yield self.sqlite_store.execute_sql(
|
||||||
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
|
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
|
||||||
next_chunk,
|
forward_chunk,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(rows[0][0])
|
brows = yield self.sqlite_store.execute_sql(
|
||||||
|
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,),
|
||||||
|
backward_chunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(frows[0][0] + brows[0][0])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_already_ported_count(self, table):
|
def _get_already_ported_count(self, table):
|
||||||
|
@ -552,10 +641,10 @@ class Porter(object):
|
||||||
defer.returnValue(rows[0][0])
|
defer.returnValue(rows[0][0])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_total_count_to_port(self, table, next_chunk):
|
def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
|
||||||
remaining, done = yield defer.gatherResults(
|
remaining, done = yield defer.gatherResults(
|
||||||
[
|
[
|
||||||
self._get_remaining_count_to_port(table, next_chunk),
|
self._get_remaining_count_to_port(table, forward_chunk, backward_chunk),
|
||||||
self._get_already_ported_count(table),
|
self._get_already_ported_count(table),
|
||||||
],
|
],
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
|
@ -686,7 +775,7 @@ class CursesProgress(Progress):
|
||||||
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
|
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
|
||||||
|
|
||||||
self.stdscr.addstr(
|
self.stdscr.addstr(
|
||||||
i+2, left_margin + max_len - len(table),
|
i + 2, left_margin + max_len - len(table),
|
||||||
table,
|
table,
|
||||||
curses.A_BOLD | color,
|
curses.A_BOLD | color,
|
||||||
)
|
)
|
||||||
|
@ -694,18 +783,18 @@ class CursesProgress(Progress):
|
||||||
size = 20
|
size = 20
|
||||||
|
|
||||||
progress = "[%s%s]" % (
|
progress = "[%s%s]" % (
|
||||||
"#" * int(perc*size/100),
|
"#" * int(perc * size / 100),
|
||||||
" " * (size - int(perc*size/100)),
|
" " * (size - int(perc * size / 100)),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.stdscr.addstr(
|
self.stdscr.addstr(
|
||||||
i+2, left_margin + max_len + middle_space,
|
i + 2, left_margin + max_len + middle_space,
|
||||||
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
|
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.finished:
|
if self.finished:
|
||||||
self.stdscr.addstr(
|
self.stdscr.addstr(
|
||||||
rows-1, 0,
|
rows - 1, 0,
|
||||||
"Press any key to exit...",
|
"Press any key to exit...",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,5 @@ ignore =
|
||||||
|
|
||||||
[flake8]
|
[flake8]
|
||||||
max-line-length = 90
|
max-line-length = 90
|
||||||
ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
|
# W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
|
||||||
|
ignore = W503
|
||||||
[pep8]
|
|
||||||
max-line-length = 90
|
|
||||||
|
|
|
@ -16,4 +16,4 @@
|
||||||
""" This is a reference implementation of a Matrix home server.
|
""" This is a reference implementation of a Matrix home server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.16.1-r1"
|
__version__ = "0.17.0"
|
||||||
|
|
|
@ -13,22 +13,22 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import pymacaroons
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
from signedjson.sign import verify_signed_json, SignatureVerifyException
|
from signedjson.sign import verify_signed_json, SignatureVerifyException
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
|
||||||
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
|
|
||||||
from synapse.types import Requester, UserID, get_domain_from_id
|
|
||||||
from synapse.util.logutils import log_function
|
|
||||||
from synapse.util.logcontext import preserve_context_over_fn
|
|
||||||
from synapse.util.metrics import Measure
|
|
||||||
from unpaddedbase64 import decode_base64
|
from unpaddedbase64 import decode_base64
|
||||||
|
|
||||||
import logging
|
import synapse.types
|
||||||
import pymacaroons
|
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||||
|
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
|
||||||
|
from synapse.types import UserID, get_domain_from_id
|
||||||
|
from synapse.util.logcontext import preserve_context_over_fn
|
||||||
|
from synapse.util.logutils import log_function
|
||||||
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -63,7 +63,7 @@ class Auth(object):
|
||||||
"user_id = ",
|
"user_id = ",
|
||||||
])
|
])
|
||||||
|
|
||||||
def check(self, event, auth_events):
|
def check(self, event, auth_events, do_sig_check=True):
|
||||||
""" Checks if this event is correctly authed.
|
""" Checks if this event is correctly authed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -79,6 +79,13 @@ class Auth(object):
|
||||||
|
|
||||||
if not hasattr(event, "room_id"):
|
if not hasattr(event, "room_id"):
|
||||||
raise AuthError(500, "Event has no room_id: %s" % event)
|
raise AuthError(500, "Event has no room_id: %s" % event)
|
||||||
|
|
||||||
|
sender_domain = get_domain_from_id(event.sender)
|
||||||
|
|
||||||
|
# Check the sender's domain has signed the event
|
||||||
|
if do_sig_check and not event.signatures.get(sender_domain):
|
||||||
|
raise AuthError(403, "Event not signed by sending server")
|
||||||
|
|
||||||
if auth_events is None:
|
if auth_events is None:
|
||||||
# Oh, we don't know what the state of the room was, so we
|
# Oh, we don't know what the state of the room was, so we
|
||||||
# are trusting that this is allowed (at least for now)
|
# are trusting that this is allowed (at least for now)
|
||||||
|
@ -86,6 +93,12 @@ class Auth(object):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if event.type == EventTypes.Create:
|
if event.type == EventTypes.Create:
|
||||||
|
room_id_domain = get_domain_from_id(event.room_id)
|
||||||
|
if room_id_domain != sender_domain:
|
||||||
|
raise AuthError(
|
||||||
|
403,
|
||||||
|
"Creation event's room_id domain does not match sender's"
|
||||||
|
)
|
||||||
# FIXME
|
# FIXME
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -108,6 +121,22 @@ class Auth(object):
|
||||||
|
|
||||||
# FIXME: Temp hack
|
# FIXME: Temp hack
|
||||||
if event.type == EventTypes.Aliases:
|
if event.type == EventTypes.Aliases:
|
||||||
|
if not event.is_state():
|
||||||
|
raise AuthError(
|
||||||
|
403,
|
||||||
|
"Alias event must be a state event",
|
||||||
|
)
|
||||||
|
if not event.state_key:
|
||||||
|
raise AuthError(
|
||||||
|
403,
|
||||||
|
"Alias event must have non-empty state_key"
|
||||||
|
)
|
||||||
|
sender_domain = get_domain_from_id(event.sender)
|
||||||
|
if event.state_key != sender_domain:
|
||||||
|
raise AuthError(
|
||||||
|
403,
|
||||||
|
"Alias event's state_key does not match sender's domain"
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -347,6 +376,10 @@ class Auth(object):
|
||||||
if Membership.INVITE == membership and "third_party_invite" in event.content:
|
if Membership.INVITE == membership and "third_party_invite" in event.content:
|
||||||
if not self._verify_third_party_invite(event, auth_events):
|
if not self._verify_third_party_invite(event, auth_events):
|
||||||
raise AuthError(403, "You are not invited to this room.")
|
raise AuthError(403, "You are not invited to this room.")
|
||||||
|
if target_banned:
|
||||||
|
raise AuthError(
|
||||||
|
403, "%s is banned from the room" % (target_user_id,)
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if Membership.JOIN != membership:
|
if Membership.JOIN != membership:
|
||||||
|
@ -537,9 +570,7 @@ class Auth(object):
|
||||||
Args:
|
Args:
|
||||||
request - An HTTP request with an access_token query parameter.
|
request - An HTTP request with an access_token query parameter.
|
||||||
Returns:
|
Returns:
|
||||||
tuple of:
|
defer.Deferred: resolves to a ``synapse.types.Requester`` object
|
||||||
UserID (str)
|
|
||||||
Access token ID (str)
|
|
||||||
Raises:
|
Raises:
|
||||||
AuthError if no user by that token exists or the token is invalid.
|
AuthError if no user by that token exists or the token is invalid.
|
||||||
"""
|
"""
|
||||||
|
@ -548,9 +579,7 @@ class Auth(object):
|
||||||
user_id = yield self._get_appservice_user_id(request.args)
|
user_id = yield self._get_appservice_user_id(request.args)
|
||||||
if user_id:
|
if user_id:
|
||||||
request.authenticated_entity = user_id
|
request.authenticated_entity = user_id
|
||||||
defer.returnValue(
|
defer.returnValue(synapse.types.create_requester(user_id))
|
||||||
Requester(UserID.from_string(user_id), "", False)
|
|
||||||
)
|
|
||||||
|
|
||||||
access_token = request.args["access_token"][0]
|
access_token = request.args["access_token"][0]
|
||||||
user_info = yield self.get_user_by_access_token(access_token, rights)
|
user_info = yield self.get_user_by_access_token(access_token, rights)
|
||||||
|
@ -558,6 +587,10 @@ class Auth(object):
|
||||||
token_id = user_info["token_id"]
|
token_id = user_info["token_id"]
|
||||||
is_guest = user_info["is_guest"]
|
is_guest = user_info["is_guest"]
|
||||||
|
|
||||||
|
# device_id may not be present if get_user_by_access_token has been
|
||||||
|
# stubbed out.
|
||||||
|
device_id = user_info.get("device_id")
|
||||||
|
|
||||||
ip_addr = self.hs.get_ip_from_request(request)
|
ip_addr = self.hs.get_ip_from_request(request)
|
||||||
user_agent = request.requestHeaders.getRawHeaders(
|
user_agent = request.requestHeaders.getRawHeaders(
|
||||||
"User-Agent",
|
"User-Agent",
|
||||||
|
@ -569,7 +602,8 @@ class Auth(object):
|
||||||
user=user,
|
user=user,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
ip=ip_addr,
|
ip=ip_addr,
|
||||||
user_agent=user_agent
|
user_agent=user_agent,
|
||||||
|
device_id=device_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_guest and not allow_guest:
|
if is_guest and not allow_guest:
|
||||||
|
@ -579,7 +613,8 @@ class Auth(object):
|
||||||
|
|
||||||
request.authenticated_entity = user.to_string()
|
request.authenticated_entity = user.to_string()
|
||||||
|
|
||||||
defer.returnValue(Requester(user, token_id, is_guest))
|
defer.returnValue(synapse.types.create_requester(
|
||||||
|
user, token_id, is_guest, device_id))
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
|
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
|
||||||
|
@ -629,7 +664,10 @@ class Auth(object):
|
||||||
except AuthError:
|
except AuthError:
|
||||||
# TODO(daniel): Remove this fallback when all existing access tokens
|
# TODO(daniel): Remove this fallback when all existing access tokens
|
||||||
# have been re-issued as macaroons.
|
# have been re-issued as macaroons.
|
||||||
|
if self.hs.config.expire_access_token:
|
||||||
|
raise
|
||||||
ret = yield self._look_up_user_by_access_token(token)
|
ret = yield self._look_up_user_by_access_token(token)
|
||||||
|
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -664,6 +702,7 @@ class Auth(object):
|
||||||
"user": user,
|
"user": user,
|
||||||
"is_guest": True,
|
"is_guest": True,
|
||||||
"token_id": None,
|
"token_id": None,
|
||||||
|
"device_id": None,
|
||||||
}
|
}
|
||||||
elif rights == "delete_pusher":
|
elif rights == "delete_pusher":
|
||||||
# We don't store these tokens in the database
|
# We don't store these tokens in the database
|
||||||
|
@ -671,13 +710,20 @@ class Auth(object):
|
||||||
"user": user,
|
"user": user,
|
||||||
"is_guest": False,
|
"is_guest": False,
|
||||||
"token_id": None,
|
"token_id": None,
|
||||||
|
"device_id": None,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# This codepath exists so that we can actually return a
|
# This codepath exists for several reasons:
|
||||||
# token ID, because we use token IDs in place of device
|
# * so that we can actually return a token ID, which is used
|
||||||
# identifiers throughout the codebase.
|
# in some parts of the schema (where we probably ought to
|
||||||
# TODO(daniel): Remove this fallback when device IDs are
|
# use device IDs instead)
|
||||||
# properly implemented.
|
# * the only way we currently have to invalidate an
|
||||||
|
# access_token is by removing it from the database, so we
|
||||||
|
# have to check here that it is still in the db
|
||||||
|
# * some attributes (notably device_id) aren't stored in the
|
||||||
|
# macaroon. They probably should be.
|
||||||
|
# TODO: build the dictionary from the macaroon once the
|
||||||
|
# above are fixed
|
||||||
ret = yield self._look_up_user_by_access_token(macaroon_str)
|
ret = yield self._look_up_user_by_access_token(macaroon_str)
|
||||||
if ret["user"] != user:
|
if ret["user"] != user:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
@ -751,10 +797,14 @@ class Auth(object):
|
||||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
|
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
|
||||||
errcode=Codes.UNKNOWN_TOKEN
|
errcode=Codes.UNKNOWN_TOKEN
|
||||||
)
|
)
|
||||||
|
# we use ret.get() below because *lots* of unit tests stub out
|
||||||
|
# get_user_by_access_token in a way where it only returns a couple of
|
||||||
|
# the fields.
|
||||||
user_info = {
|
user_info = {
|
||||||
"user": UserID.from_string(ret.get("name")),
|
"user": UserID.from_string(ret.get("name")),
|
||||||
"token_id": ret.get("token_id", None),
|
"token_id": ret.get("token_id", None),
|
||||||
"is_guest": False,
|
"is_guest": False,
|
||||||
|
"device_id": ret.get("device_id"),
|
||||||
}
|
}
|
||||||
defer.returnValue(user_info)
|
defer.returnValue(user_info)
|
||||||
|
|
||||||
|
|
|
@ -42,8 +42,10 @@ class Codes(object):
|
||||||
TOO_LARGE = "M_TOO_LARGE"
|
TOO_LARGE = "M_TOO_LARGE"
|
||||||
EXCLUSIVE = "M_EXCLUSIVE"
|
EXCLUSIVE = "M_EXCLUSIVE"
|
||||||
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
||||||
THREEPID_IN_USE = "THREEPID_IN_USE"
|
THREEPID_IN_USE = "M_THREEPID_IN_USE"
|
||||||
|
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
|
||||||
INVALID_USERNAME = "M_INVALID_USERNAME"
|
INVALID_USERNAME = "M_INVALID_USERNAME"
|
||||||
|
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
||||||
|
|
||||||
|
|
||||||
class CodeMessageException(RuntimeError):
|
class CodeMessageException(RuntimeError):
|
||||||
|
|
|
@ -191,6 +191,17 @@ class Filter(object):
|
||||||
def __init__(self, filter_json):
|
def __init__(self, filter_json):
|
||||||
self.filter_json = filter_json
|
self.filter_json = filter_json
|
||||||
|
|
||||||
|
self.types = self.filter_json.get("types", None)
|
||||||
|
self.not_types = self.filter_json.get("not_types", [])
|
||||||
|
|
||||||
|
self.rooms = self.filter_json.get("rooms", None)
|
||||||
|
self.not_rooms = self.filter_json.get("not_rooms", [])
|
||||||
|
|
||||||
|
self.senders = self.filter_json.get("senders", None)
|
||||||
|
self.not_senders = self.filter_json.get("not_senders", [])
|
||||||
|
|
||||||
|
self.contains_url = self.filter_json.get("contains_url", None)
|
||||||
|
|
||||||
def check(self, event):
|
def check(self, event):
|
||||||
"""Checks whether the filter matches the given event.
|
"""Checks whether the filter matches the given event.
|
||||||
|
|
||||||
|
@ -209,9 +220,10 @@ class Filter(object):
|
||||||
event.get("room_id", None),
|
event.get("room_id", None),
|
||||||
sender,
|
sender,
|
||||||
event.get("type", None),
|
event.get("type", None),
|
||||||
|
"url" in event.get("content", {})
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_fields(self, room_id, sender, event_type):
|
def check_fields(self, room_id, sender, event_type, contains_url):
|
||||||
"""Checks whether the filter matches the given event fields.
|
"""Checks whether the filter matches the given event fields.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -225,15 +237,20 @@ class Filter(object):
|
||||||
|
|
||||||
for name, match_func in literal_keys.items():
|
for name, match_func in literal_keys.items():
|
||||||
not_name = "not_%s" % (name,)
|
not_name = "not_%s" % (name,)
|
||||||
disallowed_values = self.filter_json.get(not_name, [])
|
disallowed_values = getattr(self, not_name)
|
||||||
if any(map(match_func, disallowed_values)):
|
if any(map(match_func, disallowed_values)):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
allowed_values = self.filter_json.get(name, None)
|
allowed_values = getattr(self, name)
|
||||||
if allowed_values is not None:
|
if allowed_values is not None:
|
||||||
if not any(map(match_func, allowed_values)):
|
if not any(map(match_func, allowed_values)):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
contains_url_filter = self.filter_json.get("contains_url")
|
||||||
|
if contains_url_filter is not None:
|
||||||
|
if contains_url_filter != contains_url:
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def filter_rooms(self, room_ids):
|
def filter_rooms(self, room_ids):
|
||||||
|
|
|
@ -16,13 +16,11 @@
|
||||||
import sys
|
import sys
|
||||||
sys.dont_write_bytecode = True
|
sys.dont_write_bytecode = True
|
||||||
|
|
||||||
from synapse.python_dependencies import (
|
from synapse import python_dependencies # noqa: E402
|
||||||
check_requirements, MissingRequirementError
|
|
||||||
) # NOQA
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
check_requirements()
|
python_dependencies.check_requirements()
|
||||||
except MissingRequirementError as e:
|
except python_dependencies.MissingRequirementError as e:
|
||||||
message = "\n".join([
|
message = "\n".join([
|
||||||
"Missing Requirement: %s" % (e.message,),
|
"Missing Requirement: %s" % (e.message,),
|
||||||
"To install run:",
|
"To install run:",
|
||||||
|
|
206
synapse/app/federation_reader.py
Normal file
206
synapse/app/federation_reader.py
Normal file
|
@ -0,0 +1,206 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import synapse
|
||||||
|
|
||||||
|
from synapse.config._base import ConfigError
|
||||||
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
from synapse.config.logger import setup_logging
|
||||||
|
from synapse.http.site import SynapseSite
|
||||||
|
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||||
|
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
|
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||||
|
from synapse.replication.slave.storage.keys import SlavedKeyStore
|
||||||
|
from synapse.replication.slave.storage.room import RoomStore
|
||||||
|
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||||
|
from synapse.replication.slave.storage.directory import DirectoryStore
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.storage.engines import create_engine
|
||||||
|
from synapse.util.async import sleep
|
||||||
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
from synapse.util.manhole import manhole
|
||||||
|
from synapse.util.rlimit import change_resource_limit
|
||||||
|
from synapse.util.versionstring import get_version_string
|
||||||
|
from synapse.api.urls import FEDERATION_PREFIX
|
||||||
|
from synapse.federation.transport.server import TransportLayerServer
|
||||||
|
from synapse.crypto import context_factory
|
||||||
|
|
||||||
|
|
||||||
|
from twisted.internet import reactor, defer
|
||||||
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
|
from daemonize import Daemonize
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import gc
|
||||||
|
|
||||||
|
logger = logging.getLogger("synapse.app.federation_reader")
|
||||||
|
|
||||||
|
|
||||||
|
class FederationReaderSlavedStore(
|
||||||
|
SlavedEventStore,
|
||||||
|
SlavedKeyStore,
|
||||||
|
RoomStore,
|
||||||
|
DirectoryStore,
|
||||||
|
TransactionStore,
|
||||||
|
BaseSlavedStore,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FederationReaderServer(HomeServer):
|
||||||
|
def get_db_conn(self, run_new_connection=True):
|
||||||
|
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||||
|
# not be passed to the database engine.
|
||||||
|
db_params = {
|
||||||
|
k: v for k, v in self.db_config.get("args", {}).items()
|
||||||
|
if not k.startswith("cp_")
|
||||||
|
}
|
||||||
|
db_conn = self.database_engine.module.connect(**db_params)
|
||||||
|
|
||||||
|
if run_new_connection:
|
||||||
|
self.database_engine.on_new_connection(db_conn)
|
||||||
|
return db_conn
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
logger.info("Setting up.")
|
||||||
|
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
|
||||||
|
logger.info("Finished setting up.")
|
||||||
|
|
||||||
|
def _listen_http(self, listener_config):
|
||||||
|
port = listener_config["port"]
|
||||||
|
bind_address = listener_config.get("bind_address", "")
|
||||||
|
site_tag = listener_config.get("tag", port)
|
||||||
|
resources = {}
|
||||||
|
for res in listener_config["resources"]:
|
||||||
|
for name in res["names"]:
|
||||||
|
if name == "metrics":
|
||||||
|
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||||
|
elif name == "federation":
|
||||||
|
resources.update({
|
||||||
|
FEDERATION_PREFIX: TransportLayerServer(self),
|
||||||
|
})
|
||||||
|
|
||||||
|
root_resource = create_resource_tree(resources, Resource())
|
||||||
|
reactor.listenTCP(
|
||||||
|
port,
|
||||||
|
SynapseSite(
|
||||||
|
"synapse.access.http.%s" % (site_tag,),
|
||||||
|
site_tag,
|
||||||
|
listener_config,
|
||||||
|
root_resource,
|
||||||
|
),
|
||||||
|
interface=bind_address
|
||||||
|
)
|
||||||
|
logger.info("Synapse federation reader now listening on port %d", port)
|
||||||
|
|
||||||
|
def start_listening(self, listeners):
|
||||||
|
for listener in listeners:
|
||||||
|
if listener["type"] == "http":
|
||||||
|
self._listen_http(listener)
|
||||||
|
elif listener["type"] == "manhole":
|
||||||
|
reactor.listenTCP(
|
||||||
|
listener["port"],
|
||||||
|
manhole(
|
||||||
|
username="matrix",
|
||||||
|
password="rabbithole",
|
||||||
|
globals={"hs": self},
|
||||||
|
),
|
||||||
|
interface=listener.get("bind_address", '127.0.0.1')
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def replicate(self):
|
||||||
|
http_client = self.get_simple_http_client()
|
||||||
|
store = self.get_datastore()
|
||||||
|
replication_url = self.config.worker_replication_url
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
args = store.stream_positions()
|
||||||
|
args["timeout"] = 30000
|
||||||
|
result = yield http_client.get_json(replication_url, args=args)
|
||||||
|
yield store.process_replication(result)
|
||||||
|
except:
|
||||||
|
logger.exception("Error replicating from %r", replication_url)
|
||||||
|
yield sleep(5)
|
||||||
|
|
||||||
|
|
||||||
|
def start(config_options):
|
||||||
|
try:
|
||||||
|
config = HomeServerConfig.load_config(
|
||||||
|
"Synapse federation reader", config_options
|
||||||
|
)
|
||||||
|
except ConfigError as e:
|
||||||
|
sys.stderr.write("\n" + e.message + "\n")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
assert config.worker_app == "synapse.app.federation_reader"
|
||||||
|
|
||||||
|
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||||
|
|
||||||
|
database_engine = create_engine(config.database_config)
|
||||||
|
|
||||||
|
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||||
|
|
||||||
|
ss = FederationReaderServer(
|
||||||
|
config.server_name,
|
||||||
|
db_config=config.database_config,
|
||||||
|
tls_server_context_factory=tls_server_context_factory,
|
||||||
|
config=config,
|
||||||
|
version_string="Synapse/" + get_version_string(synapse),
|
||||||
|
database_engine=database_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
ss.setup()
|
||||||
|
ss.get_handlers()
|
||||||
|
ss.start_listening(config.worker_listeners)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
with LoggingContext("run"):
|
||||||
|
logger.info("Running")
|
||||||
|
change_resource_limit(config.soft_file_limit)
|
||||||
|
if config.gc_thresholds:
|
||||||
|
gc.set_threshold(*config.gc_thresholds)
|
||||||
|
reactor.run()
|
||||||
|
|
||||||
|
def start():
|
||||||
|
ss.get_datastore().start_profiling()
|
||||||
|
ss.replicate()
|
||||||
|
|
||||||
|
reactor.callWhenRunning(start)
|
||||||
|
|
||||||
|
if config.worker_daemonize:
|
||||||
|
daemon = Daemonize(
|
||||||
|
app="synapse-federation-reader",
|
||||||
|
pid=config.worker_pid_file,
|
||||||
|
action=run,
|
||||||
|
auto_close_fds=False,
|
||||||
|
verbose=True,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
daemon.start()
|
||||||
|
else:
|
||||||
|
run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
with LoggingContext("main"):
|
||||||
|
start(sys.argv[1:])
|
|
@ -51,6 +51,7 @@ from synapse.api.urls import (
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.crypto import context_factory
|
from synapse.crypto import context_factory
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
from synapse.metrics import register_memory_metrics
|
||||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||||
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
|
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
|
||||||
from synapse.federation.transport.server import TransportLayerServer
|
from synapse.federation.transport.server import TransportLayerServer
|
||||||
|
@ -147,7 +148,7 @@ class SynapseHomeServer(HomeServer):
|
||||||
MEDIA_PREFIX: media_repo,
|
MEDIA_PREFIX: media_repo,
|
||||||
LEGACY_MEDIA_PREFIX: media_repo,
|
LEGACY_MEDIA_PREFIX: media_repo,
|
||||||
CONTENT_REPO_PREFIX: ContentRepoResource(
|
CONTENT_REPO_PREFIX: ContentRepoResource(
|
||||||
self, self.config.uploads_path, self.auth, self.content_addr
|
self, self.config.uploads_path
|
||||||
),
|
),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -284,7 +285,7 @@ def setup(config_options):
|
||||||
# check any extra requirements we have now we have a config
|
# check any extra requirements we have now we have a config
|
||||||
check_requirements(config)
|
check_requirements(config)
|
||||||
|
|
||||||
version_string = get_version_string("Synapse", synapse)
|
version_string = "Synapse/" + get_version_string(synapse)
|
||||||
|
|
||||||
logger.info("Server hostname: %s", config.server_name)
|
logger.info("Server hostname: %s", config.server_name)
|
||||||
logger.info("Server version: %s", version_string)
|
logger.info("Server version: %s", version_string)
|
||||||
|
@ -301,7 +302,6 @@ def setup(config_options):
|
||||||
db_config=config.database_config,
|
db_config=config.database_config,
|
||||||
tls_server_context_factory=tls_server_context_factory,
|
tls_server_context_factory=tls_server_context_factory,
|
||||||
config=config,
|
config=config,
|
||||||
content_addr=config.content_addr,
|
|
||||||
version_string=version_string,
|
version_string=version_string,
|
||||||
database_engine=database_engine,
|
database_engine=database_engine,
|
||||||
)
|
)
|
||||||
|
@ -336,6 +336,8 @@ def setup(config_options):
|
||||||
hs.get_datastore().start_doing_background_updates()
|
hs.get_datastore().start_doing_background_updates()
|
||||||
hs.get_replication_layer().start_get_pdu_cache()
|
hs.get_replication_layer().start_get_pdu_cache()
|
||||||
|
|
||||||
|
register_memory_metrics(hs)
|
||||||
|
|
||||||
reactor.callWhenRunning(start)
|
reactor.callWhenRunning(start)
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
|
@ -273,7 +273,7 @@ def start(config_options):
|
||||||
config.server_name,
|
config.server_name,
|
||||||
db_config=config.database_config,
|
db_config=config.database_config,
|
||||||
config=config,
|
config=config,
|
||||||
version_string=get_version_string("Synapse", synapse),
|
version_string="Synapse/" + get_version_string(synapse),
|
||||||
database_engine=database_engine,
|
database_engine=database_engine,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -424,7 +424,7 @@ def start(config_options):
|
||||||
config.server_name,
|
config.server_name,
|
||||||
db_config=config.database_config,
|
db_config=config.database_config,
|
||||||
config=config,
|
config=config,
|
||||||
version_string=get_version_string("Synapse", synapse),
|
version_string="Synapse/" + get_version_string(synapse),
|
||||||
database_engine=database_engine,
|
database_engine=database_engine,
|
||||||
application_service_handler=SynchrotronApplicationService(),
|
application_service_handler=SynchrotronApplicationService(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,40 +13,88 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import Config
|
from ._base import Config, ConfigError
|
||||||
|
|
||||||
|
|
||||||
|
MISSING_LDAP3 = (
|
||||||
|
"Missing ldap3 library. This is required for LDAP Authentication."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LDAPMode(object):
|
||||||
|
SIMPLE = "simple",
|
||||||
|
SEARCH = "search",
|
||||||
|
|
||||||
|
LIST = (SIMPLE, SEARCH)
|
||||||
|
|
||||||
|
|
||||||
class LDAPConfig(Config):
|
class LDAPConfig(Config):
|
||||||
def read_config(self, config):
|
def read_config(self, config):
|
||||||
ldap_config = config.get("ldap_config", None)
|
ldap_config = config.get("ldap_config", {})
|
||||||
if ldap_config:
|
|
||||||
self.ldap_enabled = ldap_config.get("enabled", False)
|
self.ldap_enabled = ldap_config.get("enabled", False)
|
||||||
self.ldap_server = ldap_config["server"]
|
|
||||||
self.ldap_port = ldap_config["port"]
|
if self.ldap_enabled:
|
||||||
self.ldap_tls = ldap_config.get("tls", False)
|
# verify dependencies are available
|
||||||
self.ldap_search_base = ldap_config["search_base"]
|
try:
|
||||||
self.ldap_search_property = ldap_config["search_property"]
|
import ldap3
|
||||||
self.ldap_email_property = ldap_config["email_property"]
|
ldap3 # to stop unused lint
|
||||||
self.ldap_full_name_property = ldap_config["full_name_property"]
|
except ImportError:
|
||||||
else:
|
raise ConfigError(MISSING_LDAP3)
|
||||||
self.ldap_enabled = False
|
|
||||||
self.ldap_server = None
|
self.ldap_mode = LDAPMode.SIMPLE
|
||||||
self.ldap_port = None
|
|
||||||
self.ldap_tls = False
|
# verify config sanity
|
||||||
self.ldap_search_base = None
|
self.require_keys(ldap_config, [
|
||||||
self.ldap_search_property = None
|
"uri",
|
||||||
self.ldap_email_property = None
|
"base",
|
||||||
self.ldap_full_name_property = None
|
"attributes",
|
||||||
|
])
|
||||||
|
|
||||||
|
self.ldap_uri = ldap_config["uri"]
|
||||||
|
self.ldap_start_tls = ldap_config.get("start_tls", False)
|
||||||
|
self.ldap_base = ldap_config["base"]
|
||||||
|
self.ldap_attributes = ldap_config["attributes"]
|
||||||
|
|
||||||
|
if "bind_dn" in ldap_config:
|
||||||
|
self.ldap_mode = LDAPMode.SEARCH
|
||||||
|
self.require_keys(ldap_config, [
|
||||||
|
"bind_dn",
|
||||||
|
"bind_password",
|
||||||
|
])
|
||||||
|
|
||||||
|
self.ldap_bind_dn = ldap_config["bind_dn"]
|
||||||
|
self.ldap_bind_password = ldap_config["bind_password"]
|
||||||
|
self.ldap_filter = ldap_config.get("filter", None)
|
||||||
|
|
||||||
|
# verify attribute lookup
|
||||||
|
self.require_keys(ldap_config['attributes'], [
|
||||||
|
"uid",
|
||||||
|
"name",
|
||||||
|
"mail",
|
||||||
|
])
|
||||||
|
|
||||||
|
def require_keys(self, config, required):
|
||||||
|
missing = [key for key in required if key not in config]
|
||||||
|
if missing:
|
||||||
|
raise ConfigError(
|
||||||
|
"LDAP enabled but missing required config values: {}".format(
|
||||||
|
", ".join(missing)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def default_config(self, **kwargs):
|
def default_config(self, **kwargs):
|
||||||
return """\
|
return """\
|
||||||
# ldap_config:
|
# ldap_config:
|
||||||
# enabled: true
|
# enabled: true
|
||||||
# server: "ldap://localhost"
|
# uri: "ldap://ldap.example.com:389"
|
||||||
# port: 389
|
# start_tls: true
|
||||||
# tls: false
|
# base: "ou=users,dc=example,dc=com"
|
||||||
# search_base: "ou=Users,dc=example,dc=com"
|
# attributes:
|
||||||
# search_property: "cn"
|
# uid: "cn"
|
||||||
# email_property: "email"
|
# mail: "email"
|
||||||
# full_name_property: "givenName"
|
# name: "givenName"
|
||||||
|
# #bind_dn:
|
||||||
|
# #bind_password:
|
||||||
|
# #filter: "(objectClass=posixAccount)"
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -23,10 +23,14 @@ class PasswordConfig(Config):
|
||||||
def read_config(self, config):
|
def read_config(self, config):
|
||||||
password_config = config.get("password_config", {})
|
password_config = config.get("password_config", {})
|
||||||
self.password_enabled = password_config.get("enabled", True)
|
self.password_enabled = password_config.get("enabled", True)
|
||||||
|
self.password_pepper = password_config.get("pepper", "")
|
||||||
|
|
||||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||||
return """
|
return """
|
||||||
# Enable password for login.
|
# Enable password for login.
|
||||||
password_config:
|
password_config:
|
||||||
enabled: true
|
enabled: true
|
||||||
|
# Uncomment and change to a secret random string for extra security.
|
||||||
|
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
|
||||||
|
#pepper: ""
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -107,26 +107,6 @@ class ServerConfig(Config):
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
|
|
||||||
# Attempt to guess the content_addr for the v0 content repostitory
|
|
||||||
content_addr = config.get("content_addr")
|
|
||||||
if not content_addr:
|
|
||||||
for listener in self.listeners:
|
|
||||||
if listener["type"] == "http" and not listener.get("tls", False):
|
|
||||||
unsecure_port = listener["port"]
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Could not determine 'content_addr'")
|
|
||||||
|
|
||||||
host = self.server_name
|
|
||||||
if ':' not in host:
|
|
||||||
host = "%s:%d" % (host, unsecure_port)
|
|
||||||
else:
|
|
||||||
host = host.split(':')[0]
|
|
||||||
host = "%s:%d" % (host, unsecure_port)
|
|
||||||
content_addr = "http://%s" % (host,)
|
|
||||||
|
|
||||||
self.content_addr = content_addr
|
|
||||||
|
|
||||||
def default_config(self, server_name, **kwargs):
|
def default_config(self, server_name, **kwargs):
|
||||||
if ":" in server_name:
|
if ":" in server_name:
|
||||||
bind_port = int(server_name.split(":")[1])
|
bind_port = int(server_name.split(":")[1])
|
||||||
|
@ -169,7 +149,6 @@ class ServerConfig(Config):
|
||||||
# room directory.
|
# room directory.
|
||||||
# secondary_directory_servers:
|
# secondary_directory_servers:
|
||||||
# - matrix.org
|
# - matrix.org
|
||||||
# - vector.im
|
|
||||||
|
|
||||||
# List of ports that Synapse should listen on, their purpose and their
|
# List of ports that Synapse should listen on, their purpose and their
|
||||||
# configuration.
|
# configuration.
|
||||||
|
|
|
@ -77,10 +77,12 @@ class SynapseKeyClientProtocol(HTTPClient):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.remote_key = defer.Deferred()
|
self.remote_key = defer.Deferred()
|
||||||
self.host = None
|
self.host = None
|
||||||
|
self._peer = None
|
||||||
|
|
||||||
def connectionMade(self):
|
def connectionMade(self):
|
||||||
self.host = self.transport.getHost()
|
self._peer = self.transport.getPeer()
|
||||||
logger.debug("Connected to %s", self.host)
|
logger.debug("Connected to %s", self._peer)
|
||||||
|
|
||||||
self.sendCommand(b"GET", self.path)
|
self.sendCommand(b"GET", self.path)
|
||||||
if self.host:
|
if self.host:
|
||||||
self.sendHeader(b"Host", self.host)
|
self.sendHeader(b"Host", self.host)
|
||||||
|
@ -124,7 +126,10 @@ class SynapseKeyClientProtocol(HTTPClient):
|
||||||
self.timer.cancel()
|
self.timer.cancel()
|
||||||
|
|
||||||
def on_timeout(self):
|
def on_timeout(self):
|
||||||
logger.debug("Timeout waiting for response from %s", self.host)
|
logger.debug(
|
||||||
|
"Timeout waiting for response from %s: %s",
|
||||||
|
self.host, self._peer,
|
||||||
|
)
|
||||||
self.errback(IOError("Timeout waiting for response"))
|
self.errback(IOError("Timeout waiting for response"))
|
||||||
self.transport.abortConnection()
|
self.transport.abortConnection()
|
||||||
|
|
||||||
|
@ -133,4 +138,5 @@ class SynapseKeyClientFactory(Factory):
|
||||||
def protocol(self):
|
def protocol(self):
|
||||||
protocol = SynapseKeyClientProtocol()
|
protocol = SynapseKeyClientProtocol()
|
||||||
protocol.path = self.path
|
protocol.path = self.path
|
||||||
|
protocol.host = self.host
|
||||||
return protocol
|
return protocol
|
||||||
|
|
|
@ -44,7 +44,21 @@ import logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
|
VerifyKeyRequest = namedtuple("VerifyRequest", (
|
||||||
|
"server_name", "key_ids", "json_object", "deferred"
|
||||||
|
))
|
||||||
|
"""
|
||||||
|
A request for a verify key to verify a JSON object.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
server_name(str): The name of the server to verify against.
|
||||||
|
key_ids(set(str)): The set of key_ids to that could be used to verify the
|
||||||
|
JSON object
|
||||||
|
json_object(dict): The JSON object to verify.
|
||||||
|
deferred(twisted.internet.defer.Deferred):
|
||||||
|
A deferred (server_name, key_id, verify_key) tuple that resolves when
|
||||||
|
a verify key has been fetched
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class Keyring(object):
|
class Keyring(object):
|
||||||
|
@ -74,39 +88,32 @@ class Keyring(object):
|
||||||
list of deferreds indicating success or failure to verify each
|
list of deferreds indicating success or failure to verify each
|
||||||
json object's signature for the given server_name.
|
json object's signature for the given server_name.
|
||||||
"""
|
"""
|
||||||
group_id_to_json = {}
|
verify_requests = []
|
||||||
group_id_to_group = {}
|
|
||||||
group_ids = []
|
|
||||||
|
|
||||||
next_group_id = 0
|
|
||||||
deferreds = {}
|
|
||||||
|
|
||||||
for server_name, json_object in server_and_json:
|
for server_name, json_object in server_and_json:
|
||||||
logger.debug("Verifying for %s", server_name)
|
logger.debug("Verifying for %s", server_name)
|
||||||
group_id = next_group_id
|
|
||||||
next_group_id += 1
|
|
||||||
group_ids.append(group_id)
|
|
||||||
|
|
||||||
key_ids = signature_ids(json_object, server_name)
|
key_ids = signature_ids(json_object, server_name)
|
||||||
if not key_ids:
|
if not key_ids:
|
||||||
deferreds[group_id] = defer.fail(SynapseError(
|
deferred = defer.fail(SynapseError(
|
||||||
400,
|
400,
|
||||||
"Not signed with a supported algorithm",
|
"Not signed with a supported algorithm",
|
||||||
Codes.UNAUTHORIZED,
|
Codes.UNAUTHORIZED,
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
deferreds[group_id] = defer.Deferred()
|
deferred = defer.Deferred()
|
||||||
|
|
||||||
group = KeyGroup(server_name, group_id, key_ids)
|
verify_request = VerifyKeyRequest(
|
||||||
|
server_name, key_ids, json_object, deferred
|
||||||
|
)
|
||||||
|
|
||||||
group_id_to_group[group_id] = group
|
verify_requests.append(verify_request)
|
||||||
group_id_to_json[group_id] = json_object
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def handle_key_deferred(group, deferred):
|
def handle_key_deferred(verify_request):
|
||||||
server_name = group.server_name
|
server_name = verify_request.server_name
|
||||||
try:
|
try:
|
||||||
_, _, key_id, verify_key = yield deferred
|
_, key_id, verify_key = yield verify_request.deferred
|
||||||
except IOError as e:
|
except IOError as e:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Got IOError when downloading keys for %s: %s %s",
|
"Got IOError when downloading keys for %s: %s %s",
|
||||||
|
@ -128,7 +135,7 @@ class Keyring(object):
|
||||||
Codes.UNAUTHORIZED,
|
Codes.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
json_object = group_id_to_json[group.group_id]
|
json_object = verify_request.json_object
|
||||||
|
|
||||||
try:
|
try:
|
||||||
verify_signed_json(json_object, server_name, verify_key)
|
verify_signed_json(json_object, server_name, verify_key)
|
||||||
|
@ -157,36 +164,34 @@ class Keyring(object):
|
||||||
|
|
||||||
# Actually start fetching keys.
|
# Actually start fetching keys.
|
||||||
wait_on_deferred.addBoth(
|
wait_on_deferred.addBoth(
|
||||||
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
|
lambda _: self.get_server_verify_keys(verify_requests)
|
||||||
)
|
)
|
||||||
|
|
||||||
# When we've finished fetching all the keys for a given server_name,
|
# When we've finished fetching all the keys for a given server_name,
|
||||||
# resolve the deferred passed to `wait_for_previous_lookups` so that
|
# resolve the deferred passed to `wait_for_previous_lookups` so that
|
||||||
# any lookups waiting will proceed.
|
# any lookups waiting will proceed.
|
||||||
server_to_gids = {}
|
server_to_request_ids = {}
|
||||||
|
|
||||||
def remove_deferreds(res, server_name, group_id):
|
def remove_deferreds(res, server_name, verify_request):
|
||||||
server_to_gids[server_name].discard(group_id)
|
request_id = id(verify_request)
|
||||||
if not server_to_gids[server_name]:
|
server_to_request_ids[server_name].discard(request_id)
|
||||||
|
if not server_to_request_ids[server_name]:
|
||||||
d = server_to_deferred.pop(server_name, None)
|
d = server_to_deferred.pop(server_name, None)
|
||||||
if d:
|
if d:
|
||||||
d.callback(None)
|
d.callback(None)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
for g_id, deferred in deferreds.items():
|
for verify_request in verify_requests:
|
||||||
server_name = group_id_to_group[g_id].server_name
|
server_name = verify_request.server_name
|
||||||
server_to_gids.setdefault(server_name, set()).add(g_id)
|
request_id = id(verify_request)
|
||||||
deferred.addBoth(remove_deferreds, server_name, g_id)
|
server_to_request_ids.setdefault(server_name, set()).add(request_id)
|
||||||
|
deferred.addBoth(remove_deferreds, server_name, verify_request)
|
||||||
|
|
||||||
# Pass those keys to handle_key_deferred so that the json object
|
# Pass those keys to handle_key_deferred so that the json object
|
||||||
# signatures can be verified
|
# signatures can be verified
|
||||||
return [
|
return [
|
||||||
preserve_context_over_fn(
|
preserve_context_over_fn(handle_key_deferred, verify_request)
|
||||||
handle_key_deferred,
|
for verify_request in verify_requests
|
||||||
group_id_to_group[g_id],
|
|
||||||
deferreds[g_id],
|
|
||||||
)
|
|
||||||
for g_id in group_ids
|
|
||||||
]
|
]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -220,7 +225,7 @@ class Keyring(object):
|
||||||
|
|
||||||
d.addBoth(rm, server_name)
|
d.addBoth(rm, server_name)
|
||||||
|
|
||||||
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
|
def get_server_verify_keys(self, verify_requests):
|
||||||
"""Takes a dict of KeyGroups and tries to find at least one key for
|
"""Takes a dict of KeyGroups and tries to find at least one key for
|
||||||
each group.
|
each group.
|
||||||
"""
|
"""
|
||||||
|
@ -237,62 +242,64 @@ class Keyring(object):
|
||||||
merged_results = {}
|
merged_results = {}
|
||||||
|
|
||||||
missing_keys = {}
|
missing_keys = {}
|
||||||
for group in group_id_to_group.values():
|
for verify_request in verify_requests:
|
||||||
missing_keys.setdefault(group.server_name, set()).update(
|
missing_keys.setdefault(verify_request.server_name, set()).update(
|
||||||
group.key_ids
|
verify_request.key_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
for fn in key_fetch_fns:
|
for fn in key_fetch_fns:
|
||||||
results = yield fn(missing_keys.items())
|
results = yield fn(missing_keys.items())
|
||||||
merged_results.update(results)
|
merged_results.update(results)
|
||||||
|
|
||||||
# We now need to figure out which groups we have keys for
|
# We now need to figure out which verify requests we have keys
|
||||||
# and which we don't
|
# for and which we don't
|
||||||
missing_groups = {}
|
missing_keys = {}
|
||||||
for group in group_id_to_group.values():
|
requests_missing_keys = []
|
||||||
for key_id in group.key_ids:
|
for verify_request in verify_requests:
|
||||||
if key_id in merged_results[group.server_name]:
|
server_name = verify_request.server_name
|
||||||
|
result_keys = merged_results[server_name]
|
||||||
|
|
||||||
|
if verify_request.deferred.called:
|
||||||
|
# We've already called this deferred, which probably
|
||||||
|
# means that we've already found a key for it.
|
||||||
|
continue
|
||||||
|
|
||||||
|
for key_id in verify_request.key_ids:
|
||||||
|
if key_id in result_keys:
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
group_id_to_deferred[group.group_id].callback((
|
verify_request.deferred.callback((
|
||||||
group.group_id,
|
server_name,
|
||||||
group.server_name,
|
|
||||||
key_id,
|
key_id,
|
||||||
merged_results[group.server_name][key_id],
|
result_keys[key_id],
|
||||||
))
|
))
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
missing_groups.setdefault(
|
# The else block is only reached if the loop above
|
||||||
group.server_name, []
|
# doesn't break.
|
||||||
).append(group)
|
missing_keys.setdefault(server_name, set()).update(
|
||||||
|
verify_request.key_ids
|
||||||
|
)
|
||||||
|
requests_missing_keys.append(verify_request)
|
||||||
|
|
||||||
if not missing_groups:
|
if not missing_keys:
|
||||||
break
|
break
|
||||||
|
|
||||||
missing_keys = {
|
for verify_request in requests_missing_keys.values():
|
||||||
server_name: set(
|
verify_request.deferred.errback(SynapseError(
|
||||||
key_id for group in groups for key_id in group.key_ids
|
|
||||||
)
|
|
||||||
for server_name, groups in missing_groups.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
for group in missing_groups.values():
|
|
||||||
group_id_to_deferred[group.group_id].errback(SynapseError(
|
|
||||||
401,
|
401,
|
||||||
"No key for %s with id %s" % (
|
"No key for %s with id %s" % (
|
||||||
group.server_name, group.key_ids,
|
verify_request.server_name, verify_request.key_ids,
|
||||||
),
|
),
|
||||||
Codes.UNAUTHORIZED,
|
Codes.UNAUTHORIZED,
|
||||||
))
|
))
|
||||||
|
|
||||||
def on_err(err):
|
def on_err(err):
|
||||||
for deferred in group_id_to_deferred.values():
|
for verify_request in verify_requests:
|
||||||
if not deferred.called:
|
if not verify_request.deferred.called:
|
||||||
deferred.errback(err)
|
verify_request.deferred.errback(err)
|
||||||
|
|
||||||
do_iterations().addErrback(on_err)
|
do_iterations().addErrback(on_err)
|
||||||
|
|
||||||
return group_id_to_deferred
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_keys_from_store(self, server_name_and_key_ids):
|
def get_keys_from_store(self, server_name_and_key_ids):
|
||||||
res = yield defer.gatherResults(
|
res = yield defer.gatherResults(
|
||||||
|
@ -447,7 +454,7 @@ class Keyring(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
processed_response = yield self.process_v2_response(
|
processed_response = yield self.process_v2_response(
|
||||||
perspective_name, response
|
perspective_name, response, only_from_server=False
|
||||||
)
|
)
|
||||||
|
|
||||||
for server_name, response_keys in processed_response.items():
|
for server_name, response_keys in processed_response.items():
|
||||||
|
@ -527,7 +534,7 @@ class Keyring(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def process_v2_response(self, from_server, response_json,
|
def process_v2_response(self, from_server, response_json,
|
||||||
requested_ids=[]):
|
requested_ids=[], only_from_server=True):
|
||||||
time_now_ms = self.clock.time_msec()
|
time_now_ms = self.clock.time_msec()
|
||||||
response_keys = {}
|
response_keys = {}
|
||||||
verify_keys = {}
|
verify_keys = {}
|
||||||
|
@ -551,6 +558,13 @@ class Keyring(object):
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
server_name = response_json["server_name"]
|
server_name = response_json["server_name"]
|
||||||
|
if only_from_server:
|
||||||
|
if server_name != from_server:
|
||||||
|
raise ValueError(
|
||||||
|
"Expected a response for server %r not %r" % (
|
||||||
|
from_server, server_name
|
||||||
|
)
|
||||||
|
)
|
||||||
for key_id in response_json["signatures"].get(server_name, {}):
|
for key_id in response_json["signatures"].get(server_name, {}):
|
||||||
if key_id not in response_json["verify_keys"]:
|
if key_id not in response_json["verify_keys"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -236,9 +236,9 @@ class FederationClient(FederationBase):
|
||||||
# TODO: Rate limit the number of times we try and get the same event.
|
# TODO: Rate limit the number of times we try and get the same event.
|
||||||
|
|
||||||
if self._get_pdu_cache:
|
if self._get_pdu_cache:
|
||||||
e = self._get_pdu_cache.get(event_id)
|
ev = self._get_pdu_cache.get(event_id)
|
||||||
if e:
|
if ev:
|
||||||
defer.returnValue(e)
|
defer.returnValue(ev)
|
||||||
|
|
||||||
pdu = None
|
pdu = None
|
||||||
for destination in destinations:
|
for destination in destinations:
|
||||||
|
@ -269,7 +269,7 @@ class FederationClient(FederationBase):
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
except SynapseError:
|
except SynapseError as e:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Failed to get PDU %s from %s because %s",
|
"Failed to get PDU %s from %s because %s",
|
||||||
event_id, destination, e,
|
event_id, destination, e,
|
||||||
|
@ -314,6 +314,42 @@ class FederationClient(FederationBase):
|
||||||
Deferred: Results in a list of PDUs.
|
Deferred: Results in a list of PDUs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# First we try and ask for just the IDs, as thats far quicker if
|
||||||
|
# we have most of the state and auth_chain already.
|
||||||
|
# However, this may 404 if the other side has an old synapse.
|
||||||
|
result = yield self.transport_layer.get_room_state_ids(
|
||||||
|
destination, room_id, event_id=event_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
state_event_ids = result["pdu_ids"]
|
||||||
|
auth_event_ids = result.get("auth_chain_ids", [])
|
||||||
|
|
||||||
|
fetched_events, failed_to_fetch = yield self.get_events(
|
||||||
|
[destination], room_id, set(state_event_ids + auth_event_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
if failed_to_fetch:
|
||||||
|
logger.warn("Failed to get %r", failed_to_fetch)
|
||||||
|
|
||||||
|
event_map = {
|
||||||
|
ev.event_id: ev for ev in fetched_events
|
||||||
|
}
|
||||||
|
|
||||||
|
pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
|
||||||
|
auth_chain = [
|
||||||
|
event_map[e_id] for e_id in auth_event_ids if e_id in event_map
|
||||||
|
]
|
||||||
|
|
||||||
|
auth_chain.sort(key=lambda e: e.depth)
|
||||||
|
|
||||||
|
defer.returnValue((pdus, auth_chain))
|
||||||
|
except HttpResponseException as e:
|
||||||
|
if e.code == 400 or e.code == 404:
|
||||||
|
logger.info("Failed to use get_room_state_ids API, falling back")
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
result = yield self.transport_layer.get_room_state(
|
result = yield self.transport_layer.get_room_state(
|
||||||
destination, room_id, event_id=event_id,
|
destination, room_id, event_id=event_id,
|
||||||
)
|
)
|
||||||
|
@ -327,18 +363,93 @@ class FederationClient(FederationBase):
|
||||||
for p in result.get("auth_chain", [])
|
for p in result.get("auth_chain", [])
|
||||||
]
|
]
|
||||||
|
|
||||||
|
seen_events = yield self.store.get_events([
|
||||||
|
ev.event_id for ev in itertools.chain(pdus, auth_chain)
|
||||||
|
])
|
||||||
|
|
||||||
signed_pdus = yield self._check_sigs_and_hash_and_fetch(
|
signed_pdus = yield self._check_sigs_and_hash_and_fetch(
|
||||||
destination, pdus, outlier=True
|
destination,
|
||||||
|
[p for p in pdus if p.event_id not in seen_events],
|
||||||
|
outlier=True
|
||||||
|
)
|
||||||
|
signed_pdus.extend(
|
||||||
|
seen_events[p.event_id] for p in pdus if p.event_id in seen_events
|
||||||
)
|
)
|
||||||
|
|
||||||
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
||||||
destination, auth_chain, outlier=True
|
destination,
|
||||||
|
[p for p in auth_chain if p.event_id not in seen_events],
|
||||||
|
outlier=True
|
||||||
|
)
|
||||||
|
signed_auth.extend(
|
||||||
|
seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
|
||||||
)
|
)
|
||||||
|
|
||||||
signed_auth.sort(key=lambda e: e.depth)
|
signed_auth.sort(key=lambda e: e.depth)
|
||||||
|
|
||||||
defer.returnValue((signed_pdus, signed_auth))
|
defer.returnValue((signed_pdus, signed_auth))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_events(self, destinations, room_id, event_ids, return_local=True):
|
||||||
|
"""Fetch events from some remote destinations, checking if we already
|
||||||
|
have them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destinations (list)
|
||||||
|
room_id (str)
|
||||||
|
event_ids (list)
|
||||||
|
return_local (bool): Whether to include events we already have in
|
||||||
|
the DB in the returned list of events
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred: A deferred resolving to a 2-tuple where the first is a list of
|
||||||
|
events and the second is a list of event ids that we failed to fetch.
|
||||||
|
"""
|
||||||
|
if return_local:
|
||||||
|
seen_events = yield self.store.get_events(event_ids)
|
||||||
|
signed_events = seen_events.values()
|
||||||
|
else:
|
||||||
|
seen_events = yield self.store.have_events(event_ids)
|
||||||
|
signed_events = []
|
||||||
|
|
||||||
|
failed_to_fetch = set()
|
||||||
|
|
||||||
|
missing_events = set(event_ids)
|
||||||
|
for k in seen_events:
|
||||||
|
missing_events.discard(k)
|
||||||
|
|
||||||
|
if not missing_events:
|
||||||
|
defer.returnValue((signed_events, failed_to_fetch))
|
||||||
|
|
||||||
|
def random_server_list():
|
||||||
|
srvs = list(destinations)
|
||||||
|
random.shuffle(srvs)
|
||||||
|
return srvs
|
||||||
|
|
||||||
|
batch_size = 20
|
||||||
|
missing_events = list(missing_events)
|
||||||
|
for i in xrange(0, len(missing_events), batch_size):
|
||||||
|
batch = set(missing_events[i:i + batch_size])
|
||||||
|
|
||||||
|
deferreds = [
|
||||||
|
self.get_pdu(
|
||||||
|
destinations=random_server_list(),
|
||||||
|
event_id=e_id,
|
||||||
|
)
|
||||||
|
for e_id in batch
|
||||||
|
]
|
||||||
|
|
||||||
|
res = yield defer.DeferredList(deferreds, consumeErrors=True)
|
||||||
|
for success, result in res:
|
||||||
|
if success:
|
||||||
|
signed_events.append(result)
|
||||||
|
batch.discard(result.event_id)
|
||||||
|
|
||||||
|
# We removed all events we successfully fetched from `batch`
|
||||||
|
failed_to_fetch.update(batch)
|
||||||
|
|
||||||
|
defer.returnValue((signed_events, failed_to_fetch))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def get_event_auth(self, destination, room_id, event_id):
|
def get_event_auth(self, destination, room_id, event_id):
|
||||||
|
@ -414,14 +525,19 @@ class FederationClient(FederationBase):
|
||||||
(destination, self.event_from_pdu_json(pdu_dict))
|
(destination, self.event_from_pdu_json(pdu_dict))
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
except CodeMessageException:
|
except CodeMessageException as e:
|
||||||
|
if not 500 <= e.code < 600:
|
||||||
raise
|
raise
|
||||||
|
else:
|
||||||
|
logger.warn(
|
||||||
|
"Failed to make_%s via %s: %s",
|
||||||
|
membership, destination, e.message
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Failed to make_%s via %s: %s",
|
"Failed to make_%s via %s: %s",
|
||||||
membership, destination, e.message
|
membership, destination, e.message
|
||||||
)
|
)
|
||||||
raise
|
|
||||||
|
|
||||||
raise RuntimeError("Failed to send to any server.")
|
raise RuntimeError("Failed to send to any server.")
|
||||||
|
|
||||||
|
@ -493,8 +609,14 @@ class FederationClient(FederationBase):
|
||||||
"auth_chain": signed_auth,
|
"auth_chain": signed_auth,
|
||||||
"origin": destination,
|
"origin": destination,
|
||||||
})
|
})
|
||||||
except CodeMessageException:
|
except CodeMessageException as e:
|
||||||
|
if not 500 <= e.code < 600:
|
||||||
raise
|
raise
|
||||||
|
else:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to send_join via %s: %s",
|
||||||
|
destination, e.message
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to send_join via %s: %s",
|
"Failed to send_join via %s: %s",
|
||||||
|
|
|
@ -21,10 +21,11 @@ from .units import Transaction, Edu
|
||||||
|
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async import Linearizer
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
from synapse.api.errors import FederationError, SynapseError
|
from synapse.api.errors import AuthError, FederationError, SynapseError
|
||||||
|
|
||||||
from synapse.crypto.event_signing import compute_event_signature
|
from synapse.crypto.event_signing import compute_event_signature
|
||||||
|
|
||||||
|
@ -48,7 +49,14 @@ class FederationServer(FederationBase):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(FederationServer, self).__init__(hs)
|
super(FederationServer, self).__init__(hs)
|
||||||
|
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
self._room_pdu_linearizer = Linearizer()
|
self._room_pdu_linearizer = Linearizer()
|
||||||
|
self._server_linearizer = Linearizer()
|
||||||
|
|
||||||
|
# We cache responses to state queries, as they take a while and often
|
||||||
|
# come in waves.
|
||||||
|
self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
|
||||||
|
|
||||||
def set_handler(self, handler):
|
def set_handler(self, handler):
|
||||||
"""Sets the handler that the replication layer will use to communicate
|
"""Sets the handler that the replication layer will use to communicate
|
||||||
|
@ -89,11 +97,14 @@ class FederationServer(FederationBase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def on_backfill_request(self, origin, room_id, versions, limit):
|
def on_backfill_request(self, origin, room_id, versions, limit):
|
||||||
|
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||||
pdus = yield self.handler.on_backfill_request(
|
pdus = yield self.handler.on_backfill_request(
|
||||||
origin, room_id, versions, limit
|
origin, room_id, versions, limit
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
|
res = self._transaction_from_pdus(pdus).get_dict()
|
||||||
|
|
||||||
|
defer.returnValue((200, res))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
|
@ -184,9 +195,50 @@ class FederationServer(FederationBase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def on_context_state_request(self, origin, room_id, event_id):
|
def on_context_state_request(self, origin, room_id, event_id):
|
||||||
if event_id:
|
if not event_id:
|
||||||
|
raise NotImplementedError("Specify an event")
|
||||||
|
|
||||||
|
in_room = yield self.auth.check_host_in_room(room_id, origin)
|
||||||
|
if not in_room:
|
||||||
|
raise AuthError(403, "Host not in room.")
|
||||||
|
|
||||||
|
result = self._state_resp_cache.get((room_id, event_id))
|
||||||
|
if not result:
|
||||||
|
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||||
|
resp = yield self._state_resp_cache.set(
|
||||||
|
(room_id, event_id),
|
||||||
|
self._on_context_state_request_compute(room_id, event_id)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
resp = yield result
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_state_ids_request(self, origin, room_id, event_id):
|
||||||
|
if not event_id:
|
||||||
|
raise NotImplementedError("Specify an event")
|
||||||
|
|
||||||
|
in_room = yield self.auth.check_host_in_room(room_id, origin)
|
||||||
|
if not in_room:
|
||||||
|
raise AuthError(403, "Host not in room.")
|
||||||
|
|
||||||
pdus = yield self.handler.get_state_for_pdu(
|
pdus = yield self.handler.get_state_for_pdu(
|
||||||
origin, room_id, event_id,
|
room_id, event_id,
|
||||||
|
)
|
||||||
|
auth_chain = yield self.store.get_auth_chain(
|
||||||
|
[pdu.event_id for pdu in pdus]
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, {
|
||||||
|
"pdu_ids": [pdu.event_id for pdu in pdus],
|
||||||
|
"auth_chain_ids": [pdu.event_id for pdu in auth_chain],
|
||||||
|
}))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _on_context_state_request_compute(self, room_id, event_id):
|
||||||
|
pdus = yield self.handler.get_state_for_pdu(
|
||||||
|
room_id, event_id,
|
||||||
)
|
)
|
||||||
auth_chain = yield self.store.get_auth_chain(
|
auth_chain = yield self.store.get_auth_chain(
|
||||||
[pdu.event_id for pdu in pdus]
|
[pdu.event_id for pdu in pdus]
|
||||||
|
@ -203,13 +255,11 @@ class FederationServer(FederationBase):
|
||||||
self.hs.config.signing_key[0]
|
self.hs.config.signing_key[0]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise NotImplementedError("Specify an event")
|
|
||||||
|
|
||||||
defer.returnValue((200, {
|
defer.returnValue({
|
||||||
"pdus": [pdu.get_pdu_json() for pdu in pdus],
|
"pdus": [pdu.get_pdu_json() for pdu in pdus],
|
||||||
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
|
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
|
||||||
}))
|
})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
|
@ -283,14 +333,16 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_event_auth(self, origin, room_id, event_id):
|
def on_event_auth(self, origin, room_id, event_id):
|
||||||
|
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
auth_pdus = yield self.handler.on_event_auth(event_id)
|
auth_pdus = yield self.handler.on_event_auth(event_id)
|
||||||
defer.returnValue((200, {
|
res = {
|
||||||
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
|
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
|
||||||
}))
|
}
|
||||||
|
defer.returnValue((200, res))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_query_auth_request(self, origin, content, event_id):
|
def on_query_auth_request(self, origin, content, room_id, event_id):
|
||||||
"""
|
"""
|
||||||
Content is a dict with keys::
|
Content is a dict with keys::
|
||||||
auth_chain (list): A list of events that give the auth chain.
|
auth_chain (list): A list of events that give the auth chain.
|
||||||
|
@ -309,6 +361,7 @@ class FederationServer(FederationBase):
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: Results in `dict` with the same format as `content`
|
Deferred: Results in `dict` with the same format as `content`
|
||||||
"""
|
"""
|
||||||
|
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||||
auth_chain = [
|
auth_chain = [
|
||||||
self.event_from_pdu_json(e)
|
self.event_from_pdu_json(e)
|
||||||
for e in content["auth_chain"]
|
for e in content["auth_chain"]
|
||||||
|
@ -340,27 +393,9 @@ class FederationServer(FederationBase):
|
||||||
(200, send_content)
|
(200, send_content)
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
@log_function
|
@log_function
|
||||||
def on_query_client_keys(self, origin, content):
|
def on_query_client_keys(self, origin, content):
|
||||||
query = []
|
return self.on_query_request("client_keys", content)
|
||||||
for user_id, device_ids in content.get("device_keys", {}).items():
|
|
||||||
if not device_ids:
|
|
||||||
query.append((user_id, None))
|
|
||||||
else:
|
|
||||||
for device_id in device_ids:
|
|
||||||
query.append((user_id, device_id))
|
|
||||||
|
|
||||||
results = yield self.store.get_e2e_device_keys(query)
|
|
||||||
|
|
||||||
json_result = {}
|
|
||||||
for user_id, device_keys in results.items():
|
|
||||||
for device_id, json_bytes in device_keys.items():
|
|
||||||
json_result.setdefault(user_id, {})[device_id] = json.loads(
|
|
||||||
json_bytes
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue({"device_keys": json_result})
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
|
@ -386,6 +421,7 @@ class FederationServer(FederationBase):
|
||||||
@log_function
|
@log_function
|
||||||
def on_get_missing_events(self, origin, room_id, earliest_events,
|
def on_get_missing_events(self, origin, room_id, earliest_events,
|
||||||
latest_events, limit, min_depth):
|
latest_events, limit, min_depth):
|
||||||
|
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||||
logger.info(
|
logger.info(
|
||||||
"on_get_missing_events: earliest_events: %r, latest_events: %r,"
|
"on_get_missing_events: earliest_events: %r, latest_events: %r,"
|
||||||
" limit: %d, min_depth: %d",
|
" limit: %d, min_depth: %d",
|
||||||
|
@ -396,7 +432,9 @@ class FederationServer(FederationBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(missing_events) < 5:
|
if len(missing_events) < 5:
|
||||||
logger.info("Returning %d events: %r", len(missing_events), missing_events)
|
logger.info(
|
||||||
|
"Returning %d events: %r", len(missing_events), missing_events
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Returning %d events", len(missing_events))
|
logger.info("Returning %d events", len(missing_events))
|
||||||
|
|
||||||
|
@ -567,7 +605,7 @@ class FederationServer(FederationBase):
|
||||||
origin, pdu.room_id, pdu.event_id,
|
origin, pdu.room_id, pdu.event_id,
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
logger.warn("Failed to get state for event: %s", pdu.event_id)
|
logger.exception("Failed to get state for event: %s", pdu.event_id)
|
||||||
|
|
||||||
yield self.handler.on_receive_pdu(
|
yield self.handler.on_receive_pdu(
|
||||||
origin,
|
origin,
|
||||||
|
|
|
@ -54,6 +54,28 @@ class TransportLayerClient(object):
|
||||||
destination, path=path, args={"event_id": event_id},
|
destination, path=path, args={"event_id": event_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def get_room_state_ids(self, destination, room_id, event_id):
|
||||||
|
""" Requests all state for a given room from the given server at the
|
||||||
|
given event. Returns the state's event_id's
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination (str): The host name of the remote home server we want
|
||||||
|
to get the state from.
|
||||||
|
context (str): The name of the context we want the state of
|
||||||
|
event_id (str): The event we want the context at.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred: Results in a dict received from the remote homeserver.
|
||||||
|
"""
|
||||||
|
logger.debug("get_room_state_ids dest=%s, room=%s",
|
||||||
|
destination, room_id)
|
||||||
|
|
||||||
|
path = PREFIX + "/state_ids/%s/" % room_id
|
||||||
|
return self.client.get_json(
|
||||||
|
destination, path=path, args={"event_id": event_id},
|
||||||
|
)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def get_event(self, destination, event_id, timeout=None):
|
def get_event(self, destination, event_id, timeout=None):
|
||||||
""" Requests the pdu with give id and origin from the given server.
|
""" Requests the pdu with give id and origin from the given server.
|
||||||
|
|
|
@ -18,13 +18,14 @@ from twisted.internet import defer
|
||||||
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
|
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
from synapse.http.servlet import parse_json_object_from_request, parse_string
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||||
|
from synapse.util.versionstring import get_version_string
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import simplejson as json
|
|
||||||
import re
|
import re
|
||||||
|
import synapse
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -60,6 +61,16 @@ class TransportLayerServer(JsonResource):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationError(SynapseError):
|
||||||
|
"""There was a problem authenticating the request"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NoAuthenticationError(AuthenticationError):
|
||||||
|
"""The request had no authentication information"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Authenticator(object):
|
class Authenticator(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.keyring = hs.get_keyring()
|
self.keyring = hs.get_keyring()
|
||||||
|
@ -67,7 +78,7 @@ class Authenticator(object):
|
||||||
|
|
||||||
# A method just so we can pass 'self' as the authenticator to the Servlets
|
# A method just so we can pass 'self' as the authenticator to the Servlets
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def authenticate_request(self, request):
|
def authenticate_request(self, request, content):
|
||||||
json_request = {
|
json_request = {
|
||||||
"method": request.method,
|
"method": request.method,
|
||||||
"uri": request.uri,
|
"uri": request.uri,
|
||||||
|
@ -75,17 +86,10 @@ class Authenticator(object):
|
||||||
"signatures": {},
|
"signatures": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
content = None
|
if content is not None:
|
||||||
origin = None
|
|
||||||
|
|
||||||
if request.method in ["PUT", "POST"]:
|
|
||||||
# TODO: Handle other method types? other content types?
|
|
||||||
try:
|
|
||||||
content_bytes = request.content.read()
|
|
||||||
content = json.loads(content_bytes)
|
|
||||||
json_request["content"] = content
|
json_request["content"] = content
|
||||||
except:
|
|
||||||
raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON)
|
origin = None
|
||||||
|
|
||||||
def parse_auth_header(header_str):
|
def parse_auth_header(header_str):
|
||||||
try:
|
try:
|
||||||
|
@ -103,14 +107,14 @@ class Authenticator(object):
|
||||||
sig = strip_quotes(param_dict["sig"])
|
sig = strip_quotes(param_dict["sig"])
|
||||||
return (origin, key, sig)
|
return (origin, key, sig)
|
||||||
except:
|
except:
|
||||||
raise SynapseError(
|
raise AuthenticationError(
|
||||||
400, "Malformed Authorization header", Codes.UNAUTHORIZED
|
400, "Malformed Authorization header", Codes.UNAUTHORIZED
|
||||||
)
|
)
|
||||||
|
|
||||||
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
||||||
|
|
||||||
if not auth_headers:
|
if not auth_headers:
|
||||||
raise SynapseError(
|
raise NoAuthenticationError(
|
||||||
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
|
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -121,7 +125,7 @@ class Authenticator(object):
|
||||||
json_request["signatures"].setdefault(origin, {})[key] = sig
|
json_request["signatures"].setdefault(origin, {})[key] = sig
|
||||||
|
|
||||||
if not json_request["signatures"]:
|
if not json_request["signatures"]:
|
||||||
raise SynapseError(
|
raise NoAuthenticationError(
|
||||||
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
|
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -130,10 +134,12 @@ class Authenticator(object):
|
||||||
logger.info("Request from %s", origin)
|
logger.info("Request from %s", origin)
|
||||||
request.authenticated_entity = origin
|
request.authenticated_entity = origin
|
||||||
|
|
||||||
defer.returnValue((origin, content))
|
defer.returnValue(origin)
|
||||||
|
|
||||||
|
|
||||||
class BaseFederationServlet(object):
|
class BaseFederationServlet(object):
|
||||||
|
REQUIRE_AUTH = True
|
||||||
|
|
||||||
def __init__(self, handler, authenticator, ratelimiter, server_name,
|
def __init__(self, handler, authenticator, ratelimiter, server_name,
|
||||||
room_list_handler):
|
room_list_handler):
|
||||||
self.handler = handler
|
self.handler = handler
|
||||||
|
@ -141,29 +147,46 @@ class BaseFederationServlet(object):
|
||||||
self.ratelimiter = ratelimiter
|
self.ratelimiter = ratelimiter
|
||||||
self.room_list_handler = room_list_handler
|
self.room_list_handler = room_list_handler
|
||||||
|
|
||||||
def _wrap(self, code):
|
def _wrap(self, func):
|
||||||
authenticator = self.authenticator
|
authenticator = self.authenticator
|
||||||
ratelimiter = self.ratelimiter
|
ratelimiter = self.ratelimiter
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@functools.wraps(code)
|
@functools.wraps(func)
|
||||||
def new_code(request, *args, **kwargs):
|
def new_func(request, *args, **kwargs):
|
||||||
|
content = None
|
||||||
|
if request.method in ["PUT", "POST"]:
|
||||||
|
# TODO: Handle other method types? other content types?
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
(origin, content) = yield authenticator.authenticate_request(request)
|
origin = yield authenticator.authenticate_request(request, content)
|
||||||
with ratelimiter.ratelimit(origin) as d:
|
except NoAuthenticationError:
|
||||||
yield d
|
origin = None
|
||||||
response = yield code(
|
if self.REQUIRE_AUTH:
|
||||||
origin, content, request.args, *args, **kwargs
|
logger.exception("authenticate_request failed")
|
||||||
)
|
raise
|
||||||
except:
|
except:
|
||||||
logger.exception("authenticate_request failed")
|
logger.exception("authenticate_request failed")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
if origin:
|
||||||
|
with ratelimiter.ratelimit(origin) as d:
|
||||||
|
yield d
|
||||||
|
response = yield func(
|
||||||
|
origin, content, request.args, *args, **kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = yield func(
|
||||||
|
origin, content, request.args, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue(response)
|
defer.returnValue(response)
|
||||||
|
|
||||||
# Extra logic that functools.wraps() doesn't finish
|
# Extra logic that functools.wraps() doesn't finish
|
||||||
new_code.__self__ = code.__self__
|
new_func.__self__ = func.__self__
|
||||||
|
|
||||||
return new_code
|
return new_func
|
||||||
|
|
||||||
def register(self, server):
|
def register(self, server):
|
||||||
pattern = re.compile("^" + PREFIX + self.PATH + "$")
|
pattern = re.compile("^" + PREFIX + self.PATH + "$")
|
||||||
|
@ -271,6 +294,17 @@ class FederationStateServlet(BaseFederationServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FederationStateIdsServlet(BaseFederationServlet):
|
||||||
|
PATH = "/state_ids/(?P<room_id>[^/]*)/"
|
||||||
|
|
||||||
|
def on_GET(self, origin, content, query, room_id):
|
||||||
|
return self.handler.on_state_ids_request(
|
||||||
|
origin,
|
||||||
|
room_id,
|
||||||
|
query.get("event_id", [None])[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FederationBackfillServlet(BaseFederationServlet):
|
class FederationBackfillServlet(BaseFederationServlet):
|
||||||
PATH = "/backfill/(?P<context>[^/]*)/"
|
PATH = "/backfill/(?P<context>[^/]*)/"
|
||||||
|
|
||||||
|
@ -367,10 +401,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
|
||||||
class FederationClientKeysQueryServlet(BaseFederationServlet):
|
class FederationClientKeysQueryServlet(BaseFederationServlet):
|
||||||
PATH = "/user/keys/query"
|
PATH = "/user/keys/query"
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def on_POST(self, origin, content, query):
|
def on_POST(self, origin, content, query):
|
||||||
response = yield self.handler.on_query_client_keys(origin, content)
|
return self.handler.on_query_client_keys(origin, content)
|
||||||
defer.returnValue((200, response))
|
|
||||||
|
|
||||||
|
|
||||||
class FederationClientKeysClaimServlet(BaseFederationServlet):
|
class FederationClientKeysClaimServlet(BaseFederationServlet):
|
||||||
|
@ -388,7 +420,7 @@ class FederationQueryAuthServlet(BaseFederationServlet):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, origin, content, query, context, event_id):
|
def on_POST(self, origin, content, query, context, event_id):
|
||||||
new_content = yield self.handler.on_query_auth_request(
|
new_content = yield self.handler.on_query_auth_request(
|
||||||
origin, content, event_id
|
origin, content, context, event_id
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, new_content))
|
defer.returnValue((200, new_content))
|
||||||
|
@ -420,9 +452,10 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
|
||||||
class On3pidBindServlet(BaseFederationServlet):
|
class On3pidBindServlet(BaseFederationServlet):
|
||||||
PATH = "/3pid/onbind"
|
PATH = "/3pid/onbind"
|
||||||
|
|
||||||
|
REQUIRE_AUTH = False
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, origin, content, query):
|
||||||
content = parse_json_object_from_request(request)
|
|
||||||
if "invites" in content:
|
if "invites" in content:
|
||||||
last_exception = None
|
last_exception = None
|
||||||
for invite in content["invites"]:
|
for invite in content["invites"]:
|
||||||
|
@ -444,11 +477,6 @@ class On3pidBindServlet(BaseFederationServlet):
|
||||||
raise last_exception
|
raise last_exception
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
# Avoid doing remote HS authorization checks which are done by default by
|
|
||||||
# BaseFederationServlet.
|
|
||||||
def _wrap(self, code):
|
|
||||||
return code
|
|
||||||
|
|
||||||
|
|
||||||
class OpenIdUserInfo(BaseFederationServlet):
|
class OpenIdUserInfo(BaseFederationServlet):
|
||||||
"""
|
"""
|
||||||
|
@ -469,9 +497,11 @@ class OpenIdUserInfo(BaseFederationServlet):
|
||||||
|
|
||||||
PATH = "/openid/userinfo"
|
PATH = "/openid/userinfo"
|
||||||
|
|
||||||
|
REQUIRE_AUTH = False
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, origin, content, query):
|
||||||
token = parse_string(request, "access_token")
|
token = query.get("access_token", [None])[0]
|
||||||
if token is None:
|
if token is None:
|
||||||
defer.returnValue((401, {
|
defer.returnValue((401, {
|
||||||
"errcode": "M_MISSING_TOKEN", "error": "Access Token required"
|
"errcode": "M_MISSING_TOKEN", "error": "Access Token required"
|
||||||
|
@ -488,11 +518,6 @@ class OpenIdUserInfo(BaseFederationServlet):
|
||||||
|
|
||||||
defer.returnValue((200, {"sub": user_id}))
|
defer.returnValue((200, {"sub": user_id}))
|
||||||
|
|
||||||
# Avoid doing remote HS authorization checks which are done by default by
|
|
||||||
# BaseFederationServlet.
|
|
||||||
def _wrap(self, code):
|
|
||||||
return code
|
|
||||||
|
|
||||||
|
|
||||||
class PublicRoomList(BaseFederationServlet):
|
class PublicRoomList(BaseFederationServlet):
|
||||||
"""
|
"""
|
||||||
|
@ -533,11 +558,26 @@ class PublicRoomList(BaseFederationServlet):
|
||||||
defer.returnValue((200, data))
|
defer.returnValue((200, data))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationVersionServlet(BaseFederationServlet):
|
||||||
|
PATH = "/version"
|
||||||
|
|
||||||
|
REQUIRE_AUTH = False
|
||||||
|
|
||||||
|
def on_GET(self, origin, content, query):
|
||||||
|
return defer.succeed((200, {
|
||||||
|
"server": {
|
||||||
|
"name": "Synapse",
|
||||||
|
"version": get_version_string(synapse)
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
SERVLET_CLASSES = (
|
SERVLET_CLASSES = (
|
||||||
FederationSendServlet,
|
FederationSendServlet,
|
||||||
FederationPullServlet,
|
FederationPullServlet,
|
||||||
FederationEventServlet,
|
FederationEventServlet,
|
||||||
FederationStateServlet,
|
FederationStateServlet,
|
||||||
|
FederationStateIdsServlet,
|
||||||
FederationBackfillServlet,
|
FederationBackfillServlet,
|
||||||
FederationQueryServlet,
|
FederationQueryServlet,
|
||||||
FederationMakeJoinServlet,
|
FederationMakeJoinServlet,
|
||||||
|
@ -555,6 +595,7 @@ SERVLET_CLASSES = (
|
||||||
On3pidBindServlet,
|
On3pidBindServlet,
|
||||||
OpenIdUserInfo,
|
OpenIdUserInfo,
|
||||||
PublicRoomList,
|
PublicRoomList,
|
||||||
|
FederationVersionServlet,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -31,10 +31,21 @@ from .search import SearchHandler
|
||||||
|
|
||||||
class Handlers(object):
|
class Handlers(object):
|
||||||
|
|
||||||
""" A collection of all the event handlers.
|
""" Deprecated. A collection of handlers.
|
||||||
|
|
||||||
There's no need to lazily create these; we'll just make them all eagerly
|
At some point most of the classes whose name ended "Handler" were
|
||||||
at construction time.
|
accessed through this class.
|
||||||
|
|
||||||
|
However this makes it painful to unit test the handlers and to run cut
|
||||||
|
down versions of synapse that only use specific handlers because using a
|
||||||
|
single handler required creating all of the handlers. So some of the
|
||||||
|
handlers have been lifted out of the Handlers object and are now accessed
|
||||||
|
directly through the homeserver object itself.
|
||||||
|
|
||||||
|
Any new handlers should follow the new pattern of being accessed through
|
||||||
|
the homeserver object and should not be added to the Handlers object.
|
||||||
|
|
||||||
|
The remaining handlers should be moved out of the handlers object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
|
|
@ -13,14 +13,14 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import LimitExceededError
|
import synapse.types
|
||||||
from synapse.api.constants import Membership, EventTypes
|
from synapse.api.constants import Membership, EventTypes
|
||||||
from synapse.types import UserID, Requester
|
from synapse.api.errors import LimitExceededError
|
||||||
|
from synapse.types import UserID
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -31,11 +31,15 @@ class BaseHandler(object):
|
||||||
Common base class for the event handlers.
|
Common base class for the event handlers.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
store (synapse.storage.events.StateStore):
|
store (synapse.storage.DataStore):
|
||||||
state_handler (synapse.state.StateHandler):
|
state_handler (synapse.state.StateHandler):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer):
|
||||||
|
"""
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
|
@ -120,7 +124,8 @@ class BaseHandler(object):
|
||||||
# and having homeservers have their own users leave keeps more
|
# and having homeservers have their own users leave keeps more
|
||||||
# of that decision-making and control local to the guest-having
|
# of that decision-making and control local to the guest-having
|
||||||
# homeserver.
|
# homeserver.
|
||||||
requester = Requester(target_user, "", True)
|
requester = synapse.types.create_requester(
|
||||||
|
target_user, is_guest=True)
|
||||||
handler = self.hs.get_handlers().room_member_handler
|
handler = self.hs.get_handlers().room_member_handler
|
||||||
yield handler.update_membership(
|
yield handler.update_membership(
|
||||||
requester,
|
requester,
|
||||||
|
|
|
@ -20,6 +20,7 @@ from synapse.api.constants import LoginType
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
|
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
|
from synapse.config.ldap import LDAPMode
|
||||||
|
|
||||||
from twisted.web.client import PartialDownloadError
|
from twisted.web.client import PartialDownloadError
|
||||||
|
|
||||||
|
@ -28,6 +29,12 @@ import bcrypt
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
import simplejson
|
import simplejson
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ldap3
|
||||||
|
except ImportError:
|
||||||
|
ldap3 = None
|
||||||
|
pass
|
||||||
|
|
||||||
import synapse.util.stringutils as stringutils
|
import synapse.util.stringutils as stringutils
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,6 +45,10 @@ class AuthHandler(BaseHandler):
|
||||||
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer):
|
||||||
|
"""
|
||||||
super(AuthHandler, self).__init__(hs)
|
super(AuthHandler, self).__init__(hs)
|
||||||
self.checkers = {
|
self.checkers = {
|
||||||
LoginType.PASSWORD: self._check_password_auth,
|
LoginType.PASSWORD: self._check_password_auth,
|
||||||
|
@ -50,19 +61,23 @@ class AuthHandler(BaseHandler):
|
||||||
self.INVALID_TOKEN_HTTP_STATUS = 401
|
self.INVALID_TOKEN_HTTP_STATUS = 401
|
||||||
|
|
||||||
self.ldap_enabled = hs.config.ldap_enabled
|
self.ldap_enabled = hs.config.ldap_enabled
|
||||||
self.ldap_server = hs.config.ldap_server
|
if self.ldap_enabled:
|
||||||
self.ldap_port = hs.config.ldap_port
|
if not ldap3:
|
||||||
self.ldap_tls = hs.config.ldap_tls
|
raise RuntimeError(
|
||||||
self.ldap_search_base = hs.config.ldap_search_base
|
'Missing ldap3 library. This is required for LDAP Authentication.'
|
||||||
self.ldap_search_property = hs.config.ldap_search_property
|
)
|
||||||
self.ldap_email_property = hs.config.ldap_email_property
|
self.ldap_mode = hs.config.ldap_mode
|
||||||
self.ldap_full_name_property = hs.config.ldap_full_name_property
|
self.ldap_uri = hs.config.ldap_uri
|
||||||
|
self.ldap_start_tls = hs.config.ldap_start_tls
|
||||||
if self.ldap_enabled is True:
|
self.ldap_base = hs.config.ldap_base
|
||||||
import ldap
|
self.ldap_filter = hs.config.ldap_filter
|
||||||
logger.info("Import ldap version: %s", ldap.__version__)
|
self.ldap_attributes = hs.config.ldap_attributes
|
||||||
|
if self.ldap_mode == LDAPMode.SEARCH:
|
||||||
|
self.ldap_bind_dn = hs.config.ldap_bind_dn
|
||||||
|
self.ldap_bind_password = hs.config.ldap_bind_password
|
||||||
|
|
||||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||||
|
self.device_handler = hs.get_device_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_auth(self, flows, clientdict, clientip):
|
def check_auth(self, flows, clientdict, clientip):
|
||||||
|
@ -220,7 +235,6 @@ class AuthHandler(BaseHandler):
|
||||||
sess = self._get_session_info(session_id)
|
sess = self._get_session_info(session_id)
|
||||||
return sess.setdefault('serverdict', {}).get(key, default)
|
return sess.setdefault('serverdict', {}).get(key, default)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _check_password_auth(self, authdict, _):
|
def _check_password_auth(self, authdict, _):
|
||||||
if "user" not in authdict or "password" not in authdict:
|
if "user" not in authdict or "password" not in authdict:
|
||||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
raise LoginError(400, "", Codes.MISSING_PARAM)
|
||||||
|
@ -230,11 +244,7 @@ class AuthHandler(BaseHandler):
|
||||||
if not user_id.startswith('@'):
|
if not user_id.startswith('@'):
|
||||||
user_id = UserID.create(user_id, self.hs.hostname).to_string()
|
user_id = UserID.create(user_id, self.hs.hostname).to_string()
|
||||||
|
|
||||||
if not (yield self._check_password(user_id, password)):
|
return self._check_password(user_id, password)
|
||||||
logger.warn("Failed password login for user %s", user_id)
|
|
||||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
|
||||||
|
|
||||||
defer.returnValue(user_id)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_recaptcha(self, authdict, clientip):
|
def _check_recaptcha(self, authdict, clientip):
|
||||||
|
@ -270,7 +280,16 @@ class AuthHandler(BaseHandler):
|
||||||
data = pde.response
|
data = pde.response
|
||||||
resp_body = simplejson.loads(data)
|
resp_body = simplejson.loads(data)
|
||||||
|
|
||||||
if 'success' in resp_body and resp_body['success']:
|
if 'success' in resp_body:
|
||||||
|
# Note that we do NOT check the hostname here: we explicitly
|
||||||
|
# intend the CAPTCHA to be presented by whatever client the
|
||||||
|
# user is using, we just care that they have completed a CAPTCHA.
|
||||||
|
logger.info(
|
||||||
|
"%s reCAPTCHA from hostname %s",
|
||||||
|
"Successful" if resp_body['success'] else "Failed",
|
||||||
|
resp_body.get('hostname')
|
||||||
|
)
|
||||||
|
if resp_body['success']:
|
||||||
defer.returnValue(True)
|
defer.returnValue(True)
|
||||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||||
|
|
||||||
|
@ -338,67 +357,84 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
return self.sessions[session_id]
|
return self.sessions[session_id]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
def validate_password_login(self, user_id, password):
|
||||||
def login_with_password(self, user_id, password):
|
|
||||||
"""
|
"""
|
||||||
Authenticates the user with their username and password.
|
Authenticates the user with their username and password.
|
||||||
|
|
||||||
Used only by the v1 login API.
|
Used only by the v1 login API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): User ID
|
user_id (str): complete @user:id
|
||||||
password (str): Password
|
password (str): Password
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of:
|
defer.Deferred: (str) canonical user id
|
||||||
The user's ID.
|
|
||||||
The access token for the user's session.
|
|
||||||
The refresh token for the user's session.
|
|
||||||
Raises:
|
Raises:
|
||||||
StoreError if there was a problem storing the token.
|
StoreError if there was a problem accessing the database
|
||||||
LoginError if there was an authentication problem.
|
LoginError if there was an authentication problem.
|
||||||
"""
|
"""
|
||||||
|
return self._check_password(user_id, password)
|
||||||
if not (yield self._check_password(user_id, password)):
|
|
||||||
logger.warn("Failed password login for user %s", user_id)
|
|
||||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
|
||||||
|
|
||||||
logger.info("Logging in user %s", user_id)
|
|
||||||
access_token = yield self.issue_access_token(user_id)
|
|
||||||
refresh_token = yield self.issue_refresh_token(user_id)
|
|
||||||
defer.returnValue((user_id, access_token, refresh_token))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_login_tuple_for_user_id(self, user_id):
|
def get_login_tuple_for_user_id(self, user_id, device_id=None,
|
||||||
|
initial_display_name=None):
|
||||||
"""
|
"""
|
||||||
Gets login tuple for the user with the given user ID.
|
Gets login tuple for the user with the given user ID.
|
||||||
|
|
||||||
|
Creates a new access/refresh token for the user.
|
||||||
|
|
||||||
The user is assumed to have been authenticated by some other
|
The user is assumed to have been authenticated by some other
|
||||||
machanism (e.g. CAS)
|
machanism (e.g. CAS), and the user_id converted to the canonical case.
|
||||||
|
|
||||||
|
The device will be recorded in the table if it is not there already.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): User ID
|
user_id (str): canonical User ID
|
||||||
|
device_id (str|None): the device ID to associate with the tokens.
|
||||||
|
None to leave the tokens unassociated with a device (deprecated:
|
||||||
|
we should always have a device ID)
|
||||||
|
initial_display_name (str): display name to associate with the
|
||||||
|
device if it needs re-registering
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of:
|
A tuple of:
|
||||||
The user's ID.
|
|
||||||
The access token for the user's session.
|
The access token for the user's session.
|
||||||
The refresh token for the user's session.
|
The refresh token for the user's session.
|
||||||
Raises:
|
Raises:
|
||||||
StoreError if there was a problem storing the token.
|
StoreError if there was a problem storing the token.
|
||||||
LoginError if there was an authentication problem.
|
LoginError if there was an authentication problem.
|
||||||
"""
|
"""
|
||||||
user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)
|
logger.info("Logging in user %s on device %s", user_id, device_id)
|
||||||
|
access_token = yield self.issue_access_token(user_id, device_id)
|
||||||
|
refresh_token = yield self.issue_refresh_token(user_id, device_id)
|
||||||
|
|
||||||
logger.info("Logging in user %s", user_id)
|
# the device *should* have been registered before we got here; however,
|
||||||
access_token = yield self.issue_access_token(user_id)
|
# it's possible we raced against a DELETE operation. The thing we
|
||||||
refresh_token = yield self.issue_refresh_token(user_id)
|
# really don't want is active access_tokens without a record of the
|
||||||
defer.returnValue((user_id, access_token, refresh_token))
|
# device, so we double-check it here.
|
||||||
|
if device_id is not None:
|
||||||
|
yield self.device_handler.check_device_registered(
|
||||||
|
user_id, device_id, initial_display_name
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((access_token, refresh_token))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def does_user_exist(self, user_id):
|
def check_user_exists(self, user_id):
|
||||||
|
"""
|
||||||
|
Checks to see if a user with the given id exists. Will check case
|
||||||
|
insensitively, but return None if there are multiple inexact matches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
(str) user_id: complete @user:id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
defer.Deferred: (str) canonical_user_id, or None if zero or
|
||||||
|
multiple matches
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
yield self._find_user_id_and_pwd_hash(user_id)
|
res = yield self._find_user_id_and_pwd_hash(user_id)
|
||||||
defer.returnValue(True)
|
defer.returnValue(res[0])
|
||||||
except LoginError:
|
except LoginError:
|
||||||
defer.returnValue(False)
|
defer.returnValue(None)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _find_user_id_and_pwd_hash(self, user_id):
|
def _find_user_id_and_pwd_hash(self, user_id):
|
||||||
|
@ -428,84 +464,232 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_password(self, user_id, password):
|
def _check_password(self, user_id, password):
|
||||||
"""
|
"""Authenticate a user against the LDAP and local databases.
|
||||||
|
|
||||||
|
user_id is checked case insensitively against the local database, but
|
||||||
|
will throw if there are multiple inexact matches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): complete @user:id
|
||||||
Returns:
|
Returns:
|
||||||
True if the user_id successfully authenticated
|
(str) the canonical_user_id
|
||||||
|
Raises:
|
||||||
|
LoginError if the password was incorrect
|
||||||
"""
|
"""
|
||||||
valid_ldap = yield self._check_ldap_password(user_id, password)
|
valid_ldap = yield self._check_ldap_password(user_id, password)
|
||||||
if valid_ldap:
|
if valid_ldap:
|
||||||
defer.returnValue(True)
|
defer.returnValue(user_id)
|
||||||
|
|
||||||
valid_local_password = yield self._check_local_password(user_id, password)
|
result = yield self._check_local_password(user_id, password)
|
||||||
if valid_local_password:
|
defer.returnValue(result)
|
||||||
defer.returnValue(True)
|
|
||||||
|
|
||||||
defer.returnValue(False)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_local_password(self, user_id, password):
|
def _check_local_password(self, user_id, password):
|
||||||
try:
|
"""Authenticate a user against the local password database.
|
||||||
|
|
||||||
|
user_id is checked case insensitively, but will throw if there are
|
||||||
|
multiple inexact matches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): complete @user:id
|
||||||
|
Returns:
|
||||||
|
(str) the canonical_user_id
|
||||||
|
Raises:
|
||||||
|
LoginError if the password was incorrect
|
||||||
|
"""
|
||||||
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
|
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
|
||||||
defer.returnValue(self.validate_hash(password, password_hash))
|
result = self.validate_hash(password, password_hash)
|
||||||
except LoginError:
|
if not result:
|
||||||
defer.returnValue(False)
|
logger.warn("Failed password login for user %s", user_id)
|
||||||
|
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||||
|
defer.returnValue(user_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_ldap_password(self, user_id, password):
|
def _check_ldap_password(self, user_id, password):
|
||||||
if not self.ldap_enabled:
|
""" Attempt to authenticate a user against an LDAP Server
|
||||||
logger.debug("LDAP not configured")
|
and register an account if none exists.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if authentication against LDAP was successful
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not ldap3 or not self.ldap_enabled:
|
||||||
defer.returnValue(False)
|
defer.returnValue(False)
|
||||||
|
|
||||||
import ldap
|
if self.ldap_mode not in LDAPMode.LIST:
|
||||||
|
raise RuntimeError(
|
||||||
logger.info("Authenticating %s with LDAP" % user_id)
|
'Invalid ldap mode specified: {mode}'.format(
|
||||||
try:
|
mode=self.ldap_mode
|
||||||
ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port)
|
)
|
||||||
logger.debug("Connecting LDAP server at %s" % ldap_url)
|
|
||||||
l = ldap.initialize(ldap_url)
|
|
||||||
if self.ldap_tls:
|
|
||||||
logger.debug("Initiating TLS")
|
|
||||||
self._connection.start_tls_s()
|
|
||||||
|
|
||||||
local_name = UserID.from_string(user_id).localpart
|
|
||||||
|
|
||||||
dn = "%s=%s, %s" % (
|
|
||||||
self.ldap_search_property,
|
|
||||||
local_name,
|
|
||||||
self.ldap_search_base)
|
|
||||||
logger.debug("DN for LDAP authentication: %s" % dn)
|
|
||||||
|
|
||||||
l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8'))
|
|
||||||
|
|
||||||
if not (yield self.does_user_exist(user_id)):
|
|
||||||
handler = self.hs.get_handlers().registration_handler
|
|
||||||
user_id, access_token = (
|
|
||||||
yield handler.register(localpart=local_name)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
server = ldap3.Server(self.ldap_uri)
|
||||||
|
logger.debug(
|
||||||
|
"Attempting ldap connection with %s",
|
||||||
|
self.ldap_uri
|
||||||
|
)
|
||||||
|
|
||||||
|
localpart = UserID.from_string(user_id).localpart
|
||||||
|
if self.ldap_mode == LDAPMode.SIMPLE:
|
||||||
|
# bind with the the local users ldap credentials
|
||||||
|
bind_dn = "{prop}={value},{base}".format(
|
||||||
|
prop=self.ldap_attributes['uid'],
|
||||||
|
value=localpart,
|
||||||
|
base=self.ldap_base
|
||||||
|
)
|
||||||
|
conn = ldap3.Connection(server, bind_dn, password)
|
||||||
|
logger.debug(
|
||||||
|
"Established ldap connection in simple mode: %s",
|
||||||
|
conn
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.ldap_start_tls:
|
||||||
|
conn.start_tls()
|
||||||
|
logger.debug(
|
||||||
|
"Upgraded ldap connection in simple mode through StartTLS: %s",
|
||||||
|
conn
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.bind()
|
||||||
|
|
||||||
|
elif self.ldap_mode == LDAPMode.SEARCH:
|
||||||
|
# connect with preconfigured credentials and search for local user
|
||||||
|
conn = ldap3.Connection(
|
||||||
|
server,
|
||||||
|
self.ldap_bind_dn,
|
||||||
|
self.ldap_bind_password
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"Established ldap connection in search mode: %s",
|
||||||
|
conn
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.ldap_start_tls:
|
||||||
|
conn.start_tls()
|
||||||
|
logger.debug(
|
||||||
|
"Upgraded ldap connection in search mode through StartTLS: %s",
|
||||||
|
conn
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.bind()
|
||||||
|
|
||||||
|
# find matching dn
|
||||||
|
query = "({prop}={value})".format(
|
||||||
|
prop=self.ldap_attributes['uid'],
|
||||||
|
value=localpart
|
||||||
|
)
|
||||||
|
if self.ldap_filter:
|
||||||
|
query = "(&{query}{filter})".format(
|
||||||
|
query=query,
|
||||||
|
filter=self.ldap_filter
|
||||||
|
)
|
||||||
|
logger.debug("ldap search filter: %s", query)
|
||||||
|
result = conn.search(self.ldap_base, query)
|
||||||
|
|
||||||
|
if result and len(conn.response) == 1:
|
||||||
|
# found exactly one result
|
||||||
|
user_dn = conn.response[0]['dn']
|
||||||
|
logger.debug('ldap search found dn: %s', user_dn)
|
||||||
|
|
||||||
|
# unbind and reconnect, rebind with found dn
|
||||||
|
conn.unbind()
|
||||||
|
conn = ldap3.Connection(
|
||||||
|
server,
|
||||||
|
user_dn,
|
||||||
|
password,
|
||||||
|
auto_bind=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# found 0 or > 1 results, abort!
|
||||||
|
logger.warn(
|
||||||
|
"ldap search returned unexpected (%d!=1) amount of results",
|
||||||
|
len(conn.response)
|
||||||
|
)
|
||||||
|
defer.returnValue(False)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"User authenticated against ldap server: %s",
|
||||||
|
conn
|
||||||
|
)
|
||||||
|
|
||||||
|
# check for existing account, if none exists, create one
|
||||||
|
if not (yield self.check_user_exists(user_id)):
|
||||||
|
# query user metadata for account creation
|
||||||
|
query = "({prop}={value})".format(
|
||||||
|
prop=self.ldap_attributes['uid'],
|
||||||
|
value=localpart
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
|
||||||
|
query = "(&{filter}{user_filter})".format(
|
||||||
|
filter=query,
|
||||||
|
user_filter=self.ldap_filter
|
||||||
|
)
|
||||||
|
logger.debug("ldap registration filter: %s", query)
|
||||||
|
|
||||||
|
result = conn.search(
|
||||||
|
search_base=self.ldap_base,
|
||||||
|
search_filter=query,
|
||||||
|
attributes=[
|
||||||
|
self.ldap_attributes['name'],
|
||||||
|
self.ldap_attributes['mail']
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(conn.response) == 1:
|
||||||
|
attrs = conn.response[0]['attributes']
|
||||||
|
mail = attrs[self.ldap_attributes['mail']][0]
|
||||||
|
name = attrs[self.ldap_attributes['name']][0]
|
||||||
|
|
||||||
|
# create account
|
||||||
|
registration_handler = self.hs.get_handlers().registration_handler
|
||||||
|
user_id, access_token = (
|
||||||
|
yield registration_handler.register(localpart=localpart)
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: bind email, set displayname with data from ldap directory
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"ldap registration successful: %d: %s (%s, %)",
|
||||||
|
user_id,
|
||||||
|
localpart,
|
||||||
|
name,
|
||||||
|
mail
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warn(
|
||||||
|
"ldap registration failed: unexpected (%d!=1) amount of results",
|
||||||
|
len(result)
|
||||||
|
)
|
||||||
|
defer.returnValue(False)
|
||||||
|
|
||||||
defer.returnValue(True)
|
defer.returnValue(True)
|
||||||
except ldap.LDAPError, e:
|
except ldap3.core.exceptions.LDAPException as e:
|
||||||
logger.warn("LDAP error: %s", e)
|
logger.warn("Error during ldap authentication: %s", e)
|
||||||
defer.returnValue(False)
|
defer.returnValue(False)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def issue_access_token(self, user_id):
|
def issue_access_token(self, user_id, device_id=None):
|
||||||
access_token = self.generate_access_token(user_id)
|
access_token = self.generate_access_token(user_id)
|
||||||
yield self.store.add_access_token_to_user(user_id, access_token)
|
yield self.store.add_access_token_to_user(user_id, access_token,
|
||||||
|
device_id)
|
||||||
defer.returnValue(access_token)
|
defer.returnValue(access_token)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def issue_refresh_token(self, user_id):
|
def issue_refresh_token(self, user_id, device_id=None):
|
||||||
refresh_token = self.generate_refresh_token(user_id)
|
refresh_token = self.generate_refresh_token(user_id)
|
||||||
yield self.store.add_refresh_token_to_user(user_id, refresh_token)
|
yield self.store.add_refresh_token_to_user(user_id, refresh_token,
|
||||||
|
device_id)
|
||||||
defer.returnValue(refresh_token)
|
defer.returnValue(refresh_token)
|
||||||
|
|
||||||
def generate_access_token(self, user_id, extra_caveats=None):
|
def generate_access_token(self, user_id, extra_caveats=None,
|
||||||
|
duration_in_ms=(60 * 60 * 1000)):
|
||||||
extra_caveats = extra_caveats or []
|
extra_caveats = extra_caveats or []
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
macaroon.add_first_party_caveat("type = access")
|
macaroon.add_first_party_caveat("type = access")
|
||||||
now = self.hs.get_clock().time_msec()
|
now = self.hs.get_clock().time_msec()
|
||||||
expiry = now + (60 * 60 * 1000)
|
expiry = now + duration_in_ms
|
||||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||||
for caveat in extra_caveats:
|
for caveat in extra_caveats:
|
||||||
macaroon.add_first_party_caveat(caveat)
|
macaroon.add_first_party_caveat(caveat)
|
||||||
|
@ -613,7 +797,8 @@ class AuthHandler(BaseHandler):
|
||||||
Returns:
|
Returns:
|
||||||
Hashed password (str).
|
Hashed password (str).
|
||||||
"""
|
"""
|
||||||
return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds))
|
return bcrypt.hashpw(password + self.hs.config.password_pepper,
|
||||||
|
bcrypt.gensalt(self.bcrypt_rounds))
|
||||||
|
|
||||||
def validate_hash(self, password, stored_hash):
|
def validate_hash(self, password, stored_hash):
|
||||||
"""Validates that self.hash(password) == stored_hash.
|
"""Validates that self.hash(password) == stored_hash.
|
||||||
|
@ -626,6 +811,7 @@ class AuthHandler(BaseHandler):
|
||||||
Whether self.hash(password) == stored_hash (bool).
|
Whether self.hash(password) == stored_hash (bool).
|
||||||
"""
|
"""
|
||||||
if stored_hash:
|
if stored_hash:
|
||||||
return bcrypt.hashpw(password, stored_hash.encode('utf-8')) == stored_hash
|
return bcrypt.hashpw(password + self.hs.config.password_pepper,
|
||||||
|
stored_hash.encode('utf-8')) == stored_hash
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
181
synapse/handlers/device.py
Normal file
181
synapse/handlers/device.py
Normal file
|
@ -0,0 +1,181 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.api import errors
|
||||||
|
from synapse.util import stringutils
|
||||||
|
from twisted.internet import defer
|
||||||
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceHandler(BaseHandler):
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(DeviceHandler, self).__init__(hs)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def check_device_registered(self, user_id, device_id,
|
||||||
|
initial_device_display_name=None):
|
||||||
|
"""
|
||||||
|
If the given device has not been registered, register it with the
|
||||||
|
supplied display name.
|
||||||
|
|
||||||
|
If no device_id is supplied, we make one up.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): @user:id
|
||||||
|
device_id (str | None): device id supplied by client
|
||||||
|
initial_device_display_name (str | None): device display name from
|
||||||
|
client
|
||||||
|
Returns:
|
||||||
|
str: device id (generated if none was supplied)
|
||||||
|
"""
|
||||||
|
if device_id is not None:
|
||||||
|
yield self.store.store_device(
|
||||||
|
user_id=user_id,
|
||||||
|
device_id=device_id,
|
||||||
|
initial_device_display_name=initial_device_display_name,
|
||||||
|
ignore_if_known=True,
|
||||||
|
)
|
||||||
|
defer.returnValue(device_id)
|
||||||
|
|
||||||
|
# if the device id is not specified, we'll autogen one, but loop a few
|
||||||
|
# times in case of a clash.
|
||||||
|
attempts = 0
|
||||||
|
while attempts < 5:
|
||||||
|
try:
|
||||||
|
device_id = stringutils.random_string_with_symbols(16)
|
||||||
|
yield self.store.store_device(
|
||||||
|
user_id=user_id,
|
||||||
|
device_id=device_id,
|
||||||
|
initial_device_display_name=initial_device_display_name,
|
||||||
|
ignore_if_known=False,
|
||||||
|
)
|
||||||
|
defer.returnValue(device_id)
|
||||||
|
except errors.StoreError:
|
||||||
|
attempts += 1
|
||||||
|
|
||||||
|
raise errors.StoreError(500, "Couldn't generate a device ID.")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_devices_by_user(self, user_id):
|
||||||
|
"""
|
||||||
|
Retrieve the given user's devices
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str):
|
||||||
|
Returns:
|
||||||
|
defer.Deferred: list[dict[str, X]]: info on each device
|
||||||
|
"""
|
||||||
|
|
||||||
|
device_map = yield self.store.get_devices_by_user(user_id)
|
||||||
|
|
||||||
|
ips = yield self.store.get_last_client_ip_by_device(
|
||||||
|
devices=((user_id, device_id) for device_id in device_map.keys())
|
||||||
|
)
|
||||||
|
|
||||||
|
devices = device_map.values()
|
||||||
|
for device in devices:
|
||||||
|
_update_device_from_client_ips(device, ips)
|
||||||
|
|
||||||
|
defer.returnValue(devices)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_device(self, user_id, device_id):
|
||||||
|
""" Retrieve the given device
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str):
|
||||||
|
device_id (str):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
defer.Deferred: dict[str, X]: info on the device
|
||||||
|
Raises:
|
||||||
|
errors.NotFoundError: if the device was not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
device = yield self.store.get_device(user_id, device_id)
|
||||||
|
except errors.StoreError:
|
||||||
|
raise errors.NotFoundError
|
||||||
|
ips = yield self.store.get_last_client_ip_by_device(
|
||||||
|
devices=((user_id, device_id),)
|
||||||
|
)
|
||||||
|
_update_device_from_client_ips(device, ips)
|
||||||
|
defer.returnValue(device)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def delete_device(self, user_id, device_id):
|
||||||
|
""" Delete the given device
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str):
|
||||||
|
device_id (str):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
defer.Deferred:
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self.store.delete_device(user_id, device_id)
|
||||||
|
except errors.StoreError, e:
|
||||||
|
if e.code == 404:
|
||||||
|
# no match
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
yield self.store.user_delete_access_tokens(
|
||||||
|
user_id, device_id=device_id,
|
||||||
|
delete_refresh_tokens=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.store.delete_e2e_keys_by_device(
|
||||||
|
user_id=user_id, device_id=device_id
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def update_device(self, user_id, device_id, content):
|
||||||
|
""" Update the given device
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str):
|
||||||
|
device_id (str):
|
||||||
|
content (dict): body of update request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
defer.Deferred:
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self.store.update_device(
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
|
new_display_name=content.get("display_name")
|
||||||
|
)
|
||||||
|
except errors.StoreError, e:
|
||||||
|
if e.code == 404:
|
||||||
|
raise errors.NotFoundError()
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _update_device_from_client_ips(device, client_ips):
|
||||||
|
ip = client_ips.get((device["user_id"], device["device_id"]), {})
|
||||||
|
device.update({
|
||||||
|
"last_seen_ts": ip.get("last_seen"),
|
||||||
|
"last_seen_ip": ip.get("ip"),
|
||||||
|
})
|
139
synapse/handlers/e2e_keys.py
Normal file
139
synapse/handlers/e2e_keys.py
Normal file
|
@ -0,0 +1,139 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api import errors
|
||||||
|
import synapse.types
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class E2eKeysHandler(object):
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.federation = hs.get_replication_layer()
|
||||||
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
self.server_name = hs.hostname
|
||||||
|
|
||||||
|
# doesn't really work as part of the generic query API, because the
|
||||||
|
# query request requires an object POST, but we abuse the
|
||||||
|
# "query handler" interface.
|
||||||
|
self.federation.register_query_handler(
|
||||||
|
"client_keys", self.on_federation_query_client_keys
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def query_devices(self, query_body):
|
||||||
|
""" Handle a device key query from a client
|
||||||
|
|
||||||
|
{
|
||||||
|
"device_keys": {
|
||||||
|
"<user_id>": ["<device_id>"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
->
|
||||||
|
{
|
||||||
|
"device_keys": {
|
||||||
|
"<user_id>": {
|
||||||
|
"<device_id>": {
|
||||||
|
...
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
device_keys_query = query_body.get("device_keys", {})
|
||||||
|
|
||||||
|
# separate users by domain.
|
||||||
|
# make a map from domain to user_id to device_ids
|
||||||
|
queries_by_domain = collections.defaultdict(dict)
|
||||||
|
for user_id, device_ids in device_keys_query.items():
|
||||||
|
user = synapse.types.UserID.from_string(user_id)
|
||||||
|
queries_by_domain[user.domain][user_id] = device_ids
|
||||||
|
|
||||||
|
# do the queries
|
||||||
|
# TODO: do these in parallel
|
||||||
|
results = {}
|
||||||
|
for destination, destination_query in queries_by_domain.items():
|
||||||
|
if destination == self.server_name:
|
||||||
|
res = yield self.query_local_devices(destination_query)
|
||||||
|
else:
|
||||||
|
res = yield self.federation.query_client_keys(
|
||||||
|
destination, {"device_keys": destination_query}
|
||||||
|
)
|
||||||
|
res = res["device_keys"]
|
||||||
|
for user_id, keys in res.items():
|
||||||
|
if user_id in destination_query:
|
||||||
|
results[user_id] = keys
|
||||||
|
|
||||||
|
defer.returnValue((200, {"device_keys": results}))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def query_local_devices(self, query):
|
||||||
|
"""Get E2E device keys for local users
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (dict[string, list[string]|None): map from user_id to a list
|
||||||
|
of devices to query (None for all devices)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
defer.Deferred: (resolves to dict[string, dict[string, dict]]):
|
||||||
|
map from user_id -> device_id -> device details
|
||||||
|
"""
|
||||||
|
local_query = []
|
||||||
|
|
||||||
|
result_dict = {}
|
||||||
|
for user_id, device_ids in query.items():
|
||||||
|
if not self.is_mine_id(user_id):
|
||||||
|
logger.warning("Request for keys for non-local user %s",
|
||||||
|
user_id)
|
||||||
|
raise errors.SynapseError(400, "Not a user here")
|
||||||
|
|
||||||
|
if not device_ids:
|
||||||
|
local_query.append((user_id, None))
|
||||||
|
else:
|
||||||
|
for device_id in device_ids:
|
||||||
|
local_query.append((user_id, device_id))
|
||||||
|
|
||||||
|
# make sure that each queried user appears in the result dict
|
||||||
|
result_dict[user_id] = {}
|
||||||
|
|
||||||
|
results = yield self.store.get_e2e_device_keys(local_query)
|
||||||
|
|
||||||
|
# Build the result structure, un-jsonify the results, and add the
|
||||||
|
# "unsigned" section
|
||||||
|
for user_id, device_keys in results.items():
|
||||||
|
for device_id, device_info in device_keys.items():
|
||||||
|
r = json.loads(device_info["key_json"])
|
||||||
|
r["unsigned"] = {}
|
||||||
|
display_name = device_info["device_display_name"]
|
||||||
|
if display_name is not None:
|
||||||
|
r["unsigned"]["device_display_name"] = display_name
|
||||||
|
result_dict[user_id][device_id] = r
|
||||||
|
|
||||||
|
defer.returnValue(result_dict)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_federation_query_client_keys(self, query_body):
|
||||||
|
""" Handle a device key query from a federated server
|
||||||
|
"""
|
||||||
|
device_keys_query = query_body.get("device_keys", {})
|
||||||
|
res = yield self.query_local_devices(device_keys_query)
|
||||||
|
defer.returnValue({"device_keys": res})
|
|
@ -124,7 +124,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
event_stream_id, max_stream_id = yield self._persist_auth_tree(
|
event_stream_id, max_stream_id = yield self._persist_auth_tree(
|
||||||
auth_chain, state, event
|
origin, auth_chain, state, event
|
||||||
)
|
)
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
raise FederationError(
|
raise FederationError(
|
||||||
|
@ -335,18 +335,35 @@ class FederationHandler(BaseHandler):
|
||||||
state_events.update({s.event_id: s for s in state})
|
state_events.update({s.event_id: s for s in state})
|
||||||
events_to_state[e_id] = state
|
events_to_state[e_id] = state
|
||||||
|
|
||||||
seen_events = yield self.store.have_events(
|
|
||||||
set(auth_events.keys()) | set(state_events.keys())
|
|
||||||
)
|
|
||||||
|
|
||||||
all_events = events + state_events.values() + auth_events.values()
|
|
||||||
required_auth = set(
|
required_auth = set(
|
||||||
a_id for event in all_events for a_id, _ in event.auth_events
|
a_id
|
||||||
|
for event in events + state_events.values() + auth_events.values()
|
||||||
|
for a_id, _ in event.auth_events
|
||||||
|
)
|
||||||
|
auth_events.update({
|
||||||
|
e_id: event_map[e_id] for e_id in required_auth if e_id in event_map
|
||||||
|
})
|
||||||
|
missing_auth = required_auth - set(auth_events)
|
||||||
|
failed_to_fetch = set()
|
||||||
|
|
||||||
|
# Try and fetch any missing auth events from both DB and remote servers.
|
||||||
|
# We repeatedly do this until we stop finding new auth events.
|
||||||
|
while missing_auth - failed_to_fetch:
|
||||||
|
logger.info("Missing auth for backfill: %r", missing_auth)
|
||||||
|
ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
|
||||||
|
auth_events.update(ret_events)
|
||||||
|
|
||||||
|
required_auth.update(
|
||||||
|
a_id for event in ret_events.values() for a_id, _ in event.auth_events
|
||||||
|
)
|
||||||
|
missing_auth = required_auth - set(auth_events)
|
||||||
|
|
||||||
|
if missing_auth - failed_to_fetch:
|
||||||
|
logger.info(
|
||||||
|
"Fetching missing auth for backfill: %r",
|
||||||
|
missing_auth - failed_to_fetch
|
||||||
)
|
)
|
||||||
|
|
||||||
missing_auth = required_auth - set(auth_events)
|
|
||||||
if missing_auth:
|
|
||||||
logger.info("Missing auth for backfill: %r", missing_auth)
|
|
||||||
results = yield defer.gatherResults(
|
results = yield defer.gatherResults(
|
||||||
[
|
[
|
||||||
self.replication_layer.get_pdu(
|
self.replication_layer.get_pdu(
|
||||||
|
@ -355,11 +372,21 @@ class FederationHandler(BaseHandler):
|
||||||
outlier=True,
|
outlier=True,
|
||||||
timeout=10000,
|
timeout=10000,
|
||||||
)
|
)
|
||||||
for event_id in missing_auth
|
for event_id in missing_auth - failed_to_fetch
|
||||||
],
|
],
|
||||||
consumeErrors=True
|
consumeErrors=True
|
||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
auth_events.update({a.event_id: a for a in results})
|
auth_events.update({a.event_id: a for a in results})
|
||||||
|
required_auth.update(
|
||||||
|
a_id for event in results for a_id, _ in event.auth_events
|
||||||
|
)
|
||||||
|
missing_auth = required_auth - set(auth_events)
|
||||||
|
|
||||||
|
failed_to_fetch = missing_auth - set(auth_events)
|
||||||
|
|
||||||
|
seen_events = yield self.store.have_events(
|
||||||
|
set(auth_events.keys()) | set(state_events.keys())
|
||||||
|
)
|
||||||
|
|
||||||
ev_infos = []
|
ev_infos = []
|
||||||
for a in auth_events.values():
|
for a in auth_events.values():
|
||||||
|
@ -372,6 +399,7 @@ class FederationHandler(BaseHandler):
|
||||||
(auth_events[a_id].type, auth_events[a_id].state_key):
|
(auth_events[a_id].type, auth_events[a_id].state_key):
|
||||||
auth_events[a_id]
|
auth_events[a_id]
|
||||||
for a_id, _ in a.auth_events
|
for a_id, _ in a.auth_events
|
||||||
|
if a_id in auth_events
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -383,6 +411,7 @@ class FederationHandler(BaseHandler):
|
||||||
(auth_events[a_id].type, auth_events[a_id].state_key):
|
(auth_events[a_id].type, auth_events[a_id].state_key):
|
||||||
auth_events[a_id]
|
auth_events[a_id]
|
||||||
for a_id, _ in event_map[e_id].auth_events
|
for a_id, _ in event_map[e_id].auth_events
|
||||||
|
if a_id in auth_events
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -637,7 +666,7 @@ class FederationHandler(BaseHandler):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
event_stream_id, max_stream_id = yield self._persist_auth_tree(
|
event_stream_id, max_stream_id = yield self._persist_auth_tree(
|
||||||
auth_chain, state, event
|
origin, auth_chain, state, event
|
||||||
)
|
)
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
|
@ -688,7 +717,9 @@ class FederationHandler(BaseHandler):
|
||||||
logger.warn("Failed to create join %r because %s", event, e)
|
logger.warn("Failed to create join %r because %s", event, e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
self.auth.check(event, auth_events=context.current_state)
|
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
||||||
|
# when we get the event back in `on_send_join_request`
|
||||||
|
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
|
||||||
|
|
||||||
defer.returnValue(event)
|
defer.returnValue(event)
|
||||||
|
|
||||||
|
@ -918,7 +949,9 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.auth.check(event, auth_events=context.current_state)
|
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
||||||
|
# when we get the event back in `on_send_leave_request`
|
||||||
|
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
logger.warn("Failed to create new leave %r because %s", event, e)
|
logger.warn("Failed to create new leave %r because %s", event, e)
|
||||||
raise e
|
raise e
|
||||||
|
@ -987,14 +1020,9 @@ class FederationHandler(BaseHandler):
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
|
def get_state_for_pdu(self, room_id, event_id):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
if do_auth:
|
|
||||||
in_room = yield self.auth.check_host_in_room(room_id, origin)
|
|
||||||
if not in_room:
|
|
||||||
raise AuthError(403, "Host not in room.")
|
|
||||||
|
|
||||||
state_groups = yield self.store.get_state_groups(
|
state_groups = yield self.store.get_state_groups(
|
||||||
room_id, [event_id]
|
room_id, [event_id]
|
||||||
)
|
)
|
||||||
|
@ -1114,6 +1142,7 @@ class FederationHandler(BaseHandler):
|
||||||
backfilled=backfilled,
|
backfilled=backfilled,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not backfilled:
|
||||||
# this intentionally does not yield: we don't care about the result
|
# this intentionally does not yield: we don't care about the result
|
||||||
# and don't need to wait for it.
|
# and don't need to wait for it.
|
||||||
preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
|
preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
|
||||||
|
@ -1150,11 +1179,19 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _persist_auth_tree(self, auth_events, state, event):
|
def _persist_auth_tree(self, origin, auth_events, state, event):
|
||||||
"""Checks the auth chain is valid (and passes auth checks) for the
|
"""Checks the auth chain is valid (and passes auth checks) for the
|
||||||
state and event. Then persists the auth chain and state atomically.
|
state and event. Then persists the auth chain and state atomically.
|
||||||
Persists the event seperately.
|
Persists the event seperately.
|
||||||
|
|
||||||
|
Will attempt to fetch missing auth events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
origin (str): Where the events came from
|
||||||
|
auth_events (list)
|
||||||
|
state (list)
|
||||||
|
event (Event)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
2-tuple of (event_stream_id, max_stream_id) from the persist_event
|
2-tuple of (event_stream_id, max_stream_id) from the persist_event
|
||||||
call for `event`
|
call for `event`
|
||||||
|
@ -1167,7 +1204,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
event_map = {
|
event_map = {
|
||||||
e.event_id: e
|
e.event_id: e
|
||||||
for e in auth_events
|
for e in itertools.chain(auth_events, state, [event])
|
||||||
}
|
}
|
||||||
|
|
||||||
create_event = None
|
create_event = None
|
||||||
|
@ -1176,10 +1213,29 @@ class FederationHandler(BaseHandler):
|
||||||
create_event = e
|
create_event = e
|
||||||
break
|
break
|
||||||
|
|
||||||
|
missing_auth_events = set()
|
||||||
|
for e in itertools.chain(auth_events, state, [event]):
|
||||||
|
for e_id, _ in e.auth_events:
|
||||||
|
if e_id not in event_map:
|
||||||
|
missing_auth_events.add(e_id)
|
||||||
|
|
||||||
|
for e_id in missing_auth_events:
|
||||||
|
m_ev = yield self.replication_layer.get_pdu(
|
||||||
|
[origin],
|
||||||
|
e_id,
|
||||||
|
outlier=True,
|
||||||
|
timeout=10000,
|
||||||
|
)
|
||||||
|
if m_ev and m_ev.event_id == e_id:
|
||||||
|
event_map[e_id] = m_ev
|
||||||
|
else:
|
||||||
|
logger.info("Failed to find auth event %r", e_id)
|
||||||
|
|
||||||
for e in itertools.chain(auth_events, state, [event]):
|
for e in itertools.chain(auth_events, state, [event]):
|
||||||
auth_for_e = {
|
auth_for_e = {
|
||||||
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
|
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
|
||||||
for e_id, _ in e.auth_events
|
for e_id, _ in e.auth_events
|
||||||
|
if e_id in event_map
|
||||||
}
|
}
|
||||||
if create_event:
|
if create_event:
|
||||||
auth_for_e[(EventTypes.Create, "")] = create_event
|
auth_for_e[(EventTypes.Create, "")] = create_event
|
||||||
|
@ -1413,7 +1469,7 @@ class FederationHandler(BaseHandler):
|
||||||
local_view = dict(auth_events)
|
local_view = dict(auth_events)
|
||||||
remote_view = dict(auth_events)
|
remote_view = dict(auth_events)
|
||||||
remote_view.update({
|
remote_view.update({
|
||||||
(d.type, d.state_key): d for d in different_events
|
(d.type, d.state_key): d for d in different_events if d
|
||||||
})
|
})
|
||||||
|
|
||||||
new_state, prev_state = self.state_handler.resolve_events(
|
new_state, prev_state = self.state_handler.resolve_events(
|
||||||
|
|
|
@ -21,7 +21,7 @@ from synapse.api.errors import (
|
||||||
)
|
)
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError, Codes
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
@ -41,6 +41,20 @@ class IdentityHandler(BaseHandler):
|
||||||
hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
|
hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _should_trust_id_server(self, id_server):
|
||||||
|
if id_server not in self.trusted_id_servers:
|
||||||
|
if self.trust_any_id_server_just_for_testing_do_not_use:
|
||||||
|
logger.warn(
|
||||||
|
"Trusting untrustworthy ID server %r even though it isn't"
|
||||||
|
" in the trusted id list for testing because"
|
||||||
|
" 'use_insecure_ssl_client_just_for_testing_do_not_use'"
|
||||||
|
" is set in the config",
|
||||||
|
id_server,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def threepid_from_creds(self, creds):
|
def threepid_from_creds(self, creds):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
@ -59,18 +73,11 @@ class IdentityHandler(BaseHandler):
|
||||||
else:
|
else:
|
||||||
raise SynapseError(400, "No client_secret in creds")
|
raise SynapseError(400, "No client_secret in creds")
|
||||||
|
|
||||||
if id_server not in self.trusted_id_servers:
|
if not self._should_trust_id_server(id_server):
|
||||||
if self.trust_any_id_server_just_for_testing_do_not_use:
|
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Trusting untrustworthy ID server %r even though it isn't"
|
'%s is not a trusted ID server: rejecting 3pid ' +
|
||||||
" in the trusted id list for testing because"
|
'credentials', id_server
|
||||||
" 'use_insecure_ssl_client_just_for_testing_do_not_use'"
|
|
||||||
" is set in the config",
|
|
||||||
id_server,
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
|
|
||||||
'credentials', id_server)
|
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
data = {}
|
data = {}
|
||||||
|
@ -129,6 +136,12 @@ class IdentityHandler(BaseHandler):
|
||||||
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
|
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
|
if not self._should_trust_id_server(id_server):
|
||||||
|
raise SynapseError(
|
||||||
|
400, "Untrusted ID server '%s'" % id_server,
|
||||||
|
Codes.SERVER_NOT_TRUSTED
|
||||||
|
)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'email': email,
|
'email': email,
|
||||||
'client_secret': client_secret,
|
'client_secret': client_secret,
|
||||||
|
|
|
@ -26,7 +26,7 @@ from synapse.types import (
|
||||||
UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id
|
UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id
|
||||||
)
|
)
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.async import concurrently_execute, run_on_reactor
|
from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
|
||||||
from synapse.util.caches.snapshot_cache import SnapshotCache
|
from synapse.util.caches.snapshot_cache import SnapshotCache
|
||||||
from synapse.util.logcontext import preserve_fn
|
from synapse.util.logcontext import preserve_fn
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
|
@ -50,9 +50,23 @@ class MessageHandler(BaseHandler):
|
||||||
self.validator = EventValidator()
|
self.validator = EventValidator()
|
||||||
self.snapshot_cache = SnapshotCache()
|
self.snapshot_cache = SnapshotCache()
|
||||||
|
|
||||||
|
self.pagination_lock = ReadWriteLock()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def purge_history(self, room_id, event_id):
|
||||||
|
event = yield self.store.get_event(event_id)
|
||||||
|
|
||||||
|
if event.room_id != room_id:
|
||||||
|
raise SynapseError(400, "Event is for wrong room.")
|
||||||
|
|
||||||
|
depth = event.depth
|
||||||
|
|
||||||
|
with (yield self.pagination_lock.write(room_id)):
|
||||||
|
yield self.store.delete_old_state(room_id, depth)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_messages(self, requester, room_id=None, pagin_config=None,
|
def get_messages(self, requester, room_id=None, pagin_config=None,
|
||||||
as_client_event=True):
|
as_client_event=True, event_filter=None):
|
||||||
"""Get messages in a room.
|
"""Get messages in a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -61,11 +75,11 @@ class MessageHandler(BaseHandler):
|
||||||
pagin_config (synapse.api.streams.PaginationConfig): The pagination
|
pagin_config (synapse.api.streams.PaginationConfig): The pagination
|
||||||
config rules to apply, if any.
|
config rules to apply, if any.
|
||||||
as_client_event (bool): True to get events in client-server format.
|
as_client_event (bool): True to get events in client-server format.
|
||||||
|
event_filter (Filter): Filter to apply to results or None
|
||||||
Returns:
|
Returns:
|
||||||
dict: Pagination API results
|
dict: Pagination API results
|
||||||
"""
|
"""
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
data_source = self.hs.get_event_sources().sources["room"]
|
|
||||||
|
|
||||||
if pagin_config.from_token:
|
if pagin_config.from_token:
|
||||||
room_token = pagin_config.from_token.room_key
|
room_token = pagin_config.from_token.room_key
|
||||||
|
@ -85,6 +99,7 @@ class MessageHandler(BaseHandler):
|
||||||
|
|
||||||
source_config = pagin_config.get_source_config("room")
|
source_config = pagin_config.get_source_config("room")
|
||||||
|
|
||||||
|
with (yield self.pagination_lock.read(room_id)):
|
||||||
membership, member_event_id = yield self._check_in_room_or_world_readable(
|
membership, member_event_id = yield self._check_in_room_or_world_readable(
|
||||||
room_id, user_id
|
room_id, user_id
|
||||||
)
|
)
|
||||||
|
@ -95,7 +110,7 @@ class MessageHandler(BaseHandler):
|
||||||
if room_token.topological:
|
if room_token.topological:
|
||||||
max_topo = room_token.topological
|
max_topo = room_token.topological
|
||||||
else:
|
else:
|
||||||
max_topo = yield self.store.get_max_topological_token_for_stream_and_room(
|
max_topo = yield self.store.get_max_topological_token(
|
||||||
room_id, room_token.stream
|
room_id, room_token.stream
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -114,8 +129,13 @@ class MessageHandler(BaseHandler):
|
||||||
room_id, max_topo
|
room_id, max_topo
|
||||||
)
|
)
|
||||||
|
|
||||||
events, next_key = yield data_source.get_pagination_rows(
|
events, next_key = yield self.store.paginate_room_events(
|
||||||
requester.user, source_config, room_id
|
room_id=room_id,
|
||||||
|
from_key=source_config.from_key,
|
||||||
|
to_key=source_config.to_key,
|
||||||
|
direction=source_config.direction,
|
||||||
|
limit=source_config.limit,
|
||||||
|
event_filter=event_filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
next_token = pagin_config.from_token.copy_and_replace(
|
next_token = pagin_config.from_token.copy_and_replace(
|
||||||
|
@ -129,6 +149,9 @@ class MessageHandler(BaseHandler):
|
||||||
"end": next_token.to_string(),
|
"end": next_token.to_string(),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if event_filter:
|
||||||
|
events = event_filter.filter(events)
|
||||||
|
|
||||||
events = yield filter_events_for_client(
|
events = yield filter_events_for_client(
|
||||||
self.store,
|
self.store,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
|
@ -13,15 +13,15 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import synapse.types
|
||||||
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
||||||
from synapse.types import UserID, Requester
|
from synapse.types import UserID
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -165,7 +165,9 @@ class ProfileHandler(BaseHandler):
|
||||||
try:
|
try:
|
||||||
# Assume the user isn't a guest because we don't let guests set
|
# Assume the user isn't a guest because we don't let guests set
|
||||||
# profile or avatar data.
|
# profile or avatar data.
|
||||||
requester = Requester(user, "", False)
|
# XXX why are we recreating `requester` here for each room?
|
||||||
|
# what was wrong with the `requester` we were passed?
|
||||||
|
requester = synapse.types.create_requester(user)
|
||||||
yield handler.update_membership(
|
yield handler.update_membership(
|
||||||
requester,
|
requester,
|
||||||
user,
|
user,
|
||||||
|
|
|
@ -14,18 +14,19 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Contains functions for registering clients."""
|
"""Contains functions for registering clients."""
|
||||||
|
import logging
|
||||||
|
import urllib
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.types import UserID, Requester
|
import synapse.types
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
|
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
|
||||||
)
|
)
|
||||||
from ._base import BaseHandler
|
|
||||||
from synapse.util.async import run_on_reactor
|
|
||||||
from synapse.http.client import CaptchaServerHttpClient
|
from synapse.http.client import CaptchaServerHttpClient
|
||||||
|
from synapse.types import UserID
|
||||||
import logging
|
from synapse.util.async import run_on_reactor
|
||||||
import urllib
|
from ._base import BaseHandler
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -52,6 +53,13 @@ class RegistrationHandler(BaseHandler):
|
||||||
Codes.INVALID_USERNAME
|
Codes.INVALID_USERNAME
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if localpart[0] == '_':
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"User ID may not begin with _",
|
||||||
|
Codes.INVALID_USERNAME
|
||||||
|
)
|
||||||
|
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
|
||||||
|
@ -90,7 +98,8 @@ class RegistrationHandler(BaseHandler):
|
||||||
password=None,
|
password=None,
|
||||||
generate_token=True,
|
generate_token=True,
|
||||||
guest_access_token=None,
|
guest_access_token=None,
|
||||||
make_guest=False
|
make_guest=False,
|
||||||
|
admin=False,
|
||||||
):
|
):
|
||||||
"""Registers a new client on the server.
|
"""Registers a new client on the server.
|
||||||
|
|
||||||
|
@ -100,6 +109,11 @@ class RegistrationHandler(BaseHandler):
|
||||||
password (str) : The password to assign to this user so they can
|
password (str) : The password to assign to this user so they can
|
||||||
login again. This can be None which means they cannot login again
|
login again. This can be None which means they cannot login again
|
||||||
via a password (e.g. the user is an application service user).
|
via a password (e.g. the user is an application service user).
|
||||||
|
generate_token (bool): Whether a new access token should be
|
||||||
|
generated. Having this be True should be considered deprecated,
|
||||||
|
since it offers no means of associating a device_id with the
|
||||||
|
access_token. Instead you should call auth_handler.issue_access_token
|
||||||
|
after registration.
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (user_id, access_token).
|
A tuple of (user_id, access_token).
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -141,6 +155,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
# If the user was a guest then they already have a profile
|
# If the user was a guest then they already have a profile
|
||||||
None if was_guest else user.localpart
|
None if was_guest else user.localpart
|
||||||
),
|
),
|
||||||
|
admin=admin,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# autogen a sequential user ID
|
# autogen a sequential user ID
|
||||||
|
@ -194,15 +209,13 @@ class RegistrationHandler(BaseHandler):
|
||||||
user_id, allowed_appservice=service
|
user_id, allowed_appservice=service
|
||||||
)
|
)
|
||||||
|
|
||||||
token = self.auth_handler().generate_access_token(user_id)
|
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
|
||||||
password_hash="",
|
password_hash="",
|
||||||
appservice_id=service_id,
|
appservice_id=service_id,
|
||||||
create_profile_with_localpart=user.localpart,
|
create_profile_with_localpart=user.localpart,
|
||||||
)
|
)
|
||||||
defer.returnValue((user_id, token))
|
defer.returnValue(user_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_recaptcha(self, ip, private_key, challenge, response):
|
def check_recaptcha(self, ip, private_key, challenge, response):
|
||||||
|
@ -358,7 +371,8 @@ class RegistrationHandler(BaseHandler):
|
||||||
defer.returnValue(data)
|
defer.returnValue(data)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_or_create_user(self, localpart, displayname, duration_seconds):
|
def get_or_create_user(self, localpart, displayname, duration_in_ms,
|
||||||
|
password_hash=None):
|
||||||
"""Creates a new user if the user does not exist,
|
"""Creates a new user if the user does not exist,
|
||||||
else revokes all previous access tokens and generates a new one.
|
else revokes all previous access tokens and generates a new one.
|
||||||
|
|
||||||
|
@ -387,14 +401,14 @@ class RegistrationHandler(BaseHandler):
|
||||||
|
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
token = self.auth_handler().generate_short_term_login_token(
|
token = self.auth_handler().generate_access_token(
|
||||||
user_id, duration_seconds)
|
user_id, None, duration_in_ms)
|
||||||
|
|
||||||
if need_register:
|
if need_register:
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
token=token,
|
||||||
password_hash=None,
|
password_hash=password_hash,
|
||||||
create_profile_with_localpart=user.localpart,
|
create_profile_with_localpart=user.localpart,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -404,8 +418,9 @@ class RegistrationHandler(BaseHandler):
|
||||||
if displayname is not None:
|
if displayname is not None:
|
||||||
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
||||||
profile_handler = self.hs.get_handlers().profile_handler
|
profile_handler = self.hs.get_handlers().profile_handler
|
||||||
|
requester = synapse.types.create_requester(user)
|
||||||
yield profile_handler.set_displayname(
|
yield profile_handler.set_displayname(
|
||||||
user, Requester(user, token, False), displayname
|
user, requester, displayname
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((user_id, token))
|
defer.returnValue((user_id, token))
|
||||||
|
|
|
@ -345,8 +345,8 @@ class RoomCreationHandler(BaseHandler):
|
||||||
class RoomListHandler(BaseHandler):
|
class RoomListHandler(BaseHandler):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(RoomListHandler, self).__init__(hs)
|
super(RoomListHandler, self).__init__(hs)
|
||||||
self.response_cache = ResponseCache()
|
self.response_cache = ResponseCache(hs)
|
||||||
self.remote_list_request_cache = ResponseCache()
|
self.remote_list_request_cache = ResponseCache(hs)
|
||||||
self.remote_list_cache = {}
|
self.remote_list_cache = {}
|
||||||
self.fetch_looping_call = hs.get_clock().looping_call(
|
self.fetch_looping_call = hs.get_clock().looping_call(
|
||||||
self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL
|
self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL
|
||||||
|
|
|
@ -14,24 +14,22 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from signedjson.key import decode_verify_key_bytes
|
||||||
|
from signedjson.sign import verify_signed_json
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from unpaddedbase64 import decode_base64
|
||||||
|
|
||||||
from ._base import BaseHandler
|
import synapse.types
|
||||||
|
|
||||||
from synapse.types import UserID, RoomID, Requester
|
|
||||||
from synapse.api.constants import (
|
from synapse.api.constants import (
|
||||||
EventTypes, Membership,
|
EventTypes, Membership,
|
||||||
)
|
)
|
||||||
from synapse.api.errors import AuthError, SynapseError, Codes
|
from synapse.api.errors import AuthError, SynapseError, Codes
|
||||||
|
from synapse.types import UserID, RoomID
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async import Linearizer
|
||||||
from synapse.util.distributor import user_left_room, user_joined_room
|
from synapse.util.distributor import user_left_room, user_joined_room
|
||||||
|
from ._base import BaseHandler
|
||||||
from signedjson.sign import verify_signed_json
|
|
||||||
from signedjson.key import decode_verify_key_bytes
|
|
||||||
|
|
||||||
from unpaddedbase64 import decode_base64
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -315,7 +313,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
|
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
|
||||||
else:
|
else:
|
||||||
requester = Requester(target_user, None, False)
|
requester = synapse.types.create_requester(target_user)
|
||||||
|
|
||||||
message_handler = self.hs.get_handlers().message_handler
|
message_handler = self.hs.get_handlers().message_handler
|
||||||
prev_event = message_handler.deduplicate_state_event(event, context)
|
prev_event = message_handler.deduplicate_state_event(event, context)
|
||||||
|
|
|
@ -138,7 +138,7 @@ class SyncHandler(object):
|
||||||
self.presence_handler = hs.get_presence_handler()
|
self.presence_handler = hs.get_presence_handler()
|
||||||
self.event_sources = hs.get_event_sources()
|
self.event_sources = hs.get_event_sources()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.response_cache = ResponseCache()
|
self.response_cache = ResponseCache(hs)
|
||||||
|
|
||||||
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
|
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
|
||||||
full_state=False):
|
full_state=False):
|
||||||
|
|
|
@ -205,6 +205,7 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
|
|
||||||
def register_paths(self, method, path_patterns, callback):
|
def register_paths(self, method, path_patterns, callback):
|
||||||
for path_pattern in path_patterns:
|
for path_pattern in path_patterns:
|
||||||
|
logger.debug("Registering for %s %s", method, path_pattern.pattern)
|
||||||
self.path_regexs.setdefault(method, []).append(
|
self.path_regexs.setdefault(method, []).append(
|
||||||
self._PathEntry(path_pattern, callback)
|
self._PathEntry(path_pattern, callback)
|
||||||
)
|
)
|
||||||
|
|
|
@ -27,7 +27,8 @@ import gc
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor
|
||||||
|
|
||||||
from .metric import (
|
from .metric import (
|
||||||
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
|
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric,
|
||||||
|
MemoryUsageMetric,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -66,6 +67,21 @@ class Metrics(object):
|
||||||
return self._register(CacheMetric, *args, **kwargs)
|
return self._register(CacheMetric, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def register_memory_metrics(hs):
|
||||||
|
try:
|
||||||
|
import psutil
|
||||||
|
process = psutil.Process()
|
||||||
|
process.memory_info().rss
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
logger.warn(
|
||||||
|
"psutil is not installed or incorrect version."
|
||||||
|
" Disabling memory metrics."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
metric = MemoryUsageMetric(hs, psutil)
|
||||||
|
all_metrics.append(metric)
|
||||||
|
|
||||||
|
|
||||||
def get_metrics_for(pkg_name):
|
def get_metrics_for(pkg_name):
|
||||||
""" Returns a Metrics instance for conveniently creating metrics
|
""" Returns a Metrics instance for conveniently creating metrics
|
||||||
namespaced with the given name prefix. """
|
namespaced with the given name prefix. """
|
||||||
|
|
|
@ -153,3 +153,43 @@ class CacheMetric(object):
|
||||||
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
|
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
|
||||||
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
|
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryUsageMetric(object):
|
||||||
|
"""Keeps track of the current memory usage, using psutil.
|
||||||
|
|
||||||
|
The class will keep the current min/max/sum/counts of rss over the last
|
||||||
|
WINDOW_SIZE_SEC, by polling UPDATE_HZ times per second
|
||||||
|
"""
|
||||||
|
|
||||||
|
UPDATE_HZ = 2 # number of times to get memory per second
|
||||||
|
WINDOW_SIZE_SEC = 30 # the size of the window in seconds
|
||||||
|
|
||||||
|
def __init__(self, hs, psutil):
|
||||||
|
clock = hs.get_clock()
|
||||||
|
self.memory_snapshots = []
|
||||||
|
|
||||||
|
self.process = psutil.Process()
|
||||||
|
|
||||||
|
clock.looping_call(self._update_curr_values, 1000 / self.UPDATE_HZ)
|
||||||
|
|
||||||
|
def _update_curr_values(self):
|
||||||
|
max_size = self.UPDATE_HZ * self.WINDOW_SIZE_SEC
|
||||||
|
self.memory_snapshots.append(self.process.memory_info().rss)
|
||||||
|
self.memory_snapshots[:] = self.memory_snapshots[-max_size:]
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
if not self.memory_snapshots:
|
||||||
|
return []
|
||||||
|
|
||||||
|
max_rss = max(self.memory_snapshots)
|
||||||
|
min_rss = min(self.memory_snapshots)
|
||||||
|
sum_rss = sum(self.memory_snapshots)
|
||||||
|
len_rss = len(self.memory_snapshots)
|
||||||
|
|
||||||
|
return [
|
||||||
|
"process_psutil_rss:max %d" % max_rss,
|
||||||
|
"process_psutil_rss:min %d" % min_rss,
|
||||||
|
"process_psutil_rss:total %d" % sum_rss,
|
||||||
|
"process_psutil_rss:count %d" % len_rss,
|
||||||
|
]
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
|
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -92,7 +93,11 @@ class EmailPusher(object):
|
||||||
|
|
||||||
def on_stop(self):
|
def on_stop(self):
|
||||||
if self.timed_call:
|
if self.timed_call:
|
||||||
|
try:
|
||||||
self.timed_call.cancel()
|
self.timed_call.cancel()
|
||||||
|
except (AlreadyCalled, AlreadyCancelled):
|
||||||
|
pass
|
||||||
|
self.timed_call = None
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
|
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
|
||||||
|
@ -140,9 +145,8 @@ class EmailPusher(object):
|
||||||
being run.
|
being run.
|
||||||
"""
|
"""
|
||||||
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
|
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
|
||||||
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
|
fn = self.store.get_unread_push_actions_for_user_in_range_for_email
|
||||||
self.user_id, start, self.max_stream_ordering
|
unprocessed = yield fn(self.user_id, start, self.max_stream_ordering)
|
||||||
)
|
|
||||||
|
|
||||||
soonest_due_at = None
|
soonest_due_at = None
|
||||||
|
|
||||||
|
@ -190,7 +194,10 @@ class EmailPusher(object):
|
||||||
soonest_due_at = should_notify_at
|
soonest_due_at = should_notify_at
|
||||||
|
|
||||||
if self.timed_call is not None:
|
if self.timed_call is not None:
|
||||||
|
try:
|
||||||
self.timed_call.cancel()
|
self.timed_call.cancel()
|
||||||
|
except (AlreadyCalled, AlreadyCancelled):
|
||||||
|
pass
|
||||||
self.timed_call = None
|
self.timed_call = None
|
||||||
|
|
||||||
if soonest_due_at is not None:
|
if soonest_due_at is not None:
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
from synapse.push import PusherConfigException
|
from synapse.push import PusherConfigException
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
|
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import push_rule_evaluator
|
import push_rule_evaluator
|
||||||
|
@ -38,6 +39,7 @@ class HttpPusher(object):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.store = self.hs.get_datastore()
|
self.store = self.hs.get_datastore()
|
||||||
self.clock = self.hs.get_clock()
|
self.clock = self.hs.get_clock()
|
||||||
|
self.state_handler = self.hs.get_state_handler()
|
||||||
self.user_id = pusherdict['user_name']
|
self.user_id = pusherdict['user_name']
|
||||||
self.app_id = pusherdict['app_id']
|
self.app_id = pusherdict['app_id']
|
||||||
self.app_display_name = pusherdict['app_display_name']
|
self.app_display_name = pusherdict['app_display_name']
|
||||||
|
@ -108,7 +110,11 @@ class HttpPusher(object):
|
||||||
|
|
||||||
def on_stop(self):
|
def on_stop(self):
|
||||||
if self.timed_call:
|
if self.timed_call:
|
||||||
|
try:
|
||||||
self.timed_call.cancel()
|
self.timed_call.cancel()
|
||||||
|
except (AlreadyCalled, AlreadyCancelled):
|
||||||
|
pass
|
||||||
|
self.timed_call = None
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _process(self):
|
def _process(self):
|
||||||
|
@ -140,7 +146,8 @@ class HttpPusher(object):
|
||||||
run once per pusher.
|
run once per pusher.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
|
fn = self.store.get_unread_push_actions_for_user_in_range_for_http
|
||||||
|
unprocessed = yield fn(
|
||||||
self.user_id, self.last_stream_ordering, self.max_stream_ordering
|
self.user_id, self.last_stream_ordering, self.max_stream_ordering
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -237,7 +244,9 @@ class HttpPusher(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _build_notification_dict(self, event, tweaks, badge):
|
def _build_notification_dict(self, event, tweaks, badge):
|
||||||
ctx = yield push_tools.get_context_for_event(self.hs.get_datastore(), event)
|
ctx = yield push_tools.get_context_for_event(
|
||||||
|
self.state_handler, event, self.user_id
|
||||||
|
)
|
||||||
|
|
||||||
d = {
|
d = {
|
||||||
'notification': {
|
'notification': {
|
||||||
|
@ -269,8 +278,8 @@ class HttpPusher(object):
|
||||||
if 'content' in event:
|
if 'content' in event:
|
||||||
d['notification']['content'] = event.content
|
d['notification']['content'] = event.content
|
||||||
|
|
||||||
if len(ctx['aliases']):
|
# We no longer send aliases separately, instead, we send the human
|
||||||
d['notification']['room_alias'] = ctx['aliases'][0]
|
# readable name of the room, which may be an alias.
|
||||||
if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0:
|
if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0:
|
||||||
d['notification']['sender_display_name'] = ctx['sender_display_name']
|
d['notification']['sender_display_name'] = ctx['sender_display_name']
|
||||||
if 'name' in ctx and len(ctx['name']) > 0:
|
if 'name' in ctx and len(ctx['name']) > 0:
|
||||||
|
|
|
@ -14,6 +14,9 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from synapse.util.presentable_names import (
|
||||||
|
calculate_room_name, name_from_member_event
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -45,24 +48,21 @@ def get_badge_count(store, user_id):
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_context_for_event(store, ev):
|
def get_context_for_event(state_handler, ev, user_id):
|
||||||
name_aliases = yield store.get_room_name_and_aliases(
|
ctx = {}
|
||||||
ev.room_id
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx = {'aliases': name_aliases[1]}
|
room_state = yield state_handler.get_current_state(ev.room_id)
|
||||||
if name_aliases[0] is not None:
|
|
||||||
ctx['name'] = name_aliases[0]
|
|
||||||
|
|
||||||
their_member_events_for_room = yield store.get_current_state(
|
# we no longer bother setting room_alias, and make room_name the
|
||||||
room_id=ev.room_id,
|
# human-readable name instead, be that m.room.name, an alias or
|
||||||
event_type='m.room.member',
|
# a list of people in the room
|
||||||
state_key=ev.user_id
|
name = calculate_room_name(
|
||||||
|
room_state, user_id, fallback_to_single_member=False
|
||||||
)
|
)
|
||||||
for mev in their_member_events_for_room:
|
if name:
|
||||||
if mev.content['membership'] == 'join' and 'displayname' in mev.content:
|
ctx['name'] = name
|
||||||
dn = mev.content['displayname']
|
|
||||||
if dn is not None:
|
sender_state_event = room_state[("m.room.member", ev.sender)]
|
||||||
ctx['sender_display_name'] = dn
|
ctx['sender_display_name'] = name_from_member_event(sender_state_event)
|
||||||
|
|
||||||
defer.returnValue(ctx)
|
defer.returnValue(ctx)
|
||||||
|
|
|
@ -48,6 +48,12 @@ CONDITIONAL_REQUIREMENTS = {
|
||||||
"Jinja2>=2.8": ["Jinja2>=2.8"],
|
"Jinja2>=2.8": ["Jinja2>=2.8"],
|
||||||
"bleach>=1.4.2": ["bleach>=1.4.2"],
|
"bleach>=1.4.2": ["bleach>=1.4.2"],
|
||||||
},
|
},
|
||||||
|
"ldap": {
|
||||||
|
"ldap3>=1.0": ["ldap3>=1.0"],
|
||||||
|
},
|
||||||
|
"psutil": {
|
||||||
|
"psutil>=2.0.0": ["psutil>=2.0.0"],
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
23
synapse/replication/slave/storage/directory.py
Normal file
23
synapse/replication/slave/storage/directory.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ._base import BaseSlavedStore
|
||||||
|
from synapse.storage.directory import DirectoryStore
|
||||||
|
|
||||||
|
|
||||||
|
class DirectoryStore(BaseSlavedStore):
|
||||||
|
get_aliases_for_room = DirectoryStore.__dict__[
|
||||||
|
"get_aliases_for_room"
|
||||||
|
].orig
|
|
@ -18,7 +18,6 @@ from ._slaved_id_tracker import SlavedIdTracker
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
from synapse.storage import DataStore
|
from synapse.storage import DataStore
|
||||||
from synapse.storage.room import RoomStore
|
|
||||||
from synapse.storage.roommember import RoomMemberStore
|
from synapse.storage.roommember import RoomMemberStore
|
||||||
from synapse.storage.event_federation import EventFederationStore
|
from synapse.storage.event_federation import EventFederationStore
|
||||||
from synapse.storage.event_push_actions import EventPushActionsStore
|
from synapse.storage.event_push_actions import EventPushActionsStore
|
||||||
|
@ -64,7 +63,6 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
|
|
||||||
# Cached functions can't be accessed through a class instance so we need
|
# Cached functions can't be accessed through a class instance so we need
|
||||||
# to reach inside the __dict__ to extract them.
|
# to reach inside the __dict__ to extract them.
|
||||||
get_room_name_and_aliases = RoomStore.__dict__["get_room_name_and_aliases"]
|
|
||||||
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
|
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
|
||||||
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
|
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
|
||||||
get_latest_event_ids_in_room = EventFederationStore.__dict__[
|
get_latest_event_ids_in_room = EventFederationStore.__dict__[
|
||||||
|
@ -95,8 +93,11 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
StreamStore.__dict__["get_recent_event_ids_for_room"]
|
StreamStore.__dict__["get_recent_event_ids_for_room"]
|
||||||
)
|
)
|
||||||
|
|
||||||
get_unread_push_actions_for_user_in_range = (
|
get_unread_push_actions_for_user_in_range_for_http = (
|
||||||
DataStore.get_unread_push_actions_for_user_in_range.__func__
|
DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
|
||||||
|
)
|
||||||
|
get_unread_push_actions_for_user_in_range_for_email = (
|
||||||
|
DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__
|
||||||
)
|
)
|
||||||
get_push_action_users_in_range = (
|
get_push_action_users_in_range = (
|
||||||
DataStore.get_push_action_users_in_range.__func__
|
DataStore.get_push_action_users_in_range.__func__
|
||||||
|
@ -144,6 +145,15 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
_get_events_around_txn = DataStore._get_events_around_txn.__func__
|
_get_events_around_txn = DataStore._get_events_around_txn.__func__
|
||||||
_get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
|
_get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
|
||||||
|
|
||||||
|
get_backfill_events = DataStore.get_backfill_events.__func__
|
||||||
|
_get_backfill_events = DataStore._get_backfill_events.__func__
|
||||||
|
get_missing_events = DataStore.get_missing_events.__func__
|
||||||
|
_get_missing_events = DataStore._get_missing_events.__func__
|
||||||
|
|
||||||
|
get_auth_chain = DataStore.get_auth_chain.__func__
|
||||||
|
get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__
|
||||||
|
_get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__
|
||||||
|
|
||||||
def stream_positions(self):
|
def stream_positions(self):
|
||||||
result = super(SlavedEventStore, self).stream_positions()
|
result = super(SlavedEventStore, self).stream_positions()
|
||||||
result["events"] = self._stream_id_gen.get_current_token()
|
result["events"] = self._stream_id_gen.get_current_token()
|
||||||
|
@ -202,7 +212,6 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
self.get_rooms_for_user.invalidate_all()
|
self.get_rooms_for_user.invalidate_all()
|
||||||
self.get_users_in_room.invalidate((event.room_id,))
|
self.get_users_in_room.invalidate((event.room_id,))
|
||||||
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
|
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
|
||||||
self.get_room_name_and_aliases.invalidate((event.room_id,))
|
|
||||||
|
|
||||||
self._invalidate_get_event_cache(event.event_id)
|
self._invalidate_get_event_cache(event.event_id)
|
||||||
|
|
||||||
|
@ -246,9 +255,3 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
self._get_current_state_for_key.invalidate((
|
self._get_current_state_for_key.invalidate((
|
||||||
event.room_id, event.type, event.state_key
|
event.room_id, event.type, event.state_key
|
||||||
))
|
))
|
||||||
|
|
||||||
if event.type in [EventTypes.Name, EventTypes.Aliases]:
|
|
||||||
self.get_room_name_and_aliases.invalidate(
|
|
||||||
(event.room_id,)
|
|
||||||
)
|
|
||||||
pass
|
|
||||||
|
|
33
synapse/replication/slave/storage/keys.py
Normal file
33
synapse/replication/slave/storage/keys.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ._base import BaseSlavedStore
|
||||||
|
from synapse.storage import DataStore
|
||||||
|
from synapse.storage.keys import KeyStore
|
||||||
|
|
||||||
|
|
||||||
|
class SlavedKeyStore(BaseSlavedStore):
|
||||||
|
_get_server_verify_key = KeyStore.__dict__[
|
||||||
|
"_get_server_verify_key"
|
||||||
|
]
|
||||||
|
|
||||||
|
get_server_verify_keys = DataStore.get_server_verify_keys.__func__
|
||||||
|
store_server_verify_key = DataStore.store_server_verify_key.__func__
|
||||||
|
|
||||||
|
get_server_certificate = DataStore.get_server_certificate.__func__
|
||||||
|
store_server_certificate = DataStore.store_server_certificate.__func__
|
||||||
|
|
||||||
|
get_server_keys_json = DataStore.get_server_keys_json.__func__
|
||||||
|
store_server_keys_json = DataStore.store_server_keys_json.__func__
|
21
synapse/replication/slave/storage/room.py
Normal file
21
synapse/replication/slave/storage/room.py
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ._base import BaseSlavedStore
|
||||||
|
from synapse.storage import DataStore
|
||||||
|
|
||||||
|
|
||||||
|
class RoomStore(BaseSlavedStore):
|
||||||
|
get_public_room_ids = DataStore.get_public_room_ids.__func__
|
30
synapse/replication/slave/storage/transactions.py
Normal file
30
synapse/replication/slave/storage/transactions.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
from ._base import BaseSlavedStore
|
||||||
|
from synapse.storage import DataStore
|
||||||
|
from synapse.storage.transactions import TransactionStore
|
||||||
|
|
||||||
|
|
||||||
|
class TransactionStore(BaseSlavedStore):
|
||||||
|
get_destination_retry_timings = TransactionStore.__dict__[
|
||||||
|
"get_destination_retry_timings"
|
||||||
|
].orig
|
||||||
|
_get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
|
||||||
|
|
||||||
|
# For now, don't record the destination rety timings
|
||||||
|
def set_destination_retry_timings(*args, **kwargs):
|
||||||
|
return defer.succeed(None)
|
|
@ -46,6 +46,7 @@ from synapse.rest.client.v2_alpha import (
|
||||||
account_data,
|
account_data,
|
||||||
report_event,
|
report_event,
|
||||||
openid,
|
openid,
|
||||||
|
devices,
|
||||||
)
|
)
|
||||||
|
|
||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
|
@ -90,3 +91,4 @@ class ClientRestResource(JsonResource):
|
||||||
account_data.register_servlets(hs, client_resource)
|
account_data.register_servlets(hs, client_resource)
|
||||||
report_event.register_servlets(hs, client_resource)
|
report_event.register_servlets(hs, client_resource)
|
||||||
openid.register_servlets(hs, client_resource)
|
openid.register_servlets(hs, client_resource)
|
||||||
|
devices.register_servlets(hs, client_resource)
|
||||||
|
|
|
@ -46,5 +46,82 @@ class WhoisRestServlet(ClientV1RestServlet):
|
||||||
defer.returnValue((200, ret))
|
defer.returnValue((200, ret))
|
||||||
|
|
||||||
|
|
||||||
|
class PurgeMediaCacheRestServlet(ClientV1RestServlet):
|
||||||
|
PATTERNS = client_path_patterns("/admin/purge_media_cache")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.media_repository = hs.get_media_repository()
|
||||||
|
super(PurgeMediaCacheRestServlet, self).__init__(hs)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||||
|
|
||||||
|
if not is_admin:
|
||||||
|
raise AuthError(403, "You are not a server admin")
|
||||||
|
|
||||||
|
before_ts = request.args.get("before_ts", None)
|
||||||
|
if not before_ts:
|
||||||
|
raise SynapseError(400, "Missing 'before_ts' arg")
|
||||||
|
|
||||||
|
logger.info("before_ts: %r", before_ts[0])
|
||||||
|
|
||||||
|
try:
|
||||||
|
before_ts = int(before_ts[0])
|
||||||
|
except Exception:
|
||||||
|
raise SynapseError(400, "Invalid 'before_ts' arg")
|
||||||
|
|
||||||
|
ret = yield self.media_repository.delete_old_remote_media(before_ts)
|
||||||
|
|
||||||
|
defer.returnValue((200, ret))
|
||||||
|
|
||||||
|
|
||||||
|
class PurgeHistoryRestServlet(ClientV1RestServlet):
|
||||||
|
PATTERNS = client_path_patterns(
|
||||||
|
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request, room_id, event_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||||
|
|
||||||
|
if not is_admin:
|
||||||
|
raise AuthError(403, "You are not a server admin")
|
||||||
|
|
||||||
|
yield self.handlers.message_handler.purge_history(room_id, event_id)
|
||||||
|
|
||||||
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
|
class DeactivateAccountRestServlet(ClientV1RestServlet):
|
||||||
|
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
super(DeactivateAccountRestServlet, self).__init__(hs)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request, target_user_id):
|
||||||
|
UserID.from_string(target_user_id)
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||||
|
|
||||||
|
if not is_admin:
|
||||||
|
raise AuthError(403, "You are not a server admin")
|
||||||
|
|
||||||
|
# FIXME: Theoretically there is a race here wherein user resets password
|
||||||
|
# using threepid.
|
||||||
|
yield self.store.user_delete_access_tokens(target_user_id)
|
||||||
|
yield self.store.user_delete_threepids(target_user_id)
|
||||||
|
yield self.store.user_set_password_hash(target_user_id, None)
|
||||||
|
|
||||||
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
WhoisRestServlet(hs).register(http_server)
|
WhoisRestServlet(hs).register(http_server)
|
||||||
|
PurgeMediaCacheRestServlet(hs).register(http_server)
|
||||||
|
DeactivateAccountRestServlet(hs).register(http_server)
|
||||||
|
PurgeHistoryRestServlet(hs).register(http_server)
|
||||||
|
|
|
@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer):
|
||||||
|
"""
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.handlers = hs.get_handlers()
|
self.handlers = hs.get_handlers()
|
||||||
self.builder_factory = hs.get_event_builder_factory()
|
self.builder_factory = hs.get_event_builder_factory()
|
||||||
|
|
|
@ -59,6 +59,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
self.servername = hs.config.server_name
|
self.servername = hs.config.server_name
|
||||||
self.http_client = hs.get_simple_http_client()
|
self.http_client = hs.get_simple_http_client()
|
||||||
self.auth_handler = self.hs.get_auth_handler()
|
self.auth_handler = self.hs.get_auth_handler()
|
||||||
|
self.device_handler = self.hs.get_device_handler()
|
||||||
|
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
flows = []
|
flows = []
|
||||||
|
@ -145,15 +146,23 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
).to_string()
|
).to_string()
|
||||||
|
|
||||||
auth_handler = self.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_id, access_token, refresh_token = yield auth_handler.login_with_password(
|
user_id = yield auth_handler.validate_password_login(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
password=login_submission["password"])
|
password=login_submission["password"],
|
||||||
|
)
|
||||||
|
device_id = yield self._register_device(user_id, login_submission)
|
||||||
|
access_token, refresh_token = (
|
||||||
|
yield auth_handler.get_login_tuple_for_user_id(
|
||||||
|
user_id, device_id,
|
||||||
|
login_submission.get("initial_device_display_name")
|
||||||
|
)
|
||||||
|
)
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": user_id, # may have changed
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"refresh_token": refresh_token,
|
"refresh_token": refresh_token,
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
|
"device_id": device_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
defer.returnValue((200, result))
|
defer.returnValue((200, result))
|
||||||
|
@ -165,14 +174,19 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
user_id = (
|
user_id = (
|
||||||
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||||
)
|
)
|
||||||
user_id, access_token, refresh_token = (
|
device_id = yield self._register_device(user_id, login_submission)
|
||||||
yield auth_handler.get_login_tuple_for_user_id(user_id)
|
access_token, refresh_token = (
|
||||||
|
yield auth_handler.get_login_tuple_for_user_id(
|
||||||
|
user_id, device_id,
|
||||||
|
login_submission.get("initial_device_display_name")
|
||||||
|
)
|
||||||
)
|
)
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": user_id, # may have changed
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"refresh_token": refresh_token,
|
"refresh_token": refresh_token,
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
|
"device_id": device_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
defer.returnValue((200, result))
|
defer.returnValue((200, result))
|
||||||
|
@ -196,13 +210,15 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||||
auth_handler = self.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
registered_user_id = yield auth_handler.check_user_exists(user_id)
|
||||||
if user_exists:
|
if registered_user_id:
|
||||||
user_id, access_token, refresh_token = (
|
access_token, refresh_token = (
|
||||||
yield auth_handler.get_login_tuple_for_user_id(user_id)
|
yield auth_handler.get_login_tuple_for_user_id(
|
||||||
|
registered_user_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": registered_user_id, # may have changed
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"refresh_token": refresh_token,
|
"refresh_token": refresh_token,
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
|
@ -245,18 +261,27 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||||
auth_handler = self.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
registered_user_id = yield auth_handler.check_user_exists(user_id)
|
||||||
if user_exists:
|
if registered_user_id:
|
||||||
user_id, access_token, refresh_token = (
|
device_id = yield self._register_device(
|
||||||
yield auth_handler.get_login_tuple_for_user_id(user_id)
|
registered_user_id, login_submission
|
||||||
|
)
|
||||||
|
access_token, refresh_token = (
|
||||||
|
yield auth_handler.get_login_tuple_for_user_id(
|
||||||
|
registered_user_id, device_id,
|
||||||
|
login_submission.get("initial_device_display_name")
|
||||||
|
)
|
||||||
)
|
)
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": registered_user_id,
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"refresh_token": refresh_token,
|
"refresh_token": refresh_token,
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
# TODO: we should probably check that the register isn't going
|
||||||
|
# to fonx/change our user_id before registering the device
|
||||||
|
device_id = yield self._register_device(user_id, login_submission)
|
||||||
user_id, access_token = (
|
user_id, access_token = (
|
||||||
yield self.handlers.registration_handler.register(localpart=user)
|
yield self.handlers.registration_handler.register(localpart=user)
|
||||||
)
|
)
|
||||||
|
@ -295,6 +320,26 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
return (user, attributes)
|
return (user, attributes)
|
||||||
|
|
||||||
|
def _register_device(self, user_id, login_submission):
|
||||||
|
"""Register a device for a user.
|
||||||
|
|
||||||
|
This is called after the user's credentials have been validated, but
|
||||||
|
before the access token has been issued.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
(str) user_id: full canonical @user:id
|
||||||
|
(object) login_submission: dictionary supplied to /login call, from
|
||||||
|
which we pull device_id and initial_device_name
|
||||||
|
Returns:
|
||||||
|
defer.Deferred: (str) device_id
|
||||||
|
"""
|
||||||
|
device_id = login_submission.get("device_id")
|
||||||
|
initial_display_name = login_submission.get(
|
||||||
|
"initial_device_display_name")
|
||||||
|
return self.device_handler.check_device_registered(
|
||||||
|
user_id, device_id, initial_display_name
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SAML2RestServlet(ClientV1RestServlet):
|
class SAML2RestServlet(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns("/login/saml2", releases=())
|
PATTERNS = client_path_patterns("/login/saml2", releases=())
|
||||||
|
@ -414,13 +459,13 @@ class CasTicketServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||||
auth_handler = self.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
registered_user_id = yield auth_handler.check_user_exists(user_id)
|
||||||
if not user_exists:
|
if not registered_user_id:
|
||||||
user_id, _ = (
|
registered_user_id, _ = (
|
||||||
yield self.handlers.registration_handler.register(localpart=user)
|
yield self.handlers.registration_handler.register(localpart=user)
|
||||||
)
|
)
|
||||||
|
|
||||||
login_token = auth_handler.generate_short_term_login_token(user_id)
|
login_token = auth_handler.generate_short_term_login_token(registered_user_id)
|
||||||
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
|
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
|
||||||
login_token)
|
login_token)
|
||||||
request.redirect(redirect_url)
|
request.redirect(redirect_url)
|
||||||
|
|
|
@ -52,6 +52,10 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False)
|
PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False)
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer): server
|
||||||
|
"""
|
||||||
super(RegisterRestServlet, self).__init__(hs)
|
super(RegisterRestServlet, self).__init__(hs)
|
||||||
# sessions are stored as:
|
# sessions are stored as:
|
||||||
# self.sessions = {
|
# self.sessions = {
|
||||||
|
@ -60,6 +64,7 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
# TODO: persistent storage
|
# TODO: persistent storage
|
||||||
self.sessions = {}
|
self.sessions = {}
|
||||||
self.enable_registration = hs.config.enable_registration
|
self.enable_registration = hs.config.enable_registration
|
||||||
|
self.auth_handler = hs.get_auth_handler()
|
||||||
|
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
if self.hs.config.enable_registration_captcha:
|
if self.hs.config.enable_registration_captcha:
|
||||||
|
@ -299,9 +304,10 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
user_localpart = register_json["user"].encode("utf-8")
|
user_localpart = register_json["user"].encode("utf-8")
|
||||||
|
|
||||||
handler = self.handlers.registration_handler
|
handler = self.handlers.registration_handler
|
||||||
(user_id, token) = yield handler.appservice_register(
|
user_id = yield handler.appservice_register(
|
||||||
user_localpart, as_token
|
user_localpart, as_token
|
||||||
)
|
)
|
||||||
|
token = yield self.auth_handler.issue_access_token(user_id)
|
||||||
self._remove_session(session)
|
self._remove_session(session)
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
@ -324,6 +330,14 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
raise SynapseError(400, "Shared secret registration is not enabled")
|
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||||
|
|
||||||
user = register_json["user"].encode("utf-8")
|
user = register_json["user"].encode("utf-8")
|
||||||
|
password = register_json["password"].encode("utf-8")
|
||||||
|
admin = register_json.get("admin", None)
|
||||||
|
|
||||||
|
# Its important to check as we use null bytes as HMAC field separators
|
||||||
|
if "\x00" in user:
|
||||||
|
raise SynapseError(400, "Invalid user")
|
||||||
|
if "\x00" in password:
|
||||||
|
raise SynapseError(400, "Invalid password")
|
||||||
|
|
||||||
# str() because otherwise hmac complains that 'unicode' does not
|
# str() because otherwise hmac complains that 'unicode' does not
|
||||||
# have the buffer interface
|
# have the buffer interface
|
||||||
|
@ -331,17 +345,21 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
want_mac = hmac.new(
|
want_mac = hmac.new(
|
||||||
key=self.hs.config.registration_shared_secret,
|
key=self.hs.config.registration_shared_secret,
|
||||||
msg=user,
|
|
||||||
digestmod=sha1,
|
digestmod=sha1,
|
||||||
).hexdigest()
|
)
|
||||||
|
want_mac.update(user)
|
||||||
password = register_json["password"].encode("utf-8")
|
want_mac.update("\x00")
|
||||||
|
want_mac.update(password)
|
||||||
|
want_mac.update("\x00")
|
||||||
|
want_mac.update("admin" if admin else "notadmin")
|
||||||
|
want_mac = want_mac.hexdigest()
|
||||||
|
|
||||||
if compare_digest(want_mac, got_mac):
|
if compare_digest(want_mac, got_mac):
|
||||||
handler = self.handlers.registration_handler
|
handler = self.handlers.registration_handler
|
||||||
user_id, token = yield handler.register(
|
user_id, token = yield handler.register(
|
||||||
localpart=user,
|
localpart=user,
|
||||||
password=password,
|
password=password,
|
||||||
|
admin=bool(admin),
|
||||||
)
|
)
|
||||||
self._remove_session(session)
|
self._remove_session(session)
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
|
@ -410,12 +428,15 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
||||||
raise SynapseError(400, "Failed to parse 'duration_seconds'")
|
raise SynapseError(400, "Failed to parse 'duration_seconds'")
|
||||||
if duration_seconds > self.direct_user_creation_max_duration:
|
if duration_seconds > self.direct_user_creation_max_duration:
|
||||||
duration_seconds = self.direct_user_creation_max_duration
|
duration_seconds = self.direct_user_creation_max_duration
|
||||||
|
password_hash = user_json["password_hash"].encode("utf-8") \
|
||||||
|
if user_json.get("password_hash") else None
|
||||||
|
|
||||||
handler = self.handlers.registration_handler
|
handler = self.handlers.registration_handler
|
||||||
user_id, token = yield handler.get_or_create_user(
|
user_id, token = yield handler.get_or_create_user(
|
||||||
localpart=localpart,
|
localpart=localpart,
|
||||||
displayname=displayname,
|
displayname=displayname,
|
||||||
duration_seconds=duration_seconds
|
duration_in_ms=(duration_seconds * 1000),
|
||||||
|
password_hash=password_hash
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
|
|
|
@ -20,12 +20,14 @@ from .base import ClientV1RestServlet, client_path_patterns
|
||||||
from synapse.api.errors import SynapseError, Codes, AuthError
|
from synapse.api.errors import SynapseError, Codes, AuthError
|
||||||
from synapse.streams.config import PaginationConfig
|
from synapse.streams.config import PaginationConfig
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
|
from synapse.api.filtering import Filter
|
||||||
from synapse.types import UserID, RoomID, RoomAlias
|
from synapse.types import UserID, RoomID, RoomAlias
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
import urllib
|
||||||
|
import ujson as json
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -327,12 +329,19 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
|
||||||
request, default_limit=10,
|
request, default_limit=10,
|
||||||
)
|
)
|
||||||
as_client_event = "raw" not in request.args
|
as_client_event = "raw" not in request.args
|
||||||
|
filter_bytes = request.args.get("filter", None)
|
||||||
|
if filter_bytes:
|
||||||
|
filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8")
|
||||||
|
event_filter = Filter(json.loads(filter_json))
|
||||||
|
else:
|
||||||
|
event_filter = None
|
||||||
handler = self.handlers.message_handler
|
handler = self.handlers.message_handler
|
||||||
msgs = yield handler.get_messages(
|
msgs = yield handler.get_messages(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
requester=requester,
|
requester=requester,
|
||||||
pagin_config=pagination_config,
|
pagin_config=pagination_config,
|
||||||
as_client_event=as_client_event
|
as_client_event=as_client_event,
|
||||||
|
event_filter=event_filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, msgs))
|
defer.returnValue((200, msgs))
|
||||||
|
|
|
@ -25,7 +25,9 @@ import logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def client_v2_patterns(path_regex, releases=(0,)):
|
def client_v2_patterns(path_regex, releases=(0,),
|
||||||
|
v2_alpha=True,
|
||||||
|
unstable=True):
|
||||||
"""Creates a regex compiled client path with the correct client path
|
"""Creates a regex compiled client path with the correct client path
|
||||||
prefix.
|
prefix.
|
||||||
|
|
||||||
|
@ -35,7 +37,10 @@ def client_v2_patterns(path_regex, releases=(0,)):
|
||||||
Returns:
|
Returns:
|
||||||
SRE_Pattern
|
SRE_Pattern
|
||||||
"""
|
"""
|
||||||
patterns = [re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)]
|
patterns = []
|
||||||
|
if v2_alpha:
|
||||||
|
patterns.append(re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex))
|
||||||
|
if unstable:
|
||||||
unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable")
|
unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable")
|
||||||
patterns.append(re.compile("^" + unstable_prefix + path_regex))
|
patterns.append(re.compile("^" + unstable_prefix + path_regex))
|
||||||
for release in releases:
|
for release in releases:
|
||||||
|
|
|
@ -28,8 +28,40 @@ import logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordRequestTokenRestServlet(RestServlet):
|
||||||
|
PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(PasswordRequestTokenRestServlet, self).__init__()
|
||||||
|
self.hs = hs
|
||||||
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
required = ['id_server', 'client_secret', 'email', 'send_attempt']
|
||||||
|
absent = []
|
||||||
|
for k in required:
|
||||||
|
if k not in body:
|
||||||
|
absent.append(k)
|
||||||
|
|
||||||
|
if absent:
|
||||||
|
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||||
|
|
||||||
|
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
|
'email', body['email']
|
||||||
|
)
|
||||||
|
|
||||||
|
if existingUid is None:
|
||||||
|
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
|
||||||
|
|
||||||
|
ret = yield self.identity_handler.requestEmailToken(**body)
|
||||||
|
defer.returnValue((200, ret))
|
||||||
|
|
||||||
|
|
||||||
class PasswordRestServlet(RestServlet):
|
class PasswordRestServlet(RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/account/password")
|
PATTERNS = client_v2_patterns("/account/password$")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(PasswordRestServlet, self).__init__()
|
super(PasswordRestServlet, self).__init__()
|
||||||
|
@ -89,8 +121,83 @@ class PasswordRestServlet(RestServlet):
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
|
|
||||||
|
class DeactivateAccountRestServlet(RestServlet):
|
||||||
|
PATTERNS = client_v2_patterns("/account/deactivate$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.hs = hs
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.auth_handler = hs.get_auth_handler()
|
||||||
|
super(DeactivateAccountRestServlet, self).__init__()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
authed, result, params, _ = yield self.auth_handler.check_auth([
|
||||||
|
[LoginType.PASSWORD],
|
||||||
|
], body, self.hs.get_ip_from_request(request))
|
||||||
|
|
||||||
|
if not authed:
|
||||||
|
defer.returnValue((401, result))
|
||||||
|
|
||||||
|
user_id = None
|
||||||
|
requester = None
|
||||||
|
|
||||||
|
if LoginType.PASSWORD in result:
|
||||||
|
# if using password, they should also be logged in
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
if user_id != result[LoginType.PASSWORD]:
|
||||||
|
raise LoginError(400, "", Codes.UNKNOWN)
|
||||||
|
else:
|
||||||
|
logger.error("Auth succeeded but no known type!", result.keys())
|
||||||
|
raise SynapseError(500, "", Codes.UNKNOWN)
|
||||||
|
|
||||||
|
# FIXME: Theoretically there is a race here wherein user resets password
|
||||||
|
# using threepid.
|
||||||
|
yield self.store.user_delete_access_tokens(user_id)
|
||||||
|
yield self.store.user_delete_threepids(user_id)
|
||||||
|
yield self.store.user_set_password_hash(user_id, None)
|
||||||
|
|
||||||
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
|
class ThreepidRequestTokenRestServlet(RestServlet):
|
||||||
|
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.hs = hs
|
||||||
|
super(ThreepidRequestTokenRestServlet, self).__init__()
|
||||||
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
required = ['id_server', 'client_secret', 'email', 'send_attempt']
|
||||||
|
absent = []
|
||||||
|
for k in required:
|
||||||
|
if k not in body:
|
||||||
|
absent.append(k)
|
||||||
|
|
||||||
|
if absent:
|
||||||
|
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||||
|
|
||||||
|
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
|
'email', body['email']
|
||||||
|
)
|
||||||
|
|
||||||
|
if existingUid is not None:
|
||||||
|
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
||||||
|
|
||||||
|
ret = yield self.identity_handler.requestEmailToken(**body)
|
||||||
|
defer.returnValue((200, ret))
|
||||||
|
|
||||||
|
|
||||||
class ThreepidRestServlet(RestServlet):
|
class ThreepidRestServlet(RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/account/3pid")
|
PATTERNS = client_v2_patterns("/account/3pid$")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ThreepidRestServlet, self).__init__()
|
super(ThreepidRestServlet, self).__init__()
|
||||||
|
@ -157,5 +264,8 @@ class ThreepidRestServlet(RestServlet):
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
|
PasswordRequestTokenRestServlet(hs).register(http_server)
|
||||||
PasswordRestServlet(hs).register(http_server)
|
PasswordRestServlet(hs).register(http_server)
|
||||||
|
DeactivateAccountRestServlet(hs).register(http_server)
|
||||||
|
ThreepidRequestTokenRestServlet(hs).register(http_server)
|
||||||
ThreepidRestServlet(hs).register(http_server)
|
ThreepidRestServlet(hs).register(http_server)
|
||||||
|
|
100
synapse/rest/client/v2_alpha/devices.py
Normal file
100
synapse/rest/client/v2_alpha/devices.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.http import servlet
|
||||||
|
from ._base import client_v2_patterns
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DevicesRestServlet(servlet.RestServlet):
|
||||||
|
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer): server
|
||||||
|
"""
|
||||||
|
super(DevicesRestServlet, self).__init__()
|
||||||
|
self.hs = hs
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.device_handler = hs.get_device_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
devices = yield self.device_handler.get_devices_by_user(
|
||||||
|
requester.user.to_string()
|
||||||
|
)
|
||||||
|
defer.returnValue((200, {"devices": devices}))
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceRestServlet(servlet.RestServlet):
|
||||||
|
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
|
||||||
|
releases=[], v2_alpha=False)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer): server
|
||||||
|
"""
|
||||||
|
super(DeviceRestServlet, self).__init__()
|
||||||
|
self.hs = hs
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.device_handler = hs.get_device_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, device_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
device = yield self.device_handler.get_device(
|
||||||
|
requester.user.to_string(),
|
||||||
|
device_id,
|
||||||
|
)
|
||||||
|
defer.returnValue((200, device))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_DELETE(self, request, device_id):
|
||||||
|
# XXX: it's not completely obvious we want to expose this endpoint.
|
||||||
|
# It allows the client to delete access tokens, which feels like a
|
||||||
|
# thing which merits extra auth. But if we want to do the interactive-
|
||||||
|
# auth dance, we should really make it possible to delete more than one
|
||||||
|
# device at a time.
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
yield self.device_handler.delete_device(
|
||||||
|
requester.user.to_string(),
|
||||||
|
device_id,
|
||||||
|
)
|
||||||
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_PUT(self, request, device_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
body = servlet.parse_json_object_from_request(request)
|
||||||
|
yield self.device_handler.update_device(
|
||||||
|
requester.user.to_string(),
|
||||||
|
device_id,
|
||||||
|
body
|
||||||
|
)
|
||||||
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
|
def register_servlets(hs, http_server):
|
||||||
|
DevicesRestServlet(hs).register(http_server)
|
||||||
|
DeviceRestServlet(hs).register(http_server)
|
|
@ -13,24 +13,25 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import simplejson as json
|
||||||
|
from canonicaljson import encode_canonical_json
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import synapse.api.errors
|
||||||
|
import synapse.server
|
||||||
|
import synapse.types
|
||||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
|
||||||
|
|
||||||
from ._base import client_v2_patterns
|
from ._base import client_v2_patterns
|
||||||
|
|
||||||
import logging
|
|
||||||
import simplejson as json
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class KeyUploadServlet(RestServlet):
|
class KeyUploadServlet(RestServlet):
|
||||||
"""
|
"""
|
||||||
POST /keys/upload/<device_id> HTTP/1.1
|
POST /keys/upload HTTP/1.1
|
||||||
Content-Type: application/json
|
Content-Type: application/json
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -53,23 +54,45 @@ class KeyUploadServlet(RestServlet):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
PATTERNS = client_v2_patterns("/keys/upload/(?P<device_id>[^/]*)", releases=())
|
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
|
||||||
|
releases=())
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer): server
|
||||||
|
"""
|
||||||
super(KeyUploadServlet, self).__init__()
|
super(KeyUploadServlet, self).__init__()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
self.device_handler = hs.get_device_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, device_id):
|
def on_POST(self, request, device_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
# TODO: Check that the device_id matches that in the authentication
|
|
||||||
# or derive the device_id from the authentication instead.
|
|
||||||
|
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
if device_id is not None:
|
||||||
|
# passing the device_id here is deprecated; however, we allow it
|
||||||
|
# for now for compatibility with older clients.
|
||||||
|
if (requester.device_id is not None and
|
||||||
|
device_id != requester.device_id):
|
||||||
|
logger.warning("Client uploading keys for a different device "
|
||||||
|
"(logged in as %s, uploading for %s)",
|
||||||
|
requester.device_id, device_id)
|
||||||
|
else:
|
||||||
|
device_id = requester.device_id
|
||||||
|
|
||||||
|
if device_id is None:
|
||||||
|
raise synapse.api.errors.SynapseError(
|
||||||
|
400,
|
||||||
|
"To upload keys, you must pass device_id when authenticating"
|
||||||
|
)
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
|
|
||||||
# TODO: Validate the JSON to make sure it has the right keys.
|
# TODO: Validate the JSON to make sure it has the right keys.
|
||||||
|
@ -102,13 +125,12 @@ class KeyUploadServlet(RestServlet):
|
||||||
user_id, device_id, time_now, key_list
|
user_id, device_id, time_now, key_list
|
||||||
)
|
)
|
||||||
|
|
||||||
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
# the device should have been registered already, but it may have been
|
||||||
defer.returnValue((200, {"one_time_key_counts": result}))
|
# deleted due to a race with a DELETE request. Or we may be using an
|
||||||
|
# old access_token without an associated device_id. Either way, we
|
||||||
@defer.inlineCallbacks
|
# need to double-check the device is registered to avoid ending up with
|
||||||
def on_GET(self, request, device_id):
|
# keys without a corresponding device.
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
self.device_handler.check_device_registered(user_id, device_id)
|
||||||
user_id = requester.user.to_string()
|
|
||||||
|
|
||||||
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||||
defer.returnValue((200, {"one_time_key_counts": result}))
|
defer.returnValue((200, {"one_time_key_counts": result}))
|
||||||
|
@ -162,17 +184,19 @@ class KeyQueryServlet(RestServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer):
|
||||||
|
"""
|
||||||
super(KeyQueryServlet, self).__init__()
|
super(KeyQueryServlet, self).__init__()
|
||||||
self.store = hs.get_datastore()
|
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.federation = hs.get_replication_layer()
|
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||||
self.is_mine = hs.is_mine
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, user_id, device_id):
|
def on_POST(self, request, user_id, device_id):
|
||||||
yield self.auth.get_user_by_req(request)
|
yield self.auth.get_user_by_req(request)
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
result = yield self.handle_request(body)
|
result = yield self.e2e_keys_handler.query_devices(body)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -181,45 +205,11 @@ class KeyQueryServlet(RestServlet):
|
||||||
auth_user_id = requester.user.to_string()
|
auth_user_id = requester.user.to_string()
|
||||||
user_id = user_id if user_id else auth_user_id
|
user_id = user_id if user_id else auth_user_id
|
||||||
device_ids = [device_id] if device_id else []
|
device_ids = [device_id] if device_id else []
|
||||||
result = yield self.handle_request(
|
result = yield self.e2e_keys_handler.query_devices(
|
||||||
{"device_keys": {user_id: device_ids}}
|
{"device_keys": {user_id: device_ids}}
|
||||||
)
|
)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def handle_request(self, body):
|
|
||||||
local_query = []
|
|
||||||
remote_queries = {}
|
|
||||||
for user_id, device_ids in body.get("device_keys", {}).items():
|
|
||||||
user = UserID.from_string(user_id)
|
|
||||||
if self.is_mine(user):
|
|
||||||
if not device_ids:
|
|
||||||
local_query.append((user_id, None))
|
|
||||||
else:
|
|
||||||
for device_id in device_ids:
|
|
||||||
local_query.append((user_id, device_id))
|
|
||||||
else:
|
|
||||||
remote_queries.setdefault(user.domain, {})[user_id] = list(
|
|
||||||
device_ids
|
|
||||||
)
|
|
||||||
results = yield self.store.get_e2e_device_keys(local_query)
|
|
||||||
|
|
||||||
json_result = {}
|
|
||||||
for user_id, device_keys in results.items():
|
|
||||||
for device_id, json_bytes in device_keys.items():
|
|
||||||
json_result.setdefault(user_id, {})[device_id] = json.loads(
|
|
||||||
json_bytes
|
|
||||||
)
|
|
||||||
|
|
||||||
for destination, device_keys in remote_queries.items():
|
|
||||||
remote_result = yield self.federation.query_client_keys(
|
|
||||||
destination, {"device_keys": device_keys}
|
|
||||||
)
|
|
||||||
for user_id, keys in remote_result["device_keys"].items():
|
|
||||||
if user_id in device_keys:
|
|
||||||
json_result[user_id] = keys
|
|
||||||
defer.returnValue((200, {"device_keys": json_result}))
|
|
||||||
|
|
||||||
|
|
||||||
class OneTimeKeyServlet(RestServlet):
|
class OneTimeKeyServlet(RestServlet):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -41,17 +41,59 @@ else:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RegisterRestServlet(RestServlet):
|
class RegisterRequestTokenRestServlet(RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/register")
|
PATTERNS = client_v2_patterns("/register/email/requestToken$")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer): server
|
||||||
|
"""
|
||||||
|
super(RegisterRequestTokenRestServlet, self).__init__()
|
||||||
|
self.hs = hs
|
||||||
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
required = ['id_server', 'client_secret', 'email', 'send_attempt']
|
||||||
|
absent = []
|
||||||
|
for k in required:
|
||||||
|
if k not in body:
|
||||||
|
absent.append(k)
|
||||||
|
|
||||||
|
if len(absent) > 0:
|
||||||
|
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||||
|
|
||||||
|
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
|
'email', body['email']
|
||||||
|
)
|
||||||
|
|
||||||
|
if existingUid is not None:
|
||||||
|
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
||||||
|
|
||||||
|
ret = yield self.identity_handler.requestEmailToken(**body)
|
||||||
|
defer.returnValue((200, ret))
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterRestServlet(RestServlet):
|
||||||
|
PATTERNS = client_v2_patterns("/register$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer): server
|
||||||
|
"""
|
||||||
super(RegisterRestServlet, self).__init__()
|
super(RegisterRestServlet, self).__init__()
|
||||||
|
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
self.registration_handler = hs.get_handlers().registration_handler
|
self.registration_handler = hs.get_handlers().registration_handler
|
||||||
self.identity_handler = hs.get_handlers().identity_handler
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
|
self.device_handler = hs.get_device_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
|
@ -70,10 +112,6 @@ class RegisterRestServlet(RestServlet):
|
||||||
"Do not understand membership kind: %s" % (kind,)
|
"Do not understand membership kind: %s" % (kind,)
|
||||||
)
|
)
|
||||||
|
|
||||||
if '/register/email/requestToken' in request.path:
|
|
||||||
ret = yield self.onEmailTokenRequest(request)
|
|
||||||
defer.returnValue(ret)
|
|
||||||
|
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
# we do basic sanity checks here because the auth layer will store these
|
# we do basic sanity checks here because the auth layer will store these
|
||||||
|
@ -104,10 +142,11 @@ class RegisterRestServlet(RestServlet):
|
||||||
# Set the desired user according to the AS API (which uses the
|
# Set the desired user according to the AS API (which uses the
|
||||||
# 'user' key not 'username'). Since this is a new addition, we'll
|
# 'user' key not 'username'). Since this is a new addition, we'll
|
||||||
# fallback to 'username' if they gave one.
|
# fallback to 'username' if they gave one.
|
||||||
if isinstance(body.get("user"), basestring):
|
desired_username = body.get("user", desired_username)
|
||||||
desired_username = body["user"]
|
|
||||||
|
if isinstance(desired_username, basestring):
|
||||||
result = yield self._do_appservice_registration(
|
result = yield self._do_appservice_registration(
|
||||||
desired_username, request.args["access_token"][0]
|
desired_username, request.args["access_token"][0], body
|
||||||
)
|
)
|
||||||
defer.returnValue((200, result)) # we throw for non 200 responses
|
defer.returnValue((200, result)) # we throw for non 200 responses
|
||||||
return
|
return
|
||||||
|
@ -117,7 +156,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
# FIXME: Should we really be determining if this is shared secret
|
# FIXME: Should we really be determining if this is shared secret
|
||||||
# auth based purely on the 'mac' key?
|
# auth based purely on the 'mac' key?
|
||||||
result = yield self._do_shared_secret_registration(
|
result = yield self._do_shared_secret_registration(
|
||||||
desired_username, desired_password, body["mac"]
|
desired_username, desired_password, body
|
||||||
)
|
)
|
||||||
defer.returnValue((200, result)) # we throw for non 200 responses
|
defer.returnValue((200, result)) # we throw for non 200 responses
|
||||||
return
|
return
|
||||||
|
@ -157,12 +196,12 @@ class RegisterRestServlet(RestServlet):
|
||||||
[LoginType.EMAIL_IDENTITY]
|
[LoginType.EMAIL_IDENTITY]
|
||||||
]
|
]
|
||||||
|
|
||||||
authed, result, params, session_id = yield self.auth_handler.check_auth(
|
authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
|
||||||
flows, body, self.hs.get_ip_from_request(request)
|
flows, body, self.hs.get_ip_from_request(request)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not authed:
|
if not authed:
|
||||||
defer.returnValue((401, result))
|
defer.returnValue((401, auth_result))
|
||||||
return
|
return
|
||||||
|
|
||||||
if registered_user_id is not None:
|
if registered_user_id is not None:
|
||||||
|
@ -170,106 +209,58 @@ class RegisterRestServlet(RestServlet):
|
||||||
"Already registered user ID %r for this session",
|
"Already registered user ID %r for this session",
|
||||||
registered_user_id
|
registered_user_id
|
||||||
)
|
)
|
||||||
access_token = yield self.auth_handler.issue_access_token(registered_user_id)
|
# don't re-register the email address
|
||||||
refresh_token = yield self.auth_handler.issue_refresh_token(
|
add_email = False
|
||||||
registered_user_id
|
else:
|
||||||
)
|
|
||||||
defer.returnValue((200, {
|
|
||||||
"user_id": registered_user_id,
|
|
||||||
"access_token": access_token,
|
|
||||||
"home_server": self.hs.hostname,
|
|
||||||
"refresh_token": refresh_token,
|
|
||||||
}))
|
|
||||||
|
|
||||||
# NB: This may be from the auth handler and NOT from the POST
|
# NB: This may be from the auth handler and NOT from the POST
|
||||||
if 'password' not in params:
|
if 'password' not in params:
|
||||||
raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
|
raise SynapseError(400, "Missing password.",
|
||||||
|
Codes.MISSING_PARAM)
|
||||||
|
|
||||||
desired_username = params.get("username", None)
|
desired_username = params.get("username", None)
|
||||||
new_password = params.get("password", None)
|
new_password = params.get("password", None)
|
||||||
guest_access_token = params.get("guest_access_token", None)
|
guest_access_token = params.get("guest_access_token", None)
|
||||||
|
|
||||||
(user_id, token) = yield self.registration_handler.register(
|
(registered_user_id, _) = yield self.registration_handler.register(
|
||||||
localpart=desired_username,
|
localpart=desired_username,
|
||||||
password=new_password,
|
password=new_password,
|
||||||
guest_access_token=guest_access_token,
|
guest_access_token=guest_access_token,
|
||||||
|
generate_token=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# remember that we've now registered that user account, and with what
|
# remember that we've now registered that user account, and with
|
||||||
# user ID (since the user may not have specified)
|
# what user ID (since the user may not have specified)
|
||||||
self.auth_handler.set_session_data(
|
self.auth_handler.set_session_data(
|
||||||
session_id, "registered_user_id", user_id
|
session_id, "registered_user_id", registered_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if result and LoginType.EMAIL_IDENTITY in result:
|
add_email = True
|
||||||
threepid = result[LoginType.EMAIL_IDENTITY]
|
|
||||||
|
|
||||||
for reqd in ['medium', 'address', 'validated_at']:
|
return_dict = yield self._create_registration_details(
|
||||||
if reqd not in threepid:
|
registered_user_id, params
|
||||||
logger.info("Can't add incomplete 3pid")
|
|
||||||
else:
|
|
||||||
yield self.auth_handler.add_threepid(
|
|
||||||
user_id,
|
|
||||||
threepid['medium'],
|
|
||||||
threepid['address'],
|
|
||||||
threepid['validated_at'],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# And we add an email pusher for them by default, but only
|
if add_email and auth_result and LoginType.EMAIL_IDENTITY in auth_result:
|
||||||
# if email notifications are enabled (so people don't start
|
threepid = auth_result[LoginType.EMAIL_IDENTITY]
|
||||||
# getting mail spam where they weren't before if email
|
yield self._register_email_threepid(
|
||||||
# notifs are set up on a home server)
|
registered_user_id, threepid, return_dict["access_token"],
|
||||||
if (
|
params.get("bind_email")
|
||||||
self.hs.config.email_enable_notifs and
|
|
||||||
self.hs.config.email_notif_for_new_users
|
|
||||||
):
|
|
||||||
# Pull the ID of the access token back out of the db
|
|
||||||
# It would really make more sense for this to be passed
|
|
||||||
# up when the access token is saved, but that's quite an
|
|
||||||
# invasive change I'd rather do separately.
|
|
||||||
user_tuple = yield self.store.get_user_by_access_token(
|
|
||||||
token
|
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.hs.get_pusherpool().add_pusher(
|
defer.returnValue((200, return_dict))
|
||||||
user_id=user_id,
|
|
||||||
access_token=user_tuple["token_id"],
|
|
||||||
kind="email",
|
|
||||||
app_id="m.email",
|
|
||||||
app_display_name="Email Notifications",
|
|
||||||
device_display_name=threepid["address"],
|
|
||||||
pushkey=threepid["address"],
|
|
||||||
lang=None, # We don't know a user's language here
|
|
||||||
data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
if 'bind_email' in params and params['bind_email']:
|
|
||||||
logger.info("bind_email specified: binding")
|
|
||||||
|
|
||||||
emailThreepid = result[LoginType.EMAIL_IDENTITY]
|
|
||||||
threepid_creds = emailThreepid['threepid_creds']
|
|
||||||
logger.debug("Binding emails %s to %s" % (
|
|
||||||
emailThreepid, user_id
|
|
||||||
))
|
|
||||||
yield self.identity_handler.bind_threepid(threepid_creds, user_id)
|
|
||||||
else:
|
|
||||||
logger.info("bind_email not specified: not binding email")
|
|
||||||
|
|
||||||
result = yield self._create_registration_details(user_id, token)
|
|
||||||
defer.returnValue((200, result))
|
|
||||||
|
|
||||||
def on_OPTIONS(self, _):
|
def on_OPTIONS(self, _):
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _do_appservice_registration(self, username, as_token):
|
def _do_appservice_registration(self, username, as_token, body):
|
||||||
(user_id, token) = yield self.registration_handler.appservice_register(
|
user_id = yield self.registration_handler.appservice_register(
|
||||||
username, as_token
|
username, as_token
|
||||||
)
|
)
|
||||||
defer.returnValue((yield self._create_registration_details(user_id, token)))
|
defer.returnValue((yield self._create_registration_details(user_id, body)))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _do_shared_secret_registration(self, username, password, mac):
|
def _do_shared_secret_registration(self, username, password, body):
|
||||||
if not self.hs.config.registration_shared_secret:
|
if not self.hs.config.registration_shared_secret:
|
||||||
raise SynapseError(400, "Shared secret registration is not enabled")
|
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||||
|
|
||||||
|
@ -277,7 +268,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
|
|
||||||
# str() because otherwise hmac complains that 'unicode' does not
|
# str() because otherwise hmac complains that 'unicode' does not
|
||||||
# have the buffer interface
|
# have the buffer interface
|
||||||
got_mac = str(mac)
|
got_mac = str(body["mac"])
|
||||||
|
|
||||||
want_mac = hmac.new(
|
want_mac = hmac.new(
|
||||||
key=self.hs.config.registration_shared_secret,
|
key=self.hs.config.registration_shared_secret,
|
||||||
|
@ -290,43 +281,132 @@ class RegisterRestServlet(RestServlet):
|
||||||
403, "HMAC incorrect",
|
403, "HMAC incorrect",
|
||||||
)
|
)
|
||||||
|
|
||||||
(user_id, token) = yield self.registration_handler.register(
|
(user_id, _) = yield self.registration_handler.register(
|
||||||
localpart=username, password=password
|
localpart=username, password=password, generate_token=False,
|
||||||
)
|
)
|
||||||
defer.returnValue((yield self._create_registration_details(user_id, token)))
|
|
||||||
|
result = yield self._create_registration_details(user_id, body)
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _create_registration_details(self, user_id, token):
|
def _register_email_threepid(self, user_id, threepid, token, bind_email):
|
||||||
refresh_token = yield self.auth_handler.issue_refresh_token(user_id)
|
"""Add an email address as a 3pid identifier
|
||||||
|
|
||||||
|
Also adds an email pusher for the email address, if configured in the
|
||||||
|
HS config
|
||||||
|
|
||||||
|
Also optionally binds emails to the given user_id on the identity server
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): id of user
|
||||||
|
threepid (object): m.login.email.identity auth response
|
||||||
|
token (str): access_token for the user
|
||||||
|
bind_email (bool): true if the client requested the email to be
|
||||||
|
bound at the identity server
|
||||||
|
Returns:
|
||||||
|
defer.Deferred:
|
||||||
|
"""
|
||||||
|
reqd = ('medium', 'address', 'validated_at')
|
||||||
|
if any(x not in threepid for x in reqd):
|
||||||
|
logger.info("Can't add incomplete 3pid")
|
||||||
|
defer.returnValue()
|
||||||
|
|
||||||
|
yield self.auth_handler.add_threepid(
|
||||||
|
user_id,
|
||||||
|
threepid['medium'],
|
||||||
|
threepid['address'],
|
||||||
|
threepid['validated_at'],
|
||||||
|
)
|
||||||
|
|
||||||
|
# And we add an email pusher for them by default, but only
|
||||||
|
# if email notifications are enabled (so people don't start
|
||||||
|
# getting mail spam where they weren't before if email
|
||||||
|
# notifs are set up on a home server)
|
||||||
|
if (self.hs.config.email_enable_notifs and
|
||||||
|
self.hs.config.email_notif_for_new_users):
|
||||||
|
# Pull the ID of the access token back out of the db
|
||||||
|
# It would really make more sense for this to be passed
|
||||||
|
# up when the access token is saved, but that's quite an
|
||||||
|
# invasive change I'd rather do separately.
|
||||||
|
user_tuple = yield self.store.get_user_by_access_token(
|
||||||
|
token
|
||||||
|
)
|
||||||
|
token_id = user_tuple["token_id"]
|
||||||
|
|
||||||
|
yield self.hs.get_pusherpool().add_pusher(
|
||||||
|
user_id=user_id,
|
||||||
|
access_token=token_id,
|
||||||
|
kind="email",
|
||||||
|
app_id="m.email",
|
||||||
|
app_display_name="Email Notifications",
|
||||||
|
device_display_name=threepid["address"],
|
||||||
|
pushkey=threepid["address"],
|
||||||
|
lang=None, # We don't know a user's language here
|
||||||
|
data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
if bind_email:
|
||||||
|
logger.info("bind_email specified: binding")
|
||||||
|
logger.debug("Binding emails %s to %s" % (
|
||||||
|
threepid, user_id
|
||||||
|
))
|
||||||
|
yield self.identity_handler.bind_threepid(
|
||||||
|
threepid['threepid_creds'], user_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("bind_email not specified: not binding email")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _create_registration_details(self, user_id, params):
|
||||||
|
"""Complete registration of newly-registered user
|
||||||
|
|
||||||
|
Allocates device_id if one was not given; also creates access_token
|
||||||
|
and refresh_token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
(str) user_id: full canonical @user:id
|
||||||
|
(object) params: registration parameters, from which we pull
|
||||||
|
device_id and initial_device_name
|
||||||
|
Returns:
|
||||||
|
defer.Deferred: (object) dictionary for response from /register
|
||||||
|
"""
|
||||||
|
device_id = yield self._register_device(user_id, params)
|
||||||
|
|
||||||
|
access_token, refresh_token = (
|
||||||
|
yield self.auth_handler.get_login_tuple_for_user_id(
|
||||||
|
user_id, device_id=device_id,
|
||||||
|
initial_display_name=params.get("initial_device_display_name")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"access_token": token,
|
"access_token": access_token,
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
"refresh_token": refresh_token,
|
"refresh_token": refresh_token,
|
||||||
|
"device_id": device_id,
|
||||||
})
|
})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
def _register_device(self, user_id, params):
|
||||||
def onEmailTokenRequest(self, request):
|
"""Register a device for a user.
|
||||||
body = parse_json_object_from_request(request)
|
|
||||||
|
|
||||||
required = ['id_server', 'client_secret', 'email', 'send_attempt']
|
This is called after the user's credentials have been validated, but
|
||||||
absent = []
|
before the access token has been issued.
|
||||||
for k in required:
|
|
||||||
if k not in body:
|
|
||||||
absent.append(k)
|
|
||||||
|
|
||||||
if len(absent) > 0:
|
Args:
|
||||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
(str) user_id: full canonical @user:id
|
||||||
|
(object) params: registration parameters, from which we pull
|
||||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
device_id and initial_device_name
|
||||||
'email', body['email']
|
Returns:
|
||||||
|
defer.Deferred: (str) device_id
|
||||||
|
"""
|
||||||
|
# register the user's device
|
||||||
|
device_id = params.get("device_id")
|
||||||
|
initial_display_name = params.get("initial_device_display_name")
|
||||||
|
device_id = self.device_handler.check_device_registered(
|
||||||
|
user_id, device_id, initial_display_name
|
||||||
)
|
)
|
||||||
|
return device_id
|
||||||
if existingUid is not None:
|
|
||||||
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
|
||||||
|
|
||||||
ret = yield self.identity_handler.requestEmailToken(**body)
|
|
||||||
defer.returnValue((200, ret))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _do_guest_registration(self):
|
def _do_guest_registration(self):
|
||||||
|
@ -336,7 +416,11 @@ class RegisterRestServlet(RestServlet):
|
||||||
generate_token=False,
|
generate_token=False,
|
||||||
make_guest=True
|
make_guest=True
|
||||||
)
|
)
|
||||||
access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"])
|
access_token = self.auth_handler.generate_access_token(
|
||||||
|
user_id, ["guest = true"]
|
||||||
|
)
|
||||||
|
# XXX the "guest" caveat is not copied by /tokenrefresh. That's ok
|
||||||
|
# so long as we don't return a refresh_token here.
|
||||||
defer.returnValue((200, {
|
defer.returnValue((200, {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
|
@ -345,4 +429,5 @@ class RegisterRestServlet(RestServlet):
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
|
RegisterRequestTokenRestServlet(hs).register(http_server)
|
||||||
RegisterRestServlet(hs).register(http_server)
|
RegisterRestServlet(hs).register(http_server)
|
||||||
|
|
|
@ -39,9 +39,13 @@ class TokenRefreshRestServlet(RestServlet):
|
||||||
try:
|
try:
|
||||||
old_refresh_token = body["refresh_token"]
|
old_refresh_token = body["refresh_token"]
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
|
refresh_result = yield self.store.exchange_refresh_token(
|
||||||
old_refresh_token, auth_handler.generate_refresh_token)
|
old_refresh_token, auth_handler.generate_refresh_token
|
||||||
new_access_token = yield auth_handler.issue_access_token(user_id)
|
)
|
||||||
|
(user_id, new_refresh_token, device_id) = refresh_result
|
||||||
|
new_access_token = yield auth_handler.issue_access_token(
|
||||||
|
user_id, device_id
|
||||||
|
)
|
||||||
defer.returnValue((200, {
|
defer.returnValue((200, {
|
||||||
"access_token": new_access_token,
|
"access_token": new_access_token,
|
||||||
"refresh_token": new_refresh_token,
|
"refresh_token": new_refresh_token,
|
||||||
|
|
|
@ -26,7 +26,11 @@ class VersionsRestServlet(RestServlet):
|
||||||
|
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
return (200, {
|
return (200, {
|
||||||
"versions": ["r0.0.1"]
|
"versions": [
|
||||||
|
"r0.0.1",
|
||||||
|
"r0.1.0",
|
||||||
|
"r0.2.0",
|
||||||
|
]
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,14 +15,12 @@
|
||||||
|
|
||||||
from synapse.http.server import respond_with_json_bytes, finish_request
|
from synapse.http.server import respond_with_json_bytes, finish_request
|
||||||
|
|
||||||
from synapse.util.stringutils import random_string
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
cs_exception, SynapseError, CodeMessageException, Codes, cs_error
|
Codes, cs_error
|
||||||
)
|
)
|
||||||
|
|
||||||
from twisted.protocols.basic import FileSender
|
from twisted.protocols.basic import FileSender
|
||||||
from twisted.web import server, resource
|
from twisted.web import server, resource
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
|
@ -50,64 +48,10 @@ class ContentRepoResource(resource.Resource):
|
||||||
"""
|
"""
|
||||||
isLeaf = True
|
isLeaf = True
|
||||||
|
|
||||||
def __init__(self, hs, directory, auth, external_addr):
|
def __init__(self, hs, directory):
|
||||||
resource.Resource.__init__(self)
|
resource.Resource.__init__(self)
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.directory = directory
|
self.directory = directory
|
||||||
self.auth = auth
|
|
||||||
self.external_addr = external_addr.rstrip('/')
|
|
||||||
self.max_upload_size = hs.config.max_upload_size
|
|
||||||
|
|
||||||
if not os.path.isdir(self.directory):
|
|
||||||
os.mkdir(self.directory)
|
|
||||||
logger.info("ContentRepoResource : Created %s directory.",
|
|
||||||
self.directory)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def map_request_to_name(self, request):
|
|
||||||
# auth the user
|
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
|
||||||
|
|
||||||
# namespace all file uploads on the user
|
|
||||||
prefix = base64.urlsafe_b64encode(
|
|
||||||
requester.user.to_string()
|
|
||||||
).replace('=', '')
|
|
||||||
|
|
||||||
# use a random string for the main portion
|
|
||||||
main_part = random_string(24)
|
|
||||||
|
|
||||||
# suffix with a file extension if we can make one. This is nice to
|
|
||||||
# provide a hint to clients on the file information. We will also reuse
|
|
||||||
# this info to spit back the content type to the client.
|
|
||||||
suffix = ""
|
|
||||||
if request.requestHeaders.hasHeader("Content-Type"):
|
|
||||||
content_type = request.requestHeaders.getRawHeaders(
|
|
||||||
"Content-Type")[0]
|
|
||||||
suffix = "." + base64.urlsafe_b64encode(content_type)
|
|
||||||
if (content_type.split("/")[0].lower() in
|
|
||||||
["image", "video", "audio"]):
|
|
||||||
file_ext = content_type.split("/")[-1]
|
|
||||||
# be a little paranoid and only allow a-z
|
|
||||||
file_ext = re.sub("[^a-z]", "", file_ext)
|
|
||||||
suffix += "." + file_ext
|
|
||||||
|
|
||||||
file_name = prefix + main_part + suffix
|
|
||||||
file_path = os.path.join(self.directory, file_name)
|
|
||||||
logger.info("User %s is uploading a file to path %s",
|
|
||||||
request.user.user_id.to_string(),
|
|
||||||
file_path)
|
|
||||||
|
|
||||||
# keep trying to make a non-clashing file, with a sensible max attempts
|
|
||||||
attempts = 0
|
|
||||||
while os.path.exists(file_path):
|
|
||||||
main_part = random_string(24)
|
|
||||||
file_name = prefix + main_part + suffix
|
|
||||||
file_path = os.path.join(self.directory, file_name)
|
|
||||||
attempts += 1
|
|
||||||
if attempts > 25: # really? Really?
|
|
||||||
raise SynapseError(500, "Unable to create file.")
|
|
||||||
|
|
||||||
defer.returnValue(file_path)
|
|
||||||
|
|
||||||
def render_GET(self, request):
|
def render_GET(self, request):
|
||||||
# no auth here on purpose, to allow anyone to view, even across home
|
# no auth here on purpose, to allow anyone to view, even across home
|
||||||
|
@ -155,58 +99,6 @@ class ContentRepoResource(resource.Resource):
|
||||||
|
|
||||||
return server.NOT_DONE_YET
|
return server.NOT_DONE_YET
|
||||||
|
|
||||||
def render_POST(self, request):
|
|
||||||
self._async_render(request)
|
|
||||||
return server.NOT_DONE_YET
|
|
||||||
|
|
||||||
def render_OPTIONS(self, request):
|
def render_OPTIONS(self, request):
|
||||||
respond_with_json_bytes(request, 200, {}, send_cors=True)
|
respond_with_json_bytes(request, 200, {}, send_cors=True)
|
||||||
return server.NOT_DONE_YET
|
return server.NOT_DONE_YET
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _async_render(self, request):
|
|
||||||
try:
|
|
||||||
# TODO: The checks here are a bit late. The content will have
|
|
||||||
# already been uploaded to a tmp file at this point
|
|
||||||
content_length = request.getHeader("Content-Length")
|
|
||||||
if content_length is None:
|
|
||||||
raise SynapseError(
|
|
||||||
msg="Request must specify a Content-Length", code=400
|
|
||||||
)
|
|
||||||
if int(content_length) > self.max_upload_size:
|
|
||||||
raise SynapseError(
|
|
||||||
msg="Upload request body is too large",
|
|
||||||
code=413,
|
|
||||||
)
|
|
||||||
|
|
||||||
fname = yield self.map_request_to_name(request)
|
|
||||||
|
|
||||||
# TODO I have a suspicious feeling this is just going to block
|
|
||||||
with open(fname, "wb") as f:
|
|
||||||
f.write(request.content.read())
|
|
||||||
|
|
||||||
# FIXME (erikj): These should use constants.
|
|
||||||
file_name = os.path.basename(fname)
|
|
||||||
# FIXME: we can't assume what the repo's public mounted path is
|
|
||||||
# ...plus self-signed SSL won't work to remote clients anyway
|
|
||||||
# ...and we can't assume that it's SSL anyway, as we might want to
|
|
||||||
# serve it via the non-SSL listener...
|
|
||||||
url = "%s/_matrix/content/%s" % (
|
|
||||||
self.external_addr, file_name
|
|
||||||
)
|
|
||||||
|
|
||||||
respond_with_json_bytes(request, 200,
|
|
||||||
json.dumps({"content_token": url}),
|
|
||||||
send_cors=True)
|
|
||||||
|
|
||||||
except CodeMessageException as e:
|
|
||||||
logger.exception(e)
|
|
||||||
respond_with_json_bytes(request, e.code,
|
|
||||||
json.dumps(cs_exception(e)))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Failed to store file: %s" % e)
|
|
||||||
respond_with_json_bytes(
|
|
||||||
request,
|
|
||||||
500,
|
|
||||||
json.dumps({"error": "Internal server error"}),
|
|
||||||
send_cors=True)
|
|
||||||
|
|
|
@ -65,3 +65,9 @@ class MediaFilePaths(object):
|
||||||
file_id[0:2], file_id[2:4], file_id[4:],
|
file_id[0:2], file_id[2:4], file_id[4:],
|
||||||
file_name
|
file_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def remote_media_thumbnail_dir(self, server_name, file_id):
|
||||||
|
return os.path.join(
|
||||||
|
self.base_path, "remote_thumbnail", server_name,
|
||||||
|
file_id[0:2], file_id[2:4], file_id[4:],
|
||||||
|
)
|
||||||
|
|
|
@ -30,11 +30,13 @@ from synapse.api.errors import SynapseError
|
||||||
|
|
||||||
from twisted.internet import defer, threads
|
from twisted.internet import defer, threads
|
||||||
|
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import Linearizer
|
||||||
from synapse.util.stringutils import is_ascii
|
from synapse.util.stringutils import is_ascii
|
||||||
from synapse.util.logcontext import preserve_context_over_fn
|
from synapse.util.logcontext import preserve_context_over_fn
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import errno
|
||||||
|
import shutil
|
||||||
|
|
||||||
import cgi
|
import cgi
|
||||||
import logging
|
import logging
|
||||||
|
@ -43,8 +45,11 @@ import urlparse
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
|
||||||
|
|
||||||
|
|
||||||
class MediaRepository(object):
|
class MediaRepository(object):
|
||||||
def __init__(self, hs, filepaths):
|
def __init__(self, hs):
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.client = MatrixFederationHttpClient(hs)
|
self.client = MatrixFederationHttpClient(hs)
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
@ -52,11 +57,28 @@ class MediaRepository(object):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.max_upload_size = hs.config.max_upload_size
|
self.max_upload_size = hs.config.max_upload_size
|
||||||
self.max_image_pixels = hs.config.max_image_pixels
|
self.max_image_pixels = hs.config.max_image_pixels
|
||||||
self.filepaths = filepaths
|
self.filepaths = MediaFilePaths(hs.config.media_store_path)
|
||||||
self.downloads = {}
|
|
||||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||||
self.thumbnail_requirements = hs.config.thumbnail_requirements
|
self.thumbnail_requirements = hs.config.thumbnail_requirements
|
||||||
|
|
||||||
|
self.remote_media_linearizer = Linearizer()
|
||||||
|
|
||||||
|
self.recently_accessed_remotes = set()
|
||||||
|
|
||||||
|
self.clock.looping_call(
|
||||||
|
self._update_recently_accessed_remotes,
|
||||||
|
UPDATE_RECENTLY_ACCESSED_REMOTES_TS
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _update_recently_accessed_remotes(self):
|
||||||
|
media = self.recently_accessed_remotes
|
||||||
|
self.recently_accessed_remotes = set()
|
||||||
|
|
||||||
|
yield self.store.update_cached_last_access_time(
|
||||||
|
media, self.clock.time_msec()
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _makedirs(filepath):
|
def _makedirs(filepath):
|
||||||
dirname = os.path.dirname(filepath)
|
dirname = os.path.dirname(filepath)
|
||||||
|
@ -93,22 +115,12 @@ class MediaRepository(object):
|
||||||
|
|
||||||
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
|
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def get_remote_media(self, server_name, media_id):
|
def get_remote_media(self, server_name, media_id):
|
||||||
key = (server_name, media_id)
|
key = (server_name, media_id)
|
||||||
download = self.downloads.get(key)
|
with (yield self.remote_media_linearizer.queue(key)):
|
||||||
if download is None:
|
media_info = yield self._get_remote_media_impl(server_name, media_id)
|
||||||
download = self._get_remote_media_impl(server_name, media_id)
|
defer.returnValue(media_info)
|
||||||
download = ObservableDeferred(
|
|
||||||
download,
|
|
||||||
consumeErrors=True
|
|
||||||
)
|
|
||||||
self.downloads[key] = download
|
|
||||||
|
|
||||||
@download.addBoth
|
|
||||||
def callback(media_info):
|
|
||||||
del self.downloads[key]
|
|
||||||
return media_info
|
|
||||||
return download.observe()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_remote_media_impl(self, server_name, media_id):
|
def _get_remote_media_impl(self, server_name, media_id):
|
||||||
|
@ -119,6 +131,11 @@ class MediaRepository(object):
|
||||||
media_info = yield self._download_remote_file(
|
media_info = yield self._download_remote_file(
|
||||||
server_name, media_id
|
server_name, media_id
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.recently_accessed_remotes.add((server_name, media_id))
|
||||||
|
yield self.store.update_cached_last_access_time(
|
||||||
|
[(server_name, media_id)], self.clock.time_msec()
|
||||||
|
)
|
||||||
defer.returnValue(media_info)
|
defer.returnValue(media_info)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -416,6 +433,41 @@ class MediaRepository(object):
|
||||||
"height": m_height,
|
"height": m_height,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def delete_old_remote_media(self, before_ts):
|
||||||
|
old_media = yield self.store.get_remote_media_before(before_ts)
|
||||||
|
|
||||||
|
deleted = 0
|
||||||
|
|
||||||
|
for media in old_media:
|
||||||
|
origin = media["media_origin"]
|
||||||
|
media_id = media["media_id"]
|
||||||
|
file_id = media["filesystem_id"]
|
||||||
|
key = (origin, media_id)
|
||||||
|
|
||||||
|
logger.info("Deleting: %r", key)
|
||||||
|
|
||||||
|
with (yield self.remote_media_linearizer.queue(key)):
|
||||||
|
full_path = self.filepaths.remote_media_filepath(origin, file_id)
|
||||||
|
try:
|
||||||
|
os.remove(full_path)
|
||||||
|
except OSError as e:
|
||||||
|
logger.warn("Failed to remove file: %r", full_path)
|
||||||
|
if e.errno == errno.ENOENT:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
thumbnail_dir = self.filepaths.remote_media_thumbnail_dir(
|
||||||
|
origin, file_id
|
||||||
|
)
|
||||||
|
shutil.rmtree(thumbnail_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
yield self.store.delete_remote_media(origin, media_id)
|
||||||
|
deleted += 1
|
||||||
|
|
||||||
|
defer.returnValue({"deleted": deleted})
|
||||||
|
|
||||||
|
|
||||||
class MediaRepositoryResource(Resource):
|
class MediaRepositoryResource(Resource):
|
||||||
"""File uploading and downloading.
|
"""File uploading and downloading.
|
||||||
|
@ -464,9 +516,8 @@ class MediaRepositoryResource(Resource):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
Resource.__init__(self)
|
Resource.__init__(self)
|
||||||
filepaths = MediaFilePaths(hs.config.media_store_path)
|
|
||||||
|
|
||||||
media_repo = MediaRepository(hs, filepaths)
|
media_repo = hs.get_media_repository()
|
||||||
|
|
||||||
self.putChild("upload", UploadResource(hs, media_repo))
|
self.putChild("upload", UploadResource(hs, media_repo))
|
||||||
self.putChild("download", DownloadResource(hs, media_repo))
|
self.putChild("download", DownloadResource(hs, media_repo))
|
||||||
|
|
|
@ -29,6 +29,8 @@ from synapse.http.server import (
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
from synapse.util.stringutils import is_ascii
|
from synapse.util.stringutils import is_ascii
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import fnmatch
|
import fnmatch
|
||||||
|
@ -329,20 +331,24 @@ class PreviewUrlResource(Resource):
|
||||||
# ...or if they are within a <script/> or <style/> tag.
|
# ...or if they are within a <script/> or <style/> tag.
|
||||||
# This is a very very very coarse approximation to a plain text
|
# This is a very very very coarse approximation to a plain text
|
||||||
# render of the page.
|
# render of the page.
|
||||||
text_nodes = tree.xpath("//text()[not(ancestor::header | ancestor::nav | "
|
|
||||||
"ancestor::aside | ancestor::footer | "
|
# We don't just use XPATH here as that is slow on some machines.
|
||||||
"ancestor::script | ancestor::style)]" +
|
|
||||||
"[ancestor::body]")
|
# We clone `tree` as we modify it.
|
||||||
text = ''
|
cloned_tree = deepcopy(tree.find("body"))
|
||||||
for text_node in text_nodes:
|
|
||||||
if len(text) < 500:
|
TAGS_TO_REMOVE = ("header", "nav", "aside", "footer", "script", "style",)
|
||||||
text += text_node + ' '
|
for el in cloned_tree.iter(TAGS_TO_REMOVE):
|
||||||
else:
|
el.getparent().remove(el)
|
||||||
break
|
|
||||||
text = re.sub(r'[\t ]+', ' ', text)
|
# Split all the text nodes into paragraphs (by splitting on new
|
||||||
text = re.sub(r'[\t \r\n]*[\r\n]+', '\n', text)
|
# lines)
|
||||||
text = text.strip()[:500]
|
text_nodes = (
|
||||||
og['og:description'] = text if text else None
|
re.sub(r'\s+', '\n', el.text).strip()
|
||||||
|
for el in cloned_tree.iter()
|
||||||
|
if el.text and isinstance(el.tag, basestring) # Removes comments
|
||||||
|
)
|
||||||
|
og['og:description'] = summarize_paragraphs(text_nodes)
|
||||||
|
|
||||||
# TODO: delete the url downloads to stop diskfilling,
|
# TODO: delete the url downloads to stop diskfilling,
|
||||||
# as we only ever cared about its OG
|
# as we only ever cared about its OG
|
||||||
|
@ -450,3 +456,56 @@ class PreviewUrlResource(Resource):
|
||||||
content_type.startswith("application/xhtml")
|
content_type.startswith("application/xhtml")
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
|
||||||
|
# Try to get a summary of between 200 and 500 words, respecting
|
||||||
|
# first paragraph and then word boundaries.
|
||||||
|
# TODO: Respect sentences?
|
||||||
|
|
||||||
|
description = ''
|
||||||
|
|
||||||
|
# Keep adding paragraphs until we get to the MIN_SIZE.
|
||||||
|
for text_node in text_nodes:
|
||||||
|
if len(description) < min_size:
|
||||||
|
text_node = re.sub(r'[\t \r\n]+', ' ', text_node)
|
||||||
|
description += text_node + '\n\n'
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
description = description.strip()
|
||||||
|
description = re.sub(r'[\t ]+', ' ', description)
|
||||||
|
description = re.sub(r'[\t \r\n]*[\r\n]+', '\n\n', description)
|
||||||
|
|
||||||
|
# If the concatenation of paragraphs to get above MIN_SIZE
|
||||||
|
# took us over MAX_SIZE, then we need to truncate mid paragraph
|
||||||
|
if len(description) > max_size:
|
||||||
|
new_desc = ""
|
||||||
|
|
||||||
|
# This splits the paragraph into words, but keeping the
|
||||||
|
# (preceeding) whitespace intact so we can easily concat
|
||||||
|
# words back together.
|
||||||
|
for match in re.finditer("\s*\S+", description):
|
||||||
|
word = match.group()
|
||||||
|
|
||||||
|
# Keep adding words while the total length is less than
|
||||||
|
# MAX_SIZE.
|
||||||
|
if len(word) + len(new_desc) < max_size:
|
||||||
|
new_desc += word
|
||||||
|
else:
|
||||||
|
# At this point the next word *will* take us over
|
||||||
|
# MAX_SIZE, but we also want to ensure that its not
|
||||||
|
# a huge word. If it is add it anyway and we'll
|
||||||
|
# truncate later.
|
||||||
|
if len(new_desc) < min_size:
|
||||||
|
new_desc += word
|
||||||
|
break
|
||||||
|
|
||||||
|
# Double check that we're not over the limit
|
||||||
|
if len(new_desc) > max_size:
|
||||||
|
new_desc = new_desc[:max_size]
|
||||||
|
|
||||||
|
# We always add an ellipsis because at the very least
|
||||||
|
# we chopped mid paragraph.
|
||||||
|
description = new_desc.strip() + "…"
|
||||||
|
return description if description else None
|
||||||
|
|
|
@ -19,37 +19,38 @@
|
||||||
# partial one for unit test mocking.
|
# partial one for unit test mocking.
|
||||||
|
|
||||||
# Imports required for the default HomeServer() implementation
|
# Imports required for the default HomeServer() implementation
|
||||||
from twisted.web.client import BrowserLikePolicyForHTTPS
|
|
||||||
from twisted.enterprise import adbapi
|
|
||||||
|
|
||||||
from synapse.appservice.scheduler import ApplicationServiceScheduler
|
|
||||||
from synapse.appservice.api import ApplicationServiceApi
|
|
||||||
from synapse.federation import initialize_http_replication
|
|
||||||
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
|
||||||
from synapse.notifier import Notifier
|
|
||||||
from synapse.api.auth import Auth
|
|
||||||
from synapse.handlers import Handlers
|
|
||||||
from synapse.handlers.presence import PresenceHandler
|
|
||||||
from synapse.handlers.sync import SyncHandler
|
|
||||||
from synapse.handlers.typing import TypingHandler
|
|
||||||
from synapse.handlers.room import RoomListHandler
|
|
||||||
from synapse.handlers.auth import AuthHandler
|
|
||||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
|
||||||
from synapse.state import StateHandler
|
|
||||||
from synapse.storage import DataStore
|
|
||||||
from synapse.util import Clock
|
|
||||||
from synapse.util.distributor import Distributor
|
|
||||||
from synapse.streams.events import EventSources
|
|
||||||
from synapse.api.ratelimiting import Ratelimiter
|
|
||||||
from synapse.crypto.keyring import Keyring
|
|
||||||
from synapse.push.pusherpool import PusherPool
|
|
||||||
from synapse.events.builder import EventBuilderFactory
|
|
||||||
from synapse.api.filtering import Filtering
|
|
||||||
|
|
||||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from twisted.enterprise import adbapi
|
||||||
|
from twisted.web.client import BrowserLikePolicyForHTTPS
|
||||||
|
|
||||||
|
from synapse.api.auth import Auth
|
||||||
|
from synapse.api.filtering import Filtering
|
||||||
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
|
from synapse.appservice.api import ApplicationServiceApi
|
||||||
|
from synapse.appservice.scheduler import ApplicationServiceScheduler
|
||||||
|
from synapse.crypto.keyring import Keyring
|
||||||
|
from synapse.events.builder import EventBuilderFactory
|
||||||
|
from synapse.federation import initialize_http_replication
|
||||||
|
from synapse.handlers import Handlers
|
||||||
|
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||||
|
from synapse.handlers.auth import AuthHandler
|
||||||
|
from synapse.handlers.device import DeviceHandler
|
||||||
|
from synapse.handlers.e2e_keys import E2eKeysHandler
|
||||||
|
from synapse.handlers.presence import PresenceHandler
|
||||||
|
from synapse.handlers.room import RoomListHandler
|
||||||
|
from synapse.handlers.sync import SyncHandler
|
||||||
|
from synapse.handlers.typing import TypingHandler
|
||||||
|
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
||||||
|
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||||
|
from synapse.notifier import Notifier
|
||||||
|
from synapse.push.pusherpool import PusherPool
|
||||||
|
from synapse.rest.media.v1.media_repository import MediaRepository
|
||||||
|
from synapse.state import StateHandler
|
||||||
|
from synapse.storage import DataStore
|
||||||
|
from synapse.streams.events import EventSources
|
||||||
|
from synapse.util import Clock
|
||||||
|
from synapse.util.distributor import Distributor
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -91,6 +92,8 @@ class HomeServer(object):
|
||||||
'typing_handler',
|
'typing_handler',
|
||||||
'room_list_handler',
|
'room_list_handler',
|
||||||
'auth_handler',
|
'auth_handler',
|
||||||
|
'device_handler',
|
||||||
|
'e2e_keys_handler',
|
||||||
'application_service_api',
|
'application_service_api',
|
||||||
'application_service_scheduler',
|
'application_service_scheduler',
|
||||||
'application_service_handler',
|
'application_service_handler',
|
||||||
|
@ -113,6 +116,7 @@ class HomeServer(object):
|
||||||
'filtering',
|
'filtering',
|
||||||
'http_client_context_factory',
|
'http_client_context_factory',
|
||||||
'simple_http_client',
|
'simple_http_client',
|
||||||
|
'media_repository',
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, hostname, **kwargs):
|
def __init__(self, hostname, **kwargs):
|
||||||
|
@ -195,6 +199,12 @@ class HomeServer(object):
|
||||||
def build_auth_handler(self):
|
def build_auth_handler(self):
|
||||||
return AuthHandler(self)
|
return AuthHandler(self)
|
||||||
|
|
||||||
|
def build_device_handler(self):
|
||||||
|
return DeviceHandler(self)
|
||||||
|
|
||||||
|
def build_e2e_keys_handler(self):
|
||||||
|
return E2eKeysHandler(self)
|
||||||
|
|
||||||
def build_application_service_api(self):
|
def build_application_service_api(self):
|
||||||
return ApplicationServiceApi(self)
|
return ApplicationServiceApi(self)
|
||||||
|
|
||||||
|
@ -233,6 +243,9 @@ class HomeServer(object):
|
||||||
**self.db_config.get("args", {})
|
**self.db_config.get("args", {})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def build_media_repository(self):
|
||||||
|
return MediaRepository(self)
|
||||||
|
|
||||||
def remove_pusher(self, app_id, push_key, user_id):
|
def remove_pusher(self, app_id, push_key, user_id):
|
||||||
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
||||||
|
|
||||||
|
|
25
synapse/server.pyi
Normal file
25
synapse/server.pyi
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
import synapse.handlers
|
||||||
|
import synapse.handlers.auth
|
||||||
|
import synapse.handlers.device
|
||||||
|
import synapse.handlers.e2e_keys
|
||||||
|
import synapse.storage
|
||||||
|
import synapse.state
|
||||||
|
|
||||||
|
class HomeServer(object):
|
||||||
|
def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_datastore(self) -> synapse.storage.DataStore:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_device_handler(self) -> synapse.handlers.device.DeviceHandler:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_e2e_keys_handler(self) -> synapse.handlers.e2e_keys.E2eKeysHandler:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_handlers(self) -> synapse.handlers.Handlers:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_state_handler(self) -> synapse.state.StateHandler:
|
||||||
|
pass
|
|
@ -379,7 +379,8 @@ class StateHandler(object):
|
||||||
try:
|
try:
|
||||||
# FIXME: hs.get_auth() is bad style, but we need to do it to
|
# FIXME: hs.get_auth() is bad style, but we need to do it to
|
||||||
# get around circular deps.
|
# get around circular deps.
|
||||||
self.hs.get_auth().check(event, auth_events)
|
# The signatures have already been checked at this point
|
||||||
|
self.hs.get_auth().check(event, auth_events, do_sig_check=False)
|
||||||
prev_event = event
|
prev_event = event
|
||||||
except AuthError:
|
except AuthError:
|
||||||
return prev_event
|
return prev_event
|
||||||
|
@ -391,7 +392,8 @@ class StateHandler(object):
|
||||||
try:
|
try:
|
||||||
# FIXME: hs.get_auth() is bad style, but we need to do it to
|
# FIXME: hs.get_auth() is bad style, but we need to do it to
|
||||||
# get around circular deps.
|
# get around circular deps.
|
||||||
self.hs.get_auth().check(event, auth_events)
|
# The signatures have already been checked at this point
|
||||||
|
self.hs.get_auth().check(event, auth_events, do_sig_check=False)
|
||||||
return event
|
return event
|
||||||
except AuthError:
|
except AuthError:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.storage.devices import DeviceStore
|
||||||
from .appservice import (
|
from .appservice import (
|
||||||
ApplicationServiceStore, ApplicationServiceTransactionStore
|
ApplicationServiceStore, ApplicationServiceTransactionStore
|
||||||
)
|
)
|
||||||
|
@ -80,6 +82,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
EventPushActionsStore,
|
EventPushActionsStore,
|
||||||
OpenIdStore,
|
OpenIdStore,
|
||||||
ClientIpStore,
|
ClientIpStore,
|
||||||
|
DeviceStore,
|
||||||
):
|
):
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
|
@ -92,7 +95,8 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
extra_tables=[("local_invites", "stream_id")]
|
extra_tables=[("local_invites", "stream_id")]
|
||||||
)
|
)
|
||||||
self._backfill_id_gen = StreamIdGenerator(
|
self._backfill_id_gen = StreamIdGenerator(
|
||||||
db_conn, "events", "stream_ordering", step=-1
|
db_conn, "events", "stream_ordering", step=-1,
|
||||||
|
extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
|
||||||
)
|
)
|
||||||
self._receipts_id_gen = StreamIdGenerator(
|
self._receipts_id_gen = StreamIdGenerator(
|
||||||
db_conn, "receipts_linearized", "stream_id"
|
db_conn, "receipts_linearized", "stream_id"
|
||||||
|
|
|
@ -597,10 +597,13 @@ class SQLBaseStore(object):
|
||||||
more rows, returning the result as a list of dicts.
|
more rows, returning the result as a list of dicts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
table : string giving the table name
|
table (str): the table name
|
||||||
keyvalues : dict of column names and values to select the rows with,
|
keyvalues (dict[str, Any] | None):
|
||||||
or None to not apply a WHERE clause.
|
column names and values to select the rows with, or None to not
|
||||||
retcols : list of strings giving the names of the columns to return
|
apply a WHERE clause.
|
||||||
|
retcols (iterable[str]): the names of the columns to return
|
||||||
|
Returns:
|
||||||
|
defer.Deferred: resolves to list[dict[str, Any]]
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
desc,
|
desc,
|
||||||
|
@ -615,9 +618,11 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
txn : Transaction object
|
txn : Transaction object
|
||||||
table : string giving the table name
|
table (str): the table name
|
||||||
keyvalues : dict of column names and values to select the rows with
|
keyvalues (dict[str, T] | None):
|
||||||
retcols : list of strings giving the names of the columns to return
|
column names and values to select the rows with, or None to not
|
||||||
|
apply a WHERE clause.
|
||||||
|
retcols (iterable[str]): the names of the columns to return
|
||||||
"""
|
"""
|
||||||
if keyvalues:
|
if keyvalues:
|
||||||
sql = "SELECT %s FROM %s WHERE %s" % (
|
sql = "SELECT %s FROM %s WHERE %s" % (
|
||||||
|
@ -807,6 +812,11 @@ class SQLBaseStore(object):
|
||||||
if txn.rowcount > 1:
|
if txn.rowcount > 1:
|
||||||
raise StoreError(500, "more than one row matched")
|
raise StoreError(500, "more than one row matched")
|
||||||
|
|
||||||
|
def _simple_delete(self, table, keyvalues, desc):
|
||||||
|
return self.runInteraction(
|
||||||
|
desc, self._simple_delete_txn, table, keyvalues
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _simple_delete_txn(txn, table, keyvalues):
|
def _simple_delete_txn(txn, table, keyvalues):
|
||||||
sql = "DELETE FROM %s WHERE %s" % (
|
sql = "DELETE FROM %s WHERE %s" % (
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
from . import engines
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -87,10 +88,12 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def start_doing_background_updates(self):
|
def start_doing_background_updates(self):
|
||||||
while True:
|
assert self._background_update_timer is None, \
|
||||||
if self._background_update_timer is not None:
|
"background updates already running"
|
||||||
return
|
|
||||||
|
|
||||||
|
logger.info("Starting background schema updates")
|
||||||
|
|
||||||
|
while True:
|
||||||
sleep = defer.Deferred()
|
sleep = defer.Deferred()
|
||||||
self._background_update_timer = self._clock.call_later(
|
self._background_update_timer = self._clock.call_later(
|
||||||
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
|
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
|
||||||
|
@ -101,22 +104,23 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||||
self._background_update_timer = None
|
self._background_update_timer = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = yield self.do_background_update(
|
result = yield self.do_next_background_update(
|
||||||
self.BACKGROUND_UPDATE_DURATION_MS
|
self.BACKGROUND_UPDATE_DURATION_MS
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
logger.exception("Error doing update")
|
logger.exception("Error doing update")
|
||||||
|
else:
|
||||||
if result is None:
|
if result is None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"No more background updates to do."
|
"No more background updates to do."
|
||||||
" Unscheduling background update task."
|
" Unscheduling background update task."
|
||||||
)
|
)
|
||||||
return
|
defer.returnValue(None)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_background_update(self, desired_duration_ms):
|
def do_next_background_update(self, desired_duration_ms):
|
||||||
"""Does some amount of work on a background update
|
"""Does some amount of work on the next queued background update
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
desired_duration_ms(float): How long we want to spend
|
desired_duration_ms(float): How long we want to spend
|
||||||
updating.
|
updating.
|
||||||
|
@ -135,11 +139,21 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||||
self._background_update_queue.append(update['update_name'])
|
self._background_update_queue.append(update['update_name'])
|
||||||
|
|
||||||
if not self._background_update_queue:
|
if not self._background_update_queue:
|
||||||
|
# no work left to do
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
|
# pop from the front, and add back to the back
|
||||||
update_name = self._background_update_queue.pop(0)
|
update_name = self._background_update_queue.pop(0)
|
||||||
self._background_update_queue.append(update_name)
|
self._background_update_queue.append(update_name)
|
||||||
|
|
||||||
|
res = yield self._do_background_update(update_name, desired_duration_ms)
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _do_background_update(self, update_name, desired_duration_ms):
|
||||||
|
logger.info("Starting update batch on background update '%s'",
|
||||||
|
update_name)
|
||||||
|
|
||||||
update_handler = self._background_update_handlers[update_name]
|
update_handler = self._background_update_handlers[update_name]
|
||||||
|
|
||||||
performance = self._background_update_performance.get(update_name)
|
performance = self._background_update_performance.get(update_name)
|
||||||
|
@ -202,6 +216,64 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
self._background_update_handlers[update_name] = update_handler
|
self._background_update_handlers[update_name] = update_handler
|
||||||
|
|
||||||
|
def register_background_index_update(self, update_name, index_name,
|
||||||
|
table, columns):
|
||||||
|
"""Helper for store classes to do a background index addition
|
||||||
|
|
||||||
|
To use:
|
||||||
|
|
||||||
|
1. use a schema delta file to add a background update. Example:
|
||||||
|
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||||
|
('my_new_index', '{}');
|
||||||
|
|
||||||
|
2. In the Store constructor, call this method
|
||||||
|
|
||||||
|
Args:
|
||||||
|
update_name (str): update_name to register for
|
||||||
|
index_name (str): name of index to add
|
||||||
|
table (str): table to add index to
|
||||||
|
columns (list[str]): columns/expressions to include in index
|
||||||
|
"""
|
||||||
|
|
||||||
|
# if this is postgres, we add the indexes concurrently. Otherwise
|
||||||
|
# we fall back to doing it inline
|
||||||
|
if isinstance(self.database_engine, engines.PostgresEngine):
|
||||||
|
conc = True
|
||||||
|
else:
|
||||||
|
conc = False
|
||||||
|
|
||||||
|
sql = "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)" \
|
||||||
|
% {
|
||||||
|
"conc": "CONCURRENTLY" if conc else "",
|
||||||
|
"name": index_name,
|
||||||
|
"table": table,
|
||||||
|
"columns": ", ".join(columns),
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_index_concurrently(conn):
|
||||||
|
conn.rollback()
|
||||||
|
# postgres insists on autocommit for the index
|
||||||
|
conn.set_session(autocommit=True)
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute(sql)
|
||||||
|
conn.set_session(autocommit=False)
|
||||||
|
|
||||||
|
def create_index(conn):
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute(sql)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def updater(progress, batch_size):
|
||||||
|
logger.info("Adding index %s to %s", index_name, table)
|
||||||
|
if conc:
|
||||||
|
yield self.runWithConnection(create_index_concurrently)
|
||||||
|
else:
|
||||||
|
yield self.runWithConnection(create_index)
|
||||||
|
yield self._end_background_update(update_name)
|
||||||
|
defer.returnValue(1)
|
||||||
|
|
||||||
|
self.register_background_update_handler(update_name, updater)
|
||||||
|
|
||||||
def start_background_update(self, update_name, progress):
|
def start_background_update(self, update_name, progress):
|
||||||
"""Starts a background update running.
|
"""Starts a background update running.
|
||||||
|
|
||||||
|
|
|
@ -13,10 +13,14 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore, Cache
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from ._base import Cache
|
||||||
|
from . import background_updates
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
|
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
|
||||||
# times give more inserts into the database even for readonly API hits
|
# times give more inserts into the database even for readonly API hits
|
||||||
|
@ -24,8 +28,7 @@ from twisted.internet import defer
|
||||||
LAST_SEEN_GRANULARITY = 120 * 1000
|
LAST_SEEN_GRANULARITY = 120 * 1000
|
||||||
|
|
||||||
|
|
||||||
class ClientIpStore(SQLBaseStore):
|
class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.client_ip_last_seen = Cache(
|
self.client_ip_last_seen = Cache(
|
||||||
name="client_ip_last_seen",
|
name="client_ip_last_seen",
|
||||||
|
@ -34,8 +37,15 @@ class ClientIpStore(SQLBaseStore):
|
||||||
|
|
||||||
super(ClientIpStore, self).__init__(hs)
|
super(ClientIpStore, self).__init__(hs)
|
||||||
|
|
||||||
|
self.register_background_index_update(
|
||||||
|
"user_ips_device_index",
|
||||||
|
index_name="user_ips_device_id",
|
||||||
|
table="user_ips",
|
||||||
|
columns=["user_id", "device_id", "last_seen"],
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def insert_client_ip(self, user, access_token, ip, user_agent):
|
def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
|
||||||
now = int(self._clock.time_msec())
|
now = int(self._clock.time_msec())
|
||||||
key = (user.to_string(), access_token, ip)
|
key = (user.to_string(), access_token, ip)
|
||||||
|
|
||||||
|
@ -59,6 +69,7 @@ class ClientIpStore(SQLBaseStore):
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"ip": ip,
|
"ip": ip,
|
||||||
"user_agent": user_agent,
|
"user_agent": user_agent,
|
||||||
|
"device_id": device_id,
|
||||||
},
|
},
|
||||||
values={
|
values={
|
||||||
"last_seen": now,
|
"last_seen": now,
|
||||||
|
@ -66,3 +77,69 @@ class ClientIpStore(SQLBaseStore):
|
||||||
desc="insert_client_ip",
|
desc="insert_client_ip",
|
||||||
lock=False,
|
lock=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_last_client_ip_by_device(self, devices):
|
||||||
|
"""For each device_id listed, give the user_ip it was last seen on
|
||||||
|
|
||||||
|
Args:
|
||||||
|
devices (iterable[(str, str)]): list of (user_id, device_id) pairs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
defer.Deferred: resolves to a dict, where the keys
|
||||||
|
are (user_id, device_id) tuples. The values are also dicts, with
|
||||||
|
keys giving the column names
|
||||||
|
"""
|
||||||
|
|
||||||
|
res = yield self.runInteraction(
|
||||||
|
"get_last_client_ip_by_device",
|
||||||
|
self._get_last_client_ip_by_device_txn,
|
||||||
|
retcols=(
|
||||||
|
"user_id",
|
||||||
|
"access_token",
|
||||||
|
"ip",
|
||||||
|
"user_agent",
|
||||||
|
"device_id",
|
||||||
|
"last_seen",
|
||||||
|
),
|
||||||
|
devices=devices
|
||||||
|
)
|
||||||
|
|
||||||
|
ret = {(d["user_id"], d["device_id"]): d for d in res}
|
||||||
|
defer.returnValue(ret)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols):
|
||||||
|
where_clauses = []
|
||||||
|
bindings = []
|
||||||
|
for (user_id, device_id) in devices:
|
||||||
|
if device_id is None:
|
||||||
|
where_clauses.append("(user_id = ? AND device_id IS NULL)")
|
||||||
|
bindings.extend((user_id, ))
|
||||||
|
else:
|
||||||
|
where_clauses.append("(user_id = ? AND device_id = ?)")
|
||||||
|
bindings.extend((user_id, device_id))
|
||||||
|
|
||||||
|
inner_select = (
|
||||||
|
"SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
|
||||||
|
"WHERE %(where)s "
|
||||||
|
"GROUP BY user_id, device_id"
|
||||||
|
) % {
|
||||||
|
"where": " OR ".join(where_clauses),
|
||||||
|
}
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
"SELECT %(retcols)s FROM user_ips "
|
||||||
|
"JOIN (%(inner_select)s) ips ON"
|
||||||
|
" user_ips.last_seen = ips.mls AND"
|
||||||
|
" user_ips.user_id = ips.user_id AND"
|
||||||
|
" (user_ips.device_id = ips.device_id OR"
|
||||||
|
" (user_ips.device_id IS NULL AND ips.device_id IS NULL)"
|
||||||
|
" )"
|
||||||
|
) % {
|
||||||
|
"retcols": ",".join("user_ips." + c for c in retcols),
|
||||||
|
"inner_select": inner_select,
|
||||||
|
}
|
||||||
|
|
||||||
|
txn.execute(sql, bindings)
|
||||||
|
return cls.cursor_to_dict(txn)
|
||||||
|
|
137
synapse/storage/devices.py
Normal file
137
synapse/storage/devices.py
Normal file
|
@ -0,0 +1,137 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.errors import StoreError
|
||||||
|
from ._base import SQLBaseStore
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceStore(SQLBaseStore):
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def store_device(self, user_id, device_id,
|
||||||
|
initial_device_display_name,
|
||||||
|
ignore_if_known=True):
|
||||||
|
"""Ensure the given device is known; add it to the store if not
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): id of user associated with the device
|
||||||
|
device_id (str): id of device
|
||||||
|
initial_device_display_name (str): initial displayname of the
|
||||||
|
device
|
||||||
|
ignore_if_known (bool): ignore integrity errors which mean the
|
||||||
|
device is already known
|
||||||
|
Returns:
|
||||||
|
defer.Deferred
|
||||||
|
Raises:
|
||||||
|
StoreError: if ignore_if_known is False and the device was already
|
||||||
|
known
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
yield self._simple_insert(
|
||||||
|
"devices",
|
||||||
|
values={
|
||||||
|
"user_id": user_id,
|
||||||
|
"device_id": device_id,
|
||||||
|
"display_name": initial_device_display_name
|
||||||
|
},
|
||||||
|
desc="store_device",
|
||||||
|
or_ignore=ignore_if_known,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("store_device with device_id=%s failed: %s",
|
||||||
|
device_id, e)
|
||||||
|
raise StoreError(500, "Problem storing device.")
|
||||||
|
|
||||||
|
def get_device(self, user_id, device_id):
|
||||||
|
"""Retrieve a device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): The ID of the user which owns the device
|
||||||
|
device_id (str): The ID of the device to retrieve
|
||||||
|
Returns:
|
||||||
|
defer.Deferred for a dict containing the device information
|
||||||
|
Raises:
|
||||||
|
StoreError: if the device is not found
|
||||||
|
"""
|
||||||
|
return self._simple_select_one(
|
||||||
|
table="devices",
|
||||||
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||||
|
retcols=("user_id", "device_id", "display_name"),
|
||||||
|
desc="get_device",
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_device(self, user_id, device_id):
|
||||||
|
"""Delete a device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): The ID of the user which owns the device
|
||||||
|
device_id (str): The ID of the device to delete
|
||||||
|
Returns:
|
||||||
|
defer.Deferred
|
||||||
|
"""
|
||||||
|
return self._simple_delete_one(
|
||||||
|
table="devices",
|
||||||
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||||
|
desc="delete_device",
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_device(self, user_id, device_id, new_display_name=None):
|
||||||
|
"""Update a device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): The ID of the user which owns the device
|
||||||
|
device_id (str): The ID of the device to update
|
||||||
|
new_display_name (str|None): new displayname for device; None
|
||||||
|
to leave unchanged
|
||||||
|
Raises:
|
||||||
|
StoreError: if the device is not found
|
||||||
|
Returns:
|
||||||
|
defer.Deferred
|
||||||
|
"""
|
||||||
|
updates = {}
|
||||||
|
if new_display_name is not None:
|
||||||
|
updates["display_name"] = new_display_name
|
||||||
|
if not updates:
|
||||||
|
return defer.succeed(None)
|
||||||
|
return self._simple_update_one(
|
||||||
|
table="devices",
|
||||||
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||||
|
updatevalues=updates,
|
||||||
|
desc="update_device",
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_devices_by_user(self, user_id):
|
||||||
|
"""Retrieve all of a user's registered devices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str):
|
||||||
|
Returns:
|
||||||
|
defer.Deferred: resolves to a dict from device_id to a dict
|
||||||
|
containing "device_id", "user_id" and "display_name" for each
|
||||||
|
device.
|
||||||
|
"""
|
||||||
|
devices = yield self._simple_select_list(
|
||||||
|
table="devices",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
retcols=("user_id", "device_id", "display_name"),
|
||||||
|
desc="get_devices_by_user"
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({d["device_id"]: d for d in devices})
|
|
@ -12,6 +12,9 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import collections
|
||||||
|
|
||||||
|
import twisted.internet.defer
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
|
||||||
|
@ -36,24 +39,49 @@ class EndToEndKeyStore(SQLBaseStore):
|
||||||
query_list(list): List of pairs of user_ids and device_ids.
|
query_list(list): List of pairs of user_ids and device_ids.
|
||||||
Returns:
|
Returns:
|
||||||
Dict mapping from user-id to dict mapping from device_id to
|
Dict mapping from user-id to dict mapping from device_id to
|
||||||
key json byte strings.
|
dict containing "key_json", "device_display_name".
|
||||||
"""
|
"""
|
||||||
def _get_e2e_device_keys(txn):
|
if not query_list:
|
||||||
result = {}
|
return {}
|
||||||
for user_id, device_id in query_list:
|
|
||||||
user_result = result.setdefault(user_id, {})
|
return self.runInteraction(
|
||||||
keyvalues = {"user_id": user_id}
|
"get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list
|
||||||
if device_id:
|
|
||||||
keyvalues["device_id"] = device_id
|
|
||||||
rows = self._simple_select_list_txn(
|
|
||||||
txn, table="e2e_device_keys_json",
|
|
||||||
keyvalues=keyvalues,
|
|
||||||
retcols=["device_id", "key_json"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_e2e_device_keys_txn(self, txn, query_list):
|
||||||
|
query_clauses = []
|
||||||
|
query_params = []
|
||||||
|
|
||||||
|
for (user_id, device_id) in query_list:
|
||||||
|
query_clause = "k.user_id = ?"
|
||||||
|
query_params.append(user_id)
|
||||||
|
|
||||||
|
if device_id:
|
||||||
|
query_clause += " AND k.device_id = ?"
|
||||||
|
query_params.append(device_id)
|
||||||
|
|
||||||
|
query_clauses.append(query_clause)
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
"SELECT k.user_id, k.device_id, "
|
||||||
|
" d.display_name AS device_display_name, "
|
||||||
|
" k.key_json"
|
||||||
|
" FROM e2e_device_keys_json k"
|
||||||
|
" LEFT JOIN devices d ON d.user_id = k.user_id"
|
||||||
|
" AND d.device_id = k.device_id"
|
||||||
|
" WHERE %s"
|
||||||
|
) % (
|
||||||
|
" OR ".join("(" + q + ")" for q in query_clauses)
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.execute(sql, query_params)
|
||||||
|
rows = self.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
result = collections.defaultdict(dict)
|
||||||
for row in rows:
|
for row in rows:
|
||||||
user_result[row["device_id"]] = row["key_json"]
|
result[row["user_id"]][row["device_id"]] = row
|
||||||
|
|
||||||
return result
|
return result
|
||||||
return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys)
|
|
||||||
|
|
||||||
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
|
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
|
||||||
def _add_e2e_one_time_keys(txn):
|
def _add_e2e_one_time_keys(txn):
|
||||||
|
@ -123,3 +151,16 @@ class EndToEndKeyStore(SQLBaseStore):
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
|
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@twisted.internet.defer.inlineCallbacks
|
||||||
|
def delete_e2e_keys_by_device(self, user_id, device_id):
|
||||||
|
yield self._simple_delete(
|
||||||
|
table="e2e_device_keys_json",
|
||||||
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||||
|
desc="delete_e2e_device_keys_by_device"
|
||||||
|
)
|
||||||
|
yield self._simple_delete(
|
||||||
|
table="e2e_one_time_keys_json",
|
||||||
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||||
|
desc="delete_e2e_one_time_keys_by_device"
|
||||||
|
)
|
||||||
|
|
|
@ -16,6 +16,8 @@
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
|
from synapse.types import RoomStreamToken
|
||||||
|
from .stream import lower_bound
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import ujson as json
|
import ujson as json
|
||||||
|
@ -73,6 +75,9 @@ class EventPushActionsStore(SQLBaseStore):
|
||||||
|
|
||||||
stream_ordering = results[0][0]
|
stream_ordering = results[0][0]
|
||||||
topological_ordering = results[0][1]
|
topological_ordering = results[0][1]
|
||||||
|
token = RoomStreamToken(
|
||||||
|
topological_ordering, stream_ordering
|
||||||
|
)
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT sum(notif), sum(highlight)"
|
"SELECT sum(notif), sum(highlight)"
|
||||||
|
@ -80,15 +85,10 @@ class EventPushActionsStore(SQLBaseStore):
|
||||||
" WHERE"
|
" WHERE"
|
||||||
" user_id = ?"
|
" user_id = ?"
|
||||||
" AND room_id = ?"
|
" AND room_id = ?"
|
||||||
" AND ("
|
" AND %s"
|
||||||
" topological_ordering > ?"
|
) % (lower_bound(token, self.database_engine, inclusive=False),)
|
||||||
" OR (topological_ordering = ? AND stream_ordering > ?)"
|
|
||||||
")"
|
txn.execute(sql, (user_id, room_id))
|
||||||
)
|
|
||||||
txn.execute(sql, (
|
|
||||||
user_id, room_id,
|
|
||||||
topological_ordering, topological_ordering, stream_ordering
|
|
||||||
))
|
|
||||||
row = txn.fetchone()
|
row = txn.fetchone()
|
||||||
if row:
|
if row:
|
||||||
return {
|
return {
|
||||||
|
@ -117,21 +117,149 @@ class EventPushActionsStore(SQLBaseStore):
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_unread_push_actions_for_user_in_range(self, user_id,
|
def get_unread_push_actions_for_user_in_range_for_http(
|
||||||
min_stream_ordering,
|
self, user_id, min_stream_ordering, max_stream_ordering, limit=20
|
||||||
max_stream_ordering=None,
|
):
|
||||||
limit=20):
|
"""Get a list of the most recent unread push actions for a given user,
|
||||||
|
within the given stream ordering range. Called by the httppusher.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): The user to fetch push actions for.
|
||||||
|
min_stream_ordering(int): The exclusive lower bound on the
|
||||||
|
stream ordering of event push actions to fetch.
|
||||||
|
max_stream_ordering(int): The inclusive upper bound on the
|
||||||
|
stream ordering of event push actions to fetch.
|
||||||
|
limit (int): The maximum number of rows to return.
|
||||||
|
Returns:
|
||||||
|
A promise which resolves to a list of dicts with the keys "event_id",
|
||||||
|
"room_id", "stream_ordering", "actions".
|
||||||
|
The list will be ordered by ascending stream_ordering.
|
||||||
|
The list will have between 0~limit entries.
|
||||||
|
"""
|
||||||
|
# find rooms that have a read receipt in them and return the next
|
||||||
|
# push actions
|
||||||
|
def get_after_receipt(txn):
|
||||||
|
# find rooms that have a read receipt in them and return the next
|
||||||
|
# push actions
|
||||||
|
sql = (
|
||||||
|
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions"
|
||||||
|
" FROM ("
|
||||||
|
" SELECT room_id,"
|
||||||
|
" MAX(topological_ordering) as topological_ordering,"
|
||||||
|
" MAX(stream_ordering) as stream_ordering"
|
||||||
|
" FROM events"
|
||||||
|
" INNER JOIN receipts_linearized USING (room_id, event_id)"
|
||||||
|
" WHERE receipt_type = 'm.read' AND user_id = ?"
|
||||||
|
" GROUP BY room_id"
|
||||||
|
") AS rl,"
|
||||||
|
" event_push_actions AS ep"
|
||||||
|
" WHERE"
|
||||||
|
" ep.room_id = rl.room_id"
|
||||||
|
" AND ("
|
||||||
|
" ep.topological_ordering > rl.topological_ordering"
|
||||||
|
" OR ("
|
||||||
|
" ep.topological_ordering = rl.topological_ordering"
|
||||||
|
" AND ep.stream_ordering > rl.stream_ordering"
|
||||||
|
" )"
|
||||||
|
" )"
|
||||||
|
" AND ep.user_id = ?"
|
||||||
|
" AND ep.stream_ordering > ?"
|
||||||
|
" AND ep.stream_ordering <= ?"
|
||||||
|
" ORDER BY ep.stream_ordering ASC LIMIT ?"
|
||||||
|
)
|
||||||
|
args = [
|
||||||
|
user_id, user_id,
|
||||||
|
min_stream_ordering, max_stream_ordering, limit,
|
||||||
|
]
|
||||||
|
txn.execute(sql, args)
|
||||||
|
return txn.fetchall()
|
||||||
|
after_read_receipt = yield self.runInteraction(
|
||||||
|
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
|
||||||
|
)
|
||||||
|
|
||||||
|
# There are rooms with push actions in them but you don't have a read receipt in
|
||||||
|
# them e.g. rooms you've been invited to, so get push actions for rooms which do
|
||||||
|
# not have read receipts in them too.
|
||||||
|
def get_no_receipt(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
|
||||||
|
" e.received_ts"
|
||||||
|
" FROM event_push_actions AS ep"
|
||||||
|
" INNER JOIN events AS e USING (room_id, event_id)"
|
||||||
|
" WHERE"
|
||||||
|
" ep.room_id NOT IN ("
|
||||||
|
" SELECT room_id FROM receipts_linearized"
|
||||||
|
" WHERE receipt_type = 'm.read' AND user_id = ?"
|
||||||
|
" GROUP BY room_id"
|
||||||
|
" )"
|
||||||
|
" AND ep.user_id = ?"
|
||||||
|
" AND ep.stream_ordering > ?"
|
||||||
|
" AND ep.stream_ordering <= ?"
|
||||||
|
" ORDER BY ep.stream_ordering ASC LIMIT ?"
|
||||||
|
)
|
||||||
|
args = [
|
||||||
|
user_id, user_id,
|
||||||
|
min_stream_ordering, max_stream_ordering, limit,
|
||||||
|
]
|
||||||
|
txn.execute(sql, args)
|
||||||
|
return txn.fetchall()
|
||||||
|
no_read_receipt = yield self.runInteraction(
|
||||||
|
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
|
||||||
|
)
|
||||||
|
|
||||||
|
notifs = [
|
||||||
|
{
|
||||||
|
"event_id": row[0],
|
||||||
|
"room_id": row[1],
|
||||||
|
"stream_ordering": row[2],
|
||||||
|
"actions": json.loads(row[3]),
|
||||||
|
} for row in after_read_receipt + no_read_receipt
|
||||||
|
]
|
||||||
|
|
||||||
|
# Now sort it so it's ordered correctly, since currently it will
|
||||||
|
# contain results from the first query, correctly ordered, followed
|
||||||
|
# by results from the second query, but we want them all ordered
|
||||||
|
# by stream_ordering, oldest first.
|
||||||
|
notifs.sort(key=lambda r: r['stream_ordering'])
|
||||||
|
|
||||||
|
# Take only up to the limit. We have to stop at the limit because
|
||||||
|
# one of the subqueries may have hit the limit.
|
||||||
|
defer.returnValue(notifs[:limit])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_unread_push_actions_for_user_in_range_for_email(
|
||||||
|
self, user_id, min_stream_ordering, max_stream_ordering, limit=20
|
||||||
|
):
|
||||||
|
"""Get a list of the most recent unread push actions for a given user,
|
||||||
|
within the given stream ordering range. Called by the emailpusher
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): The user to fetch push actions for.
|
||||||
|
min_stream_ordering(int): The exclusive lower bound on the
|
||||||
|
stream ordering of event push actions to fetch.
|
||||||
|
max_stream_ordering(int): The inclusive upper bound on the
|
||||||
|
stream ordering of event push actions to fetch.
|
||||||
|
limit (int): The maximum number of rows to return.
|
||||||
|
Returns:
|
||||||
|
A promise which resolves to a list of dicts with the keys "event_id",
|
||||||
|
"room_id", "stream_ordering", "actions", "received_ts".
|
||||||
|
The list will be ordered by descending received_ts.
|
||||||
|
The list will have between 0~limit entries.
|
||||||
|
"""
|
||||||
|
# find rooms that have a read receipt in them and return the most recent
|
||||||
|
# push actions
|
||||||
def get_after_receipt(txn):
|
def get_after_receipt(txn):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, "
|
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
|
||||||
"e.received_ts "
|
" e.received_ts"
|
||||||
"FROM ("
|
" FROM ("
|
||||||
" SELECT room_id, user_id, "
|
" SELECT room_id,"
|
||||||
" max(topological_ordering) as topological_ordering, "
|
" MAX(topological_ordering) as topological_ordering,"
|
||||||
" max(stream_ordering) as stream_ordering "
|
" MAX(stream_ordering) as stream_ordering"
|
||||||
" FROM events"
|
" FROM events"
|
||||||
" NATURAL JOIN receipts_linearized WHERE receipt_type = 'm.read'"
|
" INNER JOIN receipts_linearized USING (room_id, event_id)"
|
||||||
" GROUP BY room_id, user_id"
|
" WHERE receipt_type = 'm.read' AND user_id = ?"
|
||||||
|
" GROUP BY room_id"
|
||||||
") AS rl,"
|
") AS rl,"
|
||||||
" event_push_actions AS ep"
|
" event_push_actions AS ep"
|
||||||
" INNER JOIN events AS e USING (room_id, event_id)"
|
" INNER JOIN events AS e USING (room_id, event_id)"
|
||||||
|
@ -144,46 +272,53 @@ class EventPushActionsStore(SQLBaseStore):
|
||||||
" AND ep.stream_ordering > rl.stream_ordering"
|
" AND ep.stream_ordering > rl.stream_ordering"
|
||||||
" )"
|
" )"
|
||||||
" )"
|
" )"
|
||||||
" AND ep.stream_ordering > ?"
|
|
||||||
" AND ep.user_id = ?"
|
" AND ep.user_id = ?"
|
||||||
" AND ep.user_id = rl.user_id"
|
" AND ep.stream_ordering > ?"
|
||||||
|
" AND ep.stream_ordering <= ?"
|
||||||
|
" ORDER BY ep.stream_ordering DESC LIMIT ?"
|
||||||
)
|
)
|
||||||
args = [min_stream_ordering, user_id]
|
args = [
|
||||||
if max_stream_ordering is not None:
|
user_id, user_id,
|
||||||
sql += " AND ep.stream_ordering <= ?"
|
min_stream_ordering, max_stream_ordering, limit,
|
||||||
args.append(max_stream_ordering)
|
]
|
||||||
sql += " ORDER BY ep.stream_ordering ASC LIMIT ?"
|
|
||||||
args.append(limit)
|
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
return txn.fetchall()
|
return txn.fetchall()
|
||||||
after_read_receipt = yield self.runInteraction(
|
after_read_receipt = yield self.runInteraction(
|
||||||
"get_unread_push_actions_for_user_in_range", get_after_receipt
|
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# There are rooms with push actions in them but you don't have a read receipt in
|
||||||
|
# them e.g. rooms you've been invited to, so get push actions for rooms which do
|
||||||
|
# not have read receipts in them too.
|
||||||
def get_no_receipt(txn):
|
def get_no_receipt(txn):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
|
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
|
||||||
" e.received_ts"
|
" e.received_ts"
|
||||||
" FROM event_push_actions AS ep"
|
" FROM event_push_actions AS ep"
|
||||||
" JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
|
" INNER JOIN events AS e USING (room_id, event_id)"
|
||||||
" WHERE ep.room_id not in ("
|
" WHERE"
|
||||||
" SELECT room_id FROM events NATURAL JOIN receipts_linearized"
|
" ep.room_id NOT IN ("
|
||||||
|
" SELECT room_id FROM receipts_linearized"
|
||||||
" WHERE receipt_type = 'm.read' AND user_id = ?"
|
" WHERE receipt_type = 'm.read' AND user_id = ?"
|
||||||
" GROUP BY room_id"
|
" GROUP BY room_id"
|
||||||
") AND ep.user_id = ? AND ep.stream_ordering > ?"
|
" )"
|
||||||
|
" AND ep.user_id = ?"
|
||||||
|
" AND ep.stream_ordering > ?"
|
||||||
|
" AND ep.stream_ordering <= ?"
|
||||||
|
" ORDER BY ep.stream_ordering DESC LIMIT ?"
|
||||||
)
|
)
|
||||||
args = [user_id, user_id, min_stream_ordering]
|
args = [
|
||||||
if max_stream_ordering is not None:
|
user_id, user_id,
|
||||||
sql += " AND ep.stream_ordering <= ?"
|
min_stream_ordering, max_stream_ordering, limit,
|
||||||
args.append(max_stream_ordering)
|
]
|
||||||
sql += " ORDER BY ep.stream_ordering ASC"
|
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
return txn.fetchall()
|
return txn.fetchall()
|
||||||
no_read_receipt = yield self.runInteraction(
|
no_read_receipt = yield self.runInteraction(
|
||||||
"get_unread_push_actions_for_user_in_range", get_no_receipt
|
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue([
|
# Make a list of dicts from the two sets of results.
|
||||||
|
notifs = [
|
||||||
{
|
{
|
||||||
"event_id": row[0],
|
"event_id": row[0],
|
||||||
"room_id": row[1],
|
"room_id": row[1],
|
||||||
|
@ -191,7 +326,16 @@ class EventPushActionsStore(SQLBaseStore):
|
||||||
"actions": json.loads(row[3]),
|
"actions": json.loads(row[3]),
|
||||||
"received_ts": row[4],
|
"received_ts": row[4],
|
||||||
} for row in after_read_receipt + no_read_receipt
|
} for row in after_read_receipt + no_read_receipt
|
||||||
])
|
]
|
||||||
|
|
||||||
|
# Now sort it so it's ordered correctly, since currently it will
|
||||||
|
# contain results from the first query, correctly ordered, followed
|
||||||
|
# by results from the second query, but we want them all ordered
|
||||||
|
# by received_ts (most recent first)
|
||||||
|
notifs.sort(key=lambda r: -(r['received_ts'] or 0))
|
||||||
|
|
||||||
|
# Now return the first `limit`
|
||||||
|
defer.returnValue(notifs[:limit])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_time_of_last_push_action_before(self, stream_ordering):
|
def get_time_of_last_push_action_before(self, stream_ordering):
|
||||||
|
|
|
@ -23,9 +23,11 @@ from synapse.util.async import ObservableDeferred
|
||||||
from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
|
from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
from collections import deque, namedtuple
|
from collections import deque, namedtuple, OrderedDict
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
import synapse
|
import synapse
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
@ -149,8 +151,29 @@ class _EventPeristenceQueue(object):
|
||||||
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
|
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
|
||||||
|
|
||||||
|
|
||||||
|
def _retry_on_integrity_error(func):
|
||||||
|
"""Wraps a database function so that it gets retried on IntegrityError,
|
||||||
|
with `delete_existing=True` passed in.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: function that returns a Deferred and accepts a `delete_existing` arg
|
||||||
|
"""
|
||||||
|
@wraps(func)
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def f(self, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
res = yield func(self, *args, **kwargs)
|
||||||
|
except self.database_engine.module.IntegrityError:
|
||||||
|
logger.exception("IntegrityError, retrying.")
|
||||||
|
res = yield func(self, *args, delete_existing=True, **kwargs)
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
class EventsStore(SQLBaseStore):
|
class EventsStore(SQLBaseStore):
|
||||||
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
|
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
|
||||||
|
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(EventsStore, self).__init__(hs)
|
super(EventsStore, self).__init__(hs)
|
||||||
|
@ -158,6 +181,10 @@ class EventsStore(SQLBaseStore):
|
||||||
self.register_background_update_handler(
|
self.register_background_update_handler(
|
||||||
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
|
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
|
||||||
)
|
)
|
||||||
|
self.register_background_update_handler(
|
||||||
|
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
|
||||||
|
self._background_reindex_fields_sender,
|
||||||
|
)
|
||||||
|
|
||||||
self._event_persist_queue = _EventPeristenceQueue()
|
self._event_persist_queue = _EventPeristenceQueue()
|
||||||
|
|
||||||
|
@ -223,8 +250,10 @@ class EventsStore(SQLBaseStore):
|
||||||
|
|
||||||
self._event_persist_queue.handle_queue(room_id, persisting_queue)
|
self._event_persist_queue.handle_queue(room_id, persisting_queue)
|
||||||
|
|
||||||
|
@_retry_on_integrity_error
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _persist_events(self, events_and_contexts, backfilled=False):
|
def _persist_events(self, events_and_contexts, backfilled=False,
|
||||||
|
delete_existing=False):
|
||||||
if not events_and_contexts:
|
if not events_and_contexts:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -267,12 +296,15 @@ class EventsStore(SQLBaseStore):
|
||||||
self._persist_events_txn,
|
self._persist_events_txn,
|
||||||
events_and_contexts=chunk,
|
events_and_contexts=chunk,
|
||||||
backfilled=backfilled,
|
backfilled=backfilled,
|
||||||
|
delete_existing=delete_existing,
|
||||||
)
|
)
|
||||||
persist_event_counter.inc_by(len(chunk))
|
persist_event_counter.inc_by(len(chunk))
|
||||||
|
|
||||||
|
@_retry_on_integrity_error
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def _persist_event(self, event, context, current_state=None, backfilled=False):
|
def _persist_event(self, event, context, current_state=None, backfilled=False,
|
||||||
|
delete_existing=False):
|
||||||
try:
|
try:
|
||||||
with self._stream_id_gen.get_next() as stream_ordering:
|
with self._stream_id_gen.get_next() as stream_ordering:
|
||||||
with self._state_groups_id_gen.get_next() as state_group_id:
|
with self._state_groups_id_gen.get_next() as state_group_id:
|
||||||
|
@ -285,6 +317,7 @@ class EventsStore(SQLBaseStore):
|
||||||
context=context,
|
context=context,
|
||||||
current_state=current_state,
|
current_state=current_state,
|
||||||
backfilled=backfilled,
|
backfilled=backfilled,
|
||||||
|
delete_existing=delete_existing,
|
||||||
)
|
)
|
||||||
persist_event_counter.inc()
|
persist_event_counter.inc()
|
||||||
except _RollbackButIsFineException:
|
except _RollbackButIsFineException:
|
||||||
|
@ -317,7 +350,7 @@ class EventsStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
if not events and not allow_none:
|
if not events and not allow_none:
|
||||||
raise RuntimeError("Could not find event %s" % (event_id,))
|
raise SynapseError(404, "Could not find event %s" % (event_id,))
|
||||||
|
|
||||||
defer.returnValue(events[0] if events else None)
|
defer.returnValue(events[0] if events else None)
|
||||||
|
|
||||||
|
@ -347,7 +380,8 @@ class EventsStore(SQLBaseStore):
|
||||||
defer.returnValue({e.event_id: e for e in events})
|
defer.returnValue({e.event_id: e for e in events})
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def _persist_event_txn(self, txn, event, context, current_state, backfilled=False):
|
def _persist_event_txn(self, txn, event, context, current_state, backfilled=False,
|
||||||
|
delete_existing=False):
|
||||||
# We purposefully do this first since if we include a `current_state`
|
# We purposefully do this first since if we include a `current_state`
|
||||||
# key, we *want* to update the `current_state_events` table
|
# key, we *want* to update the `current_state_events` table
|
||||||
if current_state:
|
if current_state:
|
||||||
|
@ -355,7 +389,6 @@ class EventsStore(SQLBaseStore):
|
||||||
txn.call_after(self.get_rooms_for_user.invalidate_all)
|
txn.call_after(self.get_rooms_for_user.invalidate_all)
|
||||||
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
||||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
|
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
|
||||||
txn.call_after(self.get_room_name_and_aliases.invalidate, (event.room_id,))
|
|
||||||
|
|
||||||
# Add an entry to the current_state_resets table to record the point
|
# Add an entry to the current_state_resets table to record the point
|
||||||
# where we clobbered the current state
|
# where we clobbered the current state
|
||||||
|
@ -388,10 +421,38 @@ class EventsStore(SQLBaseStore):
|
||||||
txn,
|
txn,
|
||||||
[(event, context)],
|
[(event, context)],
|
||||||
backfilled=backfilled,
|
backfilled=backfilled,
|
||||||
|
delete_existing=delete_existing,
|
||||||
)
|
)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def _persist_events_txn(self, txn, events_and_contexts, backfilled):
|
def _persist_events_txn(self, txn, events_and_contexts, backfilled,
|
||||||
|
delete_existing=False):
|
||||||
|
"""Insert some number of room events into the necessary database tables.
|
||||||
|
|
||||||
|
Rejected events are only inserted into the events table, the events_json table,
|
||||||
|
and the rejections table. Things reading from those table will need to check
|
||||||
|
whether the event was rejected.
|
||||||
|
|
||||||
|
If delete_existing is True then existing events will be purged from the
|
||||||
|
database before insertion. This is useful when retrying due to IntegrityError.
|
||||||
|
"""
|
||||||
|
# Ensure that we don't have the same event twice.
|
||||||
|
# Pick the earliest non-outlier if there is one, else the earliest one.
|
||||||
|
new_events_and_contexts = OrderedDict()
|
||||||
|
for event, context in events_and_contexts:
|
||||||
|
prev_event_context = new_events_and_contexts.get(event.event_id)
|
||||||
|
if prev_event_context:
|
||||||
|
if not event.internal_metadata.is_outlier():
|
||||||
|
if prev_event_context[0].internal_metadata.is_outlier():
|
||||||
|
# To ensure correct ordering we pop, as OrderedDict is
|
||||||
|
# ordered by first insertion.
|
||||||
|
new_events_and_contexts.pop(event.event_id, None)
|
||||||
|
new_events_and_contexts[event.event_id] = (event, context)
|
||||||
|
else:
|
||||||
|
new_events_and_contexts[event.event_id] = (event, context)
|
||||||
|
|
||||||
|
events_and_contexts = new_events_and_contexts.values()
|
||||||
|
|
||||||
depth_updates = {}
|
depth_updates = {}
|
||||||
for event, context in events_and_contexts:
|
for event, context in events_and_contexts:
|
||||||
# Remove the any existing cache entries for the event_ids
|
# Remove the any existing cache entries for the event_ids
|
||||||
|
@ -402,21 +463,11 @@ class EventsStore(SQLBaseStore):
|
||||||
event.room_id, event.internal_metadata.stream_ordering,
|
event.room_id, event.internal_metadata.stream_ordering,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not event.internal_metadata.is_outlier():
|
if not event.internal_metadata.is_outlier() and not context.rejected:
|
||||||
depth_updates[event.room_id] = max(
|
depth_updates[event.room_id] = max(
|
||||||
event.depth, depth_updates.get(event.room_id, event.depth)
|
event.depth, depth_updates.get(event.room_id, event.depth)
|
||||||
)
|
)
|
||||||
|
|
||||||
if context.push_actions:
|
|
||||||
self._set_push_actions_for_event_and_users_txn(
|
|
||||||
txn, event, context.push_actions
|
|
||||||
)
|
|
||||||
|
|
||||||
if event.type == EventTypes.Redaction and event.redacts is not None:
|
|
||||||
self._remove_push_actions_for_event_id_txn(
|
|
||||||
txn, event.room_id, event.redacts
|
|
||||||
)
|
|
||||||
|
|
||||||
for room_id, depth in depth_updates.items():
|
for room_id, depth in depth_updates.items():
|
||||||
self._update_min_depth_for_room_txn(txn, room_id, depth)
|
self._update_min_depth_for_room_txn(txn, room_id, depth)
|
||||||
|
|
||||||
|
@ -426,30 +477,21 @@ class EventsStore(SQLBaseStore):
|
||||||
),
|
),
|
||||||
[event.event_id for event, _ in events_and_contexts]
|
[event.event_id for event, _ in events_and_contexts]
|
||||||
)
|
)
|
||||||
|
|
||||||
have_persisted = {
|
have_persisted = {
|
||||||
event_id: outlier
|
event_id: outlier
|
||||||
for event_id, outlier in txn.fetchall()
|
for event_id, outlier in txn.fetchall()
|
||||||
}
|
}
|
||||||
|
|
||||||
event_map = {}
|
|
||||||
to_remove = set()
|
to_remove = set()
|
||||||
for event, context in events_and_contexts:
|
for event, context in events_and_contexts:
|
||||||
# Handle the case of the list including the same event multiple
|
if context.rejected:
|
||||||
# times. The tricky thing here is when they differ by whether
|
# If the event is rejected then we don't care if the event
|
||||||
# they are an outlier.
|
# was an outlier or not.
|
||||||
if event.event_id in event_map:
|
if event.event_id in have_persisted:
|
||||||
other = event_map[event.event_id]
|
# If we have already seen the event then ignore it.
|
||||||
|
|
||||||
if not other.internal_metadata.is_outlier():
|
|
||||||
to_remove.add(event)
|
to_remove.add(event)
|
||||||
continue
|
continue
|
||||||
elif not event.internal_metadata.is_outlier():
|
|
||||||
to_remove.add(event)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
to_remove.add(other)
|
|
||||||
|
|
||||||
event_map[event.event_id] = event
|
|
||||||
|
|
||||||
if event.event_id not in have_persisted:
|
if event.event_id not in have_persisted:
|
||||||
continue
|
continue
|
||||||
|
@ -458,6 +500,12 @@ class EventsStore(SQLBaseStore):
|
||||||
|
|
||||||
outlier_persisted = have_persisted[event.event_id]
|
outlier_persisted = have_persisted[event.event_id]
|
||||||
if not event.internal_metadata.is_outlier() and outlier_persisted:
|
if not event.internal_metadata.is_outlier() and outlier_persisted:
|
||||||
|
# We received a copy of an event that we had already stored as
|
||||||
|
# an outlier in the database. We now have some state at that
|
||||||
|
# so we need to update the state_groups table with that state.
|
||||||
|
|
||||||
|
# insert into the state_group, state_groups_state and
|
||||||
|
# event_to_state_groups tables.
|
||||||
self._store_mult_state_groups_txn(txn, ((event, context),))
|
self._store_mult_state_groups_txn(txn, ((event, context),))
|
||||||
|
|
||||||
metadata_json = encode_json(
|
metadata_json = encode_json(
|
||||||
|
@ -473,6 +521,8 @@ class EventsStore(SQLBaseStore):
|
||||||
(metadata_json, event.event_id,)
|
(metadata_json, event.event_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add an entry to the ex_outlier_stream table to replicate the
|
||||||
|
# change in outlier status to our workers.
|
||||||
stream_order = event.internal_metadata.stream_ordering
|
stream_order = event.internal_metadata.stream_ordering
|
||||||
state_group_id = context.state_group or context.new_state_group_id
|
state_group_id = context.state_group or context.new_state_group_id
|
||||||
self._simple_insert_txn(
|
self._simple_insert_txn(
|
||||||
|
@ -494,6 +544,8 @@ class EventsStore(SQLBaseStore):
|
||||||
(False, event.event_id,)
|
(False, event.event_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Update the event_backward_extremities table now that this
|
||||||
|
# event isn't an outlier any more.
|
||||||
self._update_extremeties(txn, [event])
|
self._update_extremeties(txn, [event])
|
||||||
|
|
||||||
events_and_contexts = [
|
events_and_contexts = [
|
||||||
|
@ -501,38 +553,12 @@ class EventsStore(SQLBaseStore):
|
||||||
]
|
]
|
||||||
|
|
||||||
if not events_and_contexts:
|
if not events_and_contexts:
|
||||||
|
# Make sure we don't pass an empty list to functions that expect to
|
||||||
|
# be storing at least one element.
|
||||||
return
|
return
|
||||||
|
|
||||||
self._store_mult_state_groups_txn(txn, events_and_contexts)
|
# From this point onwards the events are only events that we haven't
|
||||||
|
# seen before.
|
||||||
self._handle_mult_prev_events(
|
|
||||||
txn,
|
|
||||||
events=[event for event, _ in events_and_contexts],
|
|
||||||
)
|
|
||||||
|
|
||||||
for event, _ in events_and_contexts:
|
|
||||||
if event.type == EventTypes.Name:
|
|
||||||
self._store_room_name_txn(txn, event)
|
|
||||||
elif event.type == EventTypes.Topic:
|
|
||||||
self._store_room_topic_txn(txn, event)
|
|
||||||
elif event.type == EventTypes.Message:
|
|
||||||
self._store_room_message_txn(txn, event)
|
|
||||||
elif event.type == EventTypes.Redaction:
|
|
||||||
self._store_redaction(txn, event)
|
|
||||||
elif event.type == EventTypes.RoomHistoryVisibility:
|
|
||||||
self._store_history_visibility_txn(txn, event)
|
|
||||||
elif event.type == EventTypes.GuestAccess:
|
|
||||||
self._store_guest_access_txn(txn, event)
|
|
||||||
|
|
||||||
self._store_room_members_txn(
|
|
||||||
txn,
|
|
||||||
[
|
|
||||||
event
|
|
||||||
for event, _ in events_and_contexts
|
|
||||||
if event.type == EventTypes.Member
|
|
||||||
],
|
|
||||||
backfilled=backfilled,
|
|
||||||
)
|
|
||||||
|
|
||||||
def event_dict(event):
|
def event_dict(event):
|
||||||
return {
|
return {
|
||||||
|
@ -544,6 +570,43 @@ class EventsStore(SQLBaseStore):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if delete_existing:
|
||||||
|
# For paranoia reasons, we go and delete all the existing entries
|
||||||
|
# for these events so we can reinsert them.
|
||||||
|
# This gets around any problems with some tables already having
|
||||||
|
# entries.
|
||||||
|
|
||||||
|
logger.info("Deleting existing")
|
||||||
|
|
||||||
|
for table in (
|
||||||
|
"events",
|
||||||
|
"event_auth",
|
||||||
|
"event_json",
|
||||||
|
"event_content_hashes",
|
||||||
|
"event_destinations",
|
||||||
|
"event_edge_hashes",
|
||||||
|
"event_edges",
|
||||||
|
"event_forward_extremities",
|
||||||
|
"event_push_actions",
|
||||||
|
"event_reference_hashes",
|
||||||
|
"event_search",
|
||||||
|
"event_signatures",
|
||||||
|
"event_to_state_groups",
|
||||||
|
"guest_access",
|
||||||
|
"history_visibility",
|
||||||
|
"local_invites",
|
||||||
|
"room_names",
|
||||||
|
"state_events",
|
||||||
|
"rejections",
|
||||||
|
"redactions",
|
||||||
|
"room_memberships",
|
||||||
|
"state_events"
|
||||||
|
):
|
||||||
|
txn.executemany(
|
||||||
|
"DELETE FROM %s WHERE event_id = ?" % (table,),
|
||||||
|
[(ev.event_id,) for ev, _ in events_and_contexts]
|
||||||
|
)
|
||||||
|
|
||||||
self._simple_insert_many_txn(
|
self._simple_insert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="event_json",
|
table="event_json",
|
||||||
|
@ -576,15 +639,51 @@ class EventsStore(SQLBaseStore):
|
||||||
"content": encode_json(event.content).decode("UTF-8"),
|
"content": encode_json(event.content).decode("UTF-8"),
|
||||||
"origin_server_ts": int(event.origin_server_ts),
|
"origin_server_ts": int(event.origin_server_ts),
|
||||||
"received_ts": self._clock.time_msec(),
|
"received_ts": self._clock.time_msec(),
|
||||||
|
"sender": event.sender,
|
||||||
|
"contains_url": (
|
||||||
|
"url" in event.content
|
||||||
|
and isinstance(event.content["url"], basestring)
|
||||||
|
),
|
||||||
}
|
}
|
||||||
for event, _ in events_and_contexts
|
for event, _ in events_and_contexts
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Remove the rejected events from the list now that we've added them
|
||||||
|
# to the events table and the events_json table.
|
||||||
|
to_remove = set()
|
||||||
|
for event, context in events_and_contexts:
|
||||||
if context.rejected:
|
if context.rejected:
|
||||||
|
# Insert the event_id into the rejections table
|
||||||
self._store_rejections_txn(
|
self._store_rejections_txn(
|
||||||
txn, event.event_id, context.rejected
|
txn, event.event_id, context.rejected
|
||||||
)
|
)
|
||||||
|
to_remove.add(event)
|
||||||
|
|
||||||
|
events_and_contexts = [
|
||||||
|
ec for ec in events_and_contexts if ec[0] not in to_remove
|
||||||
|
]
|
||||||
|
|
||||||
|
if not events_and_contexts:
|
||||||
|
# Make sure we don't pass an empty list to functions that expect to
|
||||||
|
# be storing at least one element.
|
||||||
|
return
|
||||||
|
|
||||||
|
# From this point onwards the events are only ones that weren't rejected.
|
||||||
|
|
||||||
|
for event, context in events_and_contexts:
|
||||||
|
# Insert all the push actions into the event_push_actions table.
|
||||||
|
if context.push_actions:
|
||||||
|
self._set_push_actions_for_event_and_users_txn(
|
||||||
|
txn, event, context.push_actions
|
||||||
|
)
|
||||||
|
|
||||||
|
if event.type == EventTypes.Redaction and event.redacts is not None:
|
||||||
|
# Remove the entries in the event_push_actions table for the
|
||||||
|
# redacted event.
|
||||||
|
self._remove_push_actions_for_event_id_txn(
|
||||||
|
txn, event.room_id, event.redacts
|
||||||
|
)
|
||||||
|
|
||||||
self._simple_insert_many_txn(
|
self._simple_insert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
|
@ -600,6 +699,49 @@ class EventsStore(SQLBaseStore):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Insert into the state_groups, state_groups_state, and
|
||||||
|
# event_to_state_groups tables.
|
||||||
|
self._store_mult_state_groups_txn(txn, events_and_contexts)
|
||||||
|
|
||||||
|
# Update the event_forward_extremities, event_backward_extremities and
|
||||||
|
# event_edges tables.
|
||||||
|
self._handle_mult_prev_events(
|
||||||
|
txn,
|
||||||
|
events=[event for event, _ in events_and_contexts],
|
||||||
|
)
|
||||||
|
|
||||||
|
for event, _ in events_and_contexts:
|
||||||
|
if event.type == EventTypes.Name:
|
||||||
|
# Insert into the room_names and event_search tables.
|
||||||
|
self._store_room_name_txn(txn, event)
|
||||||
|
elif event.type == EventTypes.Topic:
|
||||||
|
# Insert into the topics table and event_search table.
|
||||||
|
self._store_room_topic_txn(txn, event)
|
||||||
|
elif event.type == EventTypes.Message:
|
||||||
|
# Insert into the event_search table.
|
||||||
|
self._store_room_message_txn(txn, event)
|
||||||
|
elif event.type == EventTypes.Redaction:
|
||||||
|
# Insert into the redactions table.
|
||||||
|
self._store_redaction(txn, event)
|
||||||
|
elif event.type == EventTypes.RoomHistoryVisibility:
|
||||||
|
# Insert into the event_search table.
|
||||||
|
self._store_history_visibility_txn(txn, event)
|
||||||
|
elif event.type == EventTypes.GuestAccess:
|
||||||
|
# Insert into the event_search table.
|
||||||
|
self._store_guest_access_txn(txn, event)
|
||||||
|
|
||||||
|
# Insert into the room_memberships table.
|
||||||
|
self._store_room_members_txn(
|
||||||
|
txn,
|
||||||
|
[
|
||||||
|
event
|
||||||
|
for event, _ in events_and_contexts
|
||||||
|
if event.type == EventTypes.Member
|
||||||
|
],
|
||||||
|
backfilled=backfilled,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Insert event_reference_hashes table.
|
||||||
self._store_event_reference_hashes_txn(
|
self._store_event_reference_hashes_txn(
|
||||||
txn, [event for event, _ in events_and_contexts]
|
txn, [event for event, _ in events_and_contexts]
|
||||||
)
|
)
|
||||||
|
@ -644,6 +786,7 @@ class EventsStore(SQLBaseStore):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Prefill the event cache
|
||||||
self._add_to_cache(txn, events_and_contexts)
|
self._add_to_cache(txn, events_and_contexts)
|
||||||
|
|
||||||
if backfilled:
|
if backfilled:
|
||||||
|
@ -656,22 +799,11 @@ class EventsStore(SQLBaseStore):
|
||||||
# Outlier events shouldn't clobber the current state.
|
# Outlier events shouldn't clobber the current state.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if context.rejected:
|
|
||||||
# If the event failed it's auth checks then it shouldn't
|
|
||||||
# clobbler the current state.
|
|
||||||
continue
|
|
||||||
|
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self._get_current_state_for_key.invalidate,
|
self._get_current_state_for_key.invalidate,
|
||||||
(event.room_id, event.type, event.state_key,)
|
(event.room_id, event.type, event.state_key,)
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.type in [EventTypes.Name, EventTypes.Aliases]:
|
|
||||||
txn.call_after(
|
|
||||||
self.get_room_name_and_aliases.invalidate,
|
|
||||||
(event.room_id,)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._simple_upsert_txn(
|
self._simple_upsert_txn(
|
||||||
txn,
|
txn,
|
||||||
"current_state_events",
|
"current_state_events",
|
||||||
|
@ -1121,6 +1253,78 @@ class EventsStore(SQLBaseStore):
|
||||||
ret = yield self.runInteraction("count_messages", _count_messages)
|
ret = yield self.runInteraction("count_messages", _count_messages)
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _background_reindex_fields_sender(self, progress, batch_size):
|
||||||
|
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
||||||
|
max_stream_id = progress["max_stream_id_exclusive"]
|
||||||
|
rows_inserted = progress.get("rows_inserted", 0)
|
||||||
|
|
||||||
|
INSERT_CLUMP_SIZE = 1000
|
||||||
|
|
||||||
|
def reindex_txn(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT stream_ordering, event_id, json FROM events"
|
||||||
|
" INNER JOIN event_json USING (event_id)"
|
||||||
|
" WHERE ? <= stream_ordering AND stream_ordering < ?"
|
||||||
|
" ORDER BY stream_ordering DESC"
|
||||||
|
" LIMIT ?"
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
|
||||||
|
|
||||||
|
rows = txn.fetchall()
|
||||||
|
if not rows:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
min_stream_id = rows[-1][0]
|
||||||
|
|
||||||
|
update_rows = []
|
||||||
|
for row in rows:
|
||||||
|
try:
|
||||||
|
event_id = row[1]
|
||||||
|
event_json = json.loads(row[2])
|
||||||
|
sender = event_json["sender"]
|
||||||
|
content = event_json["content"]
|
||||||
|
|
||||||
|
contains_url = "url" in content
|
||||||
|
if contains_url:
|
||||||
|
contains_url &= isinstance(content["url"], basestring)
|
||||||
|
except (KeyError, AttributeError):
|
||||||
|
# If the event is missing a necessary field then
|
||||||
|
# skip over it.
|
||||||
|
continue
|
||||||
|
|
||||||
|
update_rows.append((sender, contains_url, event_id))
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
"UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
|
||||||
|
)
|
||||||
|
|
||||||
|
for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
|
||||||
|
clump = update_rows[index:index + INSERT_CLUMP_SIZE]
|
||||||
|
txn.executemany(sql, clump)
|
||||||
|
|
||||||
|
progress = {
|
||||||
|
"target_min_stream_id_inclusive": target_min_stream_id,
|
||||||
|
"max_stream_id_exclusive": min_stream_id,
|
||||||
|
"rows_inserted": rows_inserted + len(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
self._background_update_progress_txn(
|
||||||
|
txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
|
||||||
|
)
|
||||||
|
|
||||||
|
return len(rows)
|
||||||
|
|
||||||
|
result = yield self.runInteraction(
|
||||||
|
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
|
||||||
|
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _background_reindex_origin_server_ts(self, progress, batch_size):
|
def _background_reindex_origin_server_ts(self, progress, batch_size):
|
||||||
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
||||||
|
@ -1288,6 +1492,162 @@ class EventsStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
|
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
|
||||||
|
|
||||||
|
def delete_old_state(self, room_id, topological_ordering):
|
||||||
|
return self.runInteraction(
|
||||||
|
"delete_old_state",
|
||||||
|
self._delete_old_state_txn, room_id, topological_ordering
|
||||||
|
)
|
||||||
|
|
||||||
|
def _delete_old_state_txn(self, txn, room_id, topological_ordering):
|
||||||
|
"""Deletes old room state
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Tables that should be pruned:
|
||||||
|
# event_auth
|
||||||
|
# event_backward_extremities
|
||||||
|
# event_content_hashes
|
||||||
|
# event_destinations
|
||||||
|
# event_edge_hashes
|
||||||
|
# event_edges
|
||||||
|
# event_forward_extremities
|
||||||
|
# event_json
|
||||||
|
# event_push_actions
|
||||||
|
# event_reference_hashes
|
||||||
|
# event_search
|
||||||
|
# event_signatures
|
||||||
|
# event_to_state_groups
|
||||||
|
# events
|
||||||
|
# rejections
|
||||||
|
# room_depth
|
||||||
|
# state_groups
|
||||||
|
# state_groups_state
|
||||||
|
|
||||||
|
# First ensure that we're not about to delete all the forward extremeties
|
||||||
|
txn.execute(
|
||||||
|
"SELECT e.event_id, e.depth FROM events as e "
|
||||||
|
"INNER JOIN event_forward_extremities as f "
|
||||||
|
"ON e.event_id = f.event_id "
|
||||||
|
"AND e.room_id = f.room_id "
|
||||||
|
"WHERE f.room_id = ?",
|
||||||
|
(room_id,)
|
||||||
|
)
|
||||||
|
rows = txn.fetchall()
|
||||||
|
max_depth = max(row[0] for row in rows)
|
||||||
|
|
||||||
|
if max_depth <= topological_ordering:
|
||||||
|
# We need to ensure we don't delete all the events from the datanase
|
||||||
|
# otherwise we wouldn't be able to send any events (due to not
|
||||||
|
# having any backwards extremeties)
|
||||||
|
raise SynapseError(
|
||||||
|
400, "topological_ordering is greater than forward extremeties"
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.execute(
|
||||||
|
"SELECT event_id, state_key FROM events"
|
||||||
|
" LEFT JOIN state_events USING (room_id, event_id)"
|
||||||
|
" WHERE room_id = ? AND topological_ordering < ?",
|
||||||
|
(room_id, topological_ordering,)
|
||||||
|
)
|
||||||
|
event_rows = txn.fetchall()
|
||||||
|
|
||||||
|
# We calculate the new entries for the backward extremeties by finding
|
||||||
|
# all events that point to events that are to be purged
|
||||||
|
txn.execute(
|
||||||
|
"SELECT DISTINCT e.event_id FROM events as e"
|
||||||
|
" INNER JOIN event_edges as ed ON e.event_id = ed.prev_event_id"
|
||||||
|
" INNER JOIN events as e2 ON e2.event_id = ed.event_id"
|
||||||
|
" WHERE e.room_id = ? AND e.topological_ordering < ?"
|
||||||
|
" AND e2.topological_ordering >= ?",
|
||||||
|
(room_id, topological_ordering, topological_ordering)
|
||||||
|
)
|
||||||
|
new_backwards_extrems = txn.fetchall()
|
||||||
|
|
||||||
|
txn.execute(
|
||||||
|
"DELETE FROM event_backward_extremities WHERE room_id = ?",
|
||||||
|
(room_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update backward extremeties
|
||||||
|
txn.executemany(
|
||||||
|
"INSERT INTO event_backward_extremities (room_id, event_id)"
|
||||||
|
" VALUES (?, ?)",
|
||||||
|
[
|
||||||
|
(room_id, event_id) for event_id, in new_backwards_extrems
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get all state groups that are only referenced by events that are
|
||||||
|
# to be deleted.
|
||||||
|
txn.execute(
|
||||||
|
"SELECT state_group FROM event_to_state_groups"
|
||||||
|
" INNER JOIN events USING (event_id)"
|
||||||
|
" WHERE state_group IN ("
|
||||||
|
" SELECT DISTINCT state_group FROM events"
|
||||||
|
" INNER JOIN event_to_state_groups USING (event_id)"
|
||||||
|
" WHERE room_id = ? AND topological_ordering < ?"
|
||||||
|
" )"
|
||||||
|
" GROUP BY state_group HAVING MAX(topological_ordering) < ?",
|
||||||
|
(room_id, topological_ordering, topological_ordering)
|
||||||
|
)
|
||||||
|
state_rows = txn.fetchall()
|
||||||
|
txn.executemany(
|
||||||
|
"DELETE FROM state_groups_state WHERE state_group = ?",
|
||||||
|
state_rows
|
||||||
|
)
|
||||||
|
txn.executemany(
|
||||||
|
"DELETE FROM state_groups WHERE id = ?",
|
||||||
|
state_rows
|
||||||
|
)
|
||||||
|
# Delete all non-state
|
||||||
|
txn.executemany(
|
||||||
|
"DELETE FROM event_to_state_groups WHERE event_id = ?",
|
||||||
|
[(event_id,) for event_id, _ in event_rows]
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.execute(
|
||||||
|
"UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
|
||||||
|
(topological_ordering, room_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete all remote non-state events
|
||||||
|
to_delete = [
|
||||||
|
(event_id,) for event_id, state_key in event_rows
|
||||||
|
if state_key is None and not self.hs.is_mine_id(event_id)
|
||||||
|
]
|
||||||
|
for table in (
|
||||||
|
"events",
|
||||||
|
"event_json",
|
||||||
|
"event_auth",
|
||||||
|
"event_content_hashes",
|
||||||
|
"event_destinations",
|
||||||
|
"event_edge_hashes",
|
||||||
|
"event_edges",
|
||||||
|
"event_forward_extremities",
|
||||||
|
"event_push_actions",
|
||||||
|
"event_reference_hashes",
|
||||||
|
"event_search",
|
||||||
|
"event_signatures",
|
||||||
|
"rejections",
|
||||||
|
):
|
||||||
|
txn.executemany(
|
||||||
|
"DELETE FROM %s WHERE event_id = ?" % (table,),
|
||||||
|
to_delete
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.executemany(
|
||||||
|
"DELETE FROM events WHERE event_id = ?",
|
||||||
|
to_delete
|
||||||
|
)
|
||||||
|
# Mark all state and own events as outliers
|
||||||
|
txn.executemany(
|
||||||
|
"UPDATE events SET outlier = ?"
|
||||||
|
" WHERE event_id = ?",
|
||||||
|
[
|
||||||
|
(True, event_id,) for event_id, state_key in event_rows
|
||||||
|
if state_key is not None or self.hs.is_mine_id(event_id)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
AllNewEventsResult = namedtuple("AllNewEventsResult", [
|
AllNewEventsResult = namedtuple("AllNewEventsResult", [
|
||||||
"new_forward_events", "new_backfill_events",
|
"new_forward_events", "new_backfill_events",
|
||||||
|
|
|
@ -22,6 +22,10 @@ import OpenSSL
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class KeyStore(SQLBaseStore):
|
class KeyStore(SQLBaseStore):
|
||||||
"""Persistence for signature verification keys and tls X.509 certificates
|
"""Persistence for signature verification keys and tls X.509 certificates
|
||||||
|
@ -74,22 +78,22 @@ class KeyStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
@cachedInlineCallbacks()
|
@cachedInlineCallbacks()
|
||||||
def get_all_server_verify_keys(self, server_name):
|
def _get_server_verify_key(self, server_name, key_id):
|
||||||
rows = yield self._simple_select_list(
|
verify_key_bytes = yield self._simple_select_one_onecol(
|
||||||
table="server_signature_keys",
|
table="server_signature_keys",
|
||||||
keyvalues={
|
keyvalues={
|
||||||
"server_name": server_name,
|
"server_name": server_name,
|
||||||
|
"key_id": key_id,
|
||||||
},
|
},
|
||||||
retcols=["key_id", "verify_key"],
|
retcol="verify_key",
|
||||||
desc="get_all_server_verify_keys",
|
desc="_get_server_verify_key",
|
||||||
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue({
|
if verify_key_bytes:
|
||||||
row["key_id"]: decode_verify_key_bytes(
|
defer.returnValue(decode_verify_key_bytes(
|
||||||
row["key_id"], str(row["verify_key"])
|
key_id, str(verify_key_bytes)
|
||||||
)
|
))
|
||||||
for row in rows
|
|
||||||
})
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_server_verify_keys(self, server_name, key_ids):
|
def get_server_verify_keys(self, server_name, key_ids):
|
||||||
|
@ -101,12 +105,12 @@ class KeyStore(SQLBaseStore):
|
||||||
Returns:
|
Returns:
|
||||||
(list of VerifyKey): The verification keys.
|
(list of VerifyKey): The verification keys.
|
||||||
"""
|
"""
|
||||||
keys = yield self.get_all_server_verify_keys(server_name)
|
keys = {}
|
||||||
defer.returnValue({
|
for key_id in key_ids:
|
||||||
k: keys[k]
|
key = yield self._get_server_verify_key(server_name, key_id)
|
||||||
for k in key_ids
|
if key:
|
||||||
if k in keys and keys[k]
|
keys[key_id] = key
|
||||||
})
|
defer.returnValue(keys)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def store_server_verify_key(self, server_name, from_server, time_now_ms,
|
def store_server_verify_key(self, server_name, from_server, time_now_ms,
|
||||||
|
@ -133,8 +137,6 @@ class KeyStore(SQLBaseStore):
|
||||||
desc="store_server_verify_key",
|
desc="store_server_verify_key",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_all_server_verify_keys.invalidate((server_name,))
|
|
||||||
|
|
||||||
def store_server_keys_json(self, server_name, key_id, from_server,
|
def store_server_keys_json(self, server_name, key_id, from_server,
|
||||||
ts_now_ms, ts_expires_ms, key_json_bytes):
|
ts_now_ms, ts_expires_ms, key_json_bytes):
|
||||||
"""Stores the JSON bytes for a set of keys from a server
|
"""Stores the JSON bytes for a set of keys from a server
|
||||||
|
|
|
@ -157,10 +157,25 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||||
"created_ts": time_now_ms,
|
"created_ts": time_now_ms,
|
||||||
"upload_name": upload_name,
|
"upload_name": upload_name,
|
||||||
"filesystem_id": filesystem_id,
|
"filesystem_id": filesystem_id,
|
||||||
|
"last_access_ts": time_now_ms,
|
||||||
},
|
},
|
||||||
desc="store_cached_remote_media",
|
desc="store_cached_remote_media",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def update_cached_last_access_time(self, origin_id_tuples, time_ts):
|
||||||
|
def update_cache_txn(txn):
|
||||||
|
sql = (
|
||||||
|
"UPDATE remote_media_cache SET last_access_ts = ?"
|
||||||
|
" WHERE media_origin = ? AND media_id = ?"
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.executemany(sql, (
|
||||||
|
(time_ts, media_origin, media_id)
|
||||||
|
for media_origin, media_id in origin_id_tuples
|
||||||
|
))
|
||||||
|
|
||||||
|
return self.runInteraction("update_cached_last_access_time", update_cache_txn)
|
||||||
|
|
||||||
def get_remote_media_thumbnails(self, origin, media_id):
|
def get_remote_media_thumbnails(self, origin, media_id):
|
||||||
return self._simple_select_list(
|
return self._simple_select_list(
|
||||||
"remote_media_cache_thumbnails",
|
"remote_media_cache_thumbnails",
|
||||||
|
@ -190,3 +205,32 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||||
},
|
},
|
||||||
desc="store_remote_media_thumbnail",
|
desc="store_remote_media_thumbnail",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_remote_media_before(self, before_ts):
|
||||||
|
sql = (
|
||||||
|
"SELECT media_origin, media_id, filesystem_id"
|
||||||
|
" FROM remote_media_cache"
|
||||||
|
" WHERE last_access_ts < ?"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._execute(
|
||||||
|
"get_remote_media_before", self.cursor_to_dict, sql, before_ts
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_remote_media(self, media_origin, media_id):
|
||||||
|
def delete_remote_media_txn(txn):
|
||||||
|
self._simple_delete_txn(
|
||||||
|
txn,
|
||||||
|
"remote_media_cache",
|
||||||
|
keyvalues={
|
||||||
|
"media_origin": media_origin, "media_id": media_id
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self._simple_delete_txn(
|
||||||
|
txn,
|
||||||
|
"remote_media_cache_thumbnails",
|
||||||
|
keyvalues={
|
||||||
|
"media_origin": media_origin, "media_id": media_id
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return self.runInteraction("delete_remote_media", delete_remote_media_txn)
|
||||||
|
|
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Remember to update this number every time a change is made to database
|
# Remember to update this number every time a change is made to database
|
||||||
# schema files, so the users will be informed on server restarts.
|
# schema files, so the users will be informed on server restarts.
|
||||||
SCHEMA_VERSION = 32
|
SCHEMA_VERSION = 33
|
||||||
|
|
||||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
|
|
@ -18,25 +18,40 @@ import re
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import StoreError, Codes
|
from synapse.api.errors import StoreError, Codes
|
||||||
|
from synapse.storage import background_updates
|
||||||
from ._base import SQLBaseStore
|
|
||||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||||
|
|
||||||
|
|
||||||
class RegistrationStore(SQLBaseStore):
|
class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(RegistrationStore, self).__init__(hs)
|
super(RegistrationStore, self).__init__(hs)
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
self.register_background_index_update(
|
||||||
|
"access_tokens_device_index",
|
||||||
|
index_name="access_tokens_device_id",
|
||||||
|
table="access_tokens",
|
||||||
|
columns=["user_id", "device_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_background_index_update(
|
||||||
|
"refresh_tokens_device_index",
|
||||||
|
index_name="refresh_tokens_device_id",
|
||||||
|
table="refresh_tokens",
|
||||||
|
columns=["user_id", "device_id"],
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_access_token_to_user(self, user_id, token):
|
def add_access_token_to_user(self, user_id, token, device_id=None):
|
||||||
"""Adds an access token for the given user.
|
"""Adds an access token for the given user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The user ID.
|
user_id (str): The user ID.
|
||||||
token (str): The new access token to add.
|
token (str): The new access token to add.
|
||||||
|
device_id (str): ID of the device to associate with the access
|
||||||
|
token
|
||||||
Raises:
|
Raises:
|
||||||
StoreError if there was a problem adding this.
|
StoreError if there was a problem adding this.
|
||||||
"""
|
"""
|
||||||
|
@ -47,18 +62,21 @@ class RegistrationStore(SQLBaseStore):
|
||||||
{
|
{
|
||||||
"id": next_id,
|
"id": next_id,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"token": token
|
"token": token,
|
||||||
|
"device_id": device_id,
|
||||||
},
|
},
|
||||||
desc="add_access_token_to_user",
|
desc="add_access_token_to_user",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_refresh_token_to_user(self, user_id, token):
|
def add_refresh_token_to_user(self, user_id, token, device_id=None):
|
||||||
"""Adds a refresh token for the given user.
|
"""Adds a refresh token for the given user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The user ID.
|
user_id (str): The user ID.
|
||||||
token (str): The new refresh token to add.
|
token (str): The new refresh token to add.
|
||||||
|
device_id (str): ID of the device to associate with the access
|
||||||
|
token
|
||||||
Raises:
|
Raises:
|
||||||
StoreError if there was a problem adding this.
|
StoreError if there was a problem adding this.
|
||||||
"""
|
"""
|
||||||
|
@ -69,20 +87,23 @@ class RegistrationStore(SQLBaseStore):
|
||||||
{
|
{
|
||||||
"id": next_id,
|
"id": next_id,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"token": token
|
"token": token,
|
||||||
|
"device_id": device_id,
|
||||||
},
|
},
|
||||||
desc="add_refresh_token_to_user",
|
desc="add_refresh_token_to_user",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def register(self, user_id, token, password_hash,
|
def register(self, user_id, token=None, password_hash=None,
|
||||||
was_guest=False, make_guest=False, appservice_id=None,
|
was_guest=False, make_guest=False, appservice_id=None,
|
||||||
create_profile_with_localpart=None):
|
create_profile_with_localpart=None, admin=False):
|
||||||
"""Attempts to register an account.
|
"""Attempts to register an account.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The desired user ID to register.
|
user_id (str): The desired user ID to register.
|
||||||
token (str): The desired access token to use for this user.
|
token (str): The desired access token to use for this user. If this
|
||||||
|
is not None, the given access token is associated with the user
|
||||||
|
id.
|
||||||
password_hash (str): Optional. The password hash for this user.
|
password_hash (str): Optional. The password hash for this user.
|
||||||
was_guest (bool): Optional. Whether this is a guest account being
|
was_guest (bool): Optional. Whether this is a guest account being
|
||||||
upgraded to a non-guest account.
|
upgraded to a non-guest account.
|
||||||
|
@ -104,6 +125,7 @@ class RegistrationStore(SQLBaseStore):
|
||||||
make_guest,
|
make_guest,
|
||||||
appservice_id,
|
appservice_id,
|
||||||
create_profile_with_localpart,
|
create_profile_with_localpart,
|
||||||
|
admin
|
||||||
)
|
)
|
||||||
self.get_user_by_id.invalidate((user_id,))
|
self.get_user_by_id.invalidate((user_id,))
|
||||||
self.is_guest.invalidate((user_id,))
|
self.is_guest.invalidate((user_id,))
|
||||||
|
@ -118,6 +140,7 @@ class RegistrationStore(SQLBaseStore):
|
||||||
make_guest,
|
make_guest,
|
||||||
appservice_id,
|
appservice_id,
|
||||||
create_profile_with_localpart,
|
create_profile_with_localpart,
|
||||||
|
admin,
|
||||||
):
|
):
|
||||||
now = int(self.clock.time())
|
now = int(self.clock.time())
|
||||||
|
|
||||||
|
@ -125,29 +148,48 @@ class RegistrationStore(SQLBaseStore):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if was_guest:
|
if was_guest:
|
||||||
txn.execute("UPDATE users SET"
|
# Ensure that the guest user actually exists
|
||||||
" password_hash = ?,"
|
# ``allow_none=False`` makes this raise an exception
|
||||||
" upgrade_ts = ?,"
|
# if the row isn't in the database.
|
||||||
" is_guest = ?"
|
self._simple_select_one_txn(
|
||||||
" WHERE name = ?",
|
txn,
|
||||||
[password_hash, now, 1 if make_guest else 0, user_id])
|
"users",
|
||||||
|
keyvalues={
|
||||||
|
"name": user_id,
|
||||||
|
"is_guest": 1,
|
||||||
|
},
|
||||||
|
retcols=("name",),
|
||||||
|
allow_none=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._simple_update_one_txn(
|
||||||
|
txn,
|
||||||
|
"users",
|
||||||
|
keyvalues={
|
||||||
|
"name": user_id,
|
||||||
|
"is_guest": 1,
|
||||||
|
},
|
||||||
|
updatevalues={
|
||||||
|
"password_hash": password_hash,
|
||||||
|
"upgrade_ts": now,
|
||||||
|
"is_guest": 1 if make_guest else 0,
|
||||||
|
"appservice_id": appservice_id,
|
||||||
|
"admin": 1 if admin else 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
txn.execute("INSERT INTO users "
|
self._simple_insert_txn(
|
||||||
"("
|
txn,
|
||||||
" name,"
|
"users",
|
||||||
" password_hash,"
|
values={
|
||||||
" creation_ts,"
|
"name": user_id,
|
||||||
" is_guest,"
|
"password_hash": password_hash,
|
||||||
" appservice_id"
|
"creation_ts": now,
|
||||||
") "
|
"is_guest": 1 if make_guest else 0,
|
||||||
"VALUES (?,?,?,?,?)",
|
"appservice_id": appservice_id,
|
||||||
[
|
"admin": 1 if admin else 0,
|
||||||
user_id,
|
}
|
||||||
password_hash,
|
)
|
||||||
now,
|
|
||||||
1 if make_guest else 0,
|
|
||||||
appservice_id,
|
|
||||||
])
|
|
||||||
except self.database_engine.module.IntegrityError:
|
except self.database_engine.module.IntegrityError:
|
||||||
raise StoreError(
|
raise StoreError(
|
||||||
400, "User ID already taken.", errcode=Codes.USER_IN_USE
|
400, "User ID already taken.", errcode=Codes.USER_IN_USE
|
||||||
|
@ -209,16 +251,37 @@ class RegistrationStore(SQLBaseStore):
|
||||||
self.get_user_by_id.invalidate((user_id,))
|
self.get_user_by_id.invalidate((user_id,))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def user_delete_access_tokens(self, user_id, except_token_ids=[]):
|
def user_delete_access_tokens(self, user_id, except_token_ids=[],
|
||||||
def f(txn):
|
device_id=None,
|
||||||
sql = "SELECT token FROM access_tokens WHERE user_id = ?"
|
delete_refresh_tokens=False):
|
||||||
|
"""
|
||||||
|
Invalidate access/refresh tokens belonging to a user
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): ID of user the tokens belong to
|
||||||
|
except_token_ids (list[str]): list of access_tokens which should
|
||||||
|
*not* be deleted
|
||||||
|
device_id (str|None): ID of device the tokens are associated with.
|
||||||
|
If None, tokens associated with any device (or no device) will
|
||||||
|
be deleted
|
||||||
|
delete_refresh_tokens (bool): True to delete refresh tokens as
|
||||||
|
well as access tokens.
|
||||||
|
Returns:
|
||||||
|
defer.Deferred:
|
||||||
|
"""
|
||||||
|
def f(txn, table, except_tokens, call_after_delete):
|
||||||
|
sql = "SELECT token FROM %s WHERE user_id = ?" % table
|
||||||
clauses = [user_id]
|
clauses = [user_id]
|
||||||
|
|
||||||
if except_token_ids:
|
if device_id is not None:
|
||||||
|
sql += " AND device_id = ?"
|
||||||
|
clauses.append(device_id)
|
||||||
|
|
||||||
|
if except_tokens:
|
||||||
sql += " AND id NOT IN (%s)" % (
|
sql += " AND id NOT IN (%s)" % (
|
||||||
",".join(["?" for _ in except_token_ids]),
|
",".join(["?" for _ in except_tokens]),
|
||||||
)
|
)
|
||||||
clauses += except_token_ids
|
clauses += except_tokens
|
||||||
|
|
||||||
txn.execute(sql, clauses)
|
txn.execute(sql, clauses)
|
||||||
|
|
||||||
|
@ -227,16 +290,33 @@ class RegistrationStore(SQLBaseStore):
|
||||||
n = 100
|
n = 100
|
||||||
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
|
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
|
if call_after_delete:
|
||||||
for row in chunk:
|
for row in chunk:
|
||||||
txn.call_after(self.get_user_by_access_token.invalidate, (row[0],))
|
txn.call_after(call_after_delete, (row[0],))
|
||||||
|
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"DELETE FROM access_tokens WHERE token in (%s)" % (
|
"DELETE FROM %s WHERE token in (%s)" % (
|
||||||
|
table,
|
||||||
",".join(["?" for _ in chunk]),
|
",".join(["?" for _ in chunk]),
|
||||||
), [r[0] for r in chunk]
|
), [r[0] for r in chunk]
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.runInteraction("user_delete_access_tokens", f)
|
# delete refresh tokens first, to stop new access tokens being
|
||||||
|
# allocated while our backs are turned
|
||||||
|
if delete_refresh_tokens:
|
||||||
|
yield self.runInteraction(
|
||||||
|
"user_delete_access_tokens", f,
|
||||||
|
table="refresh_tokens",
|
||||||
|
except_tokens=[],
|
||||||
|
call_after_delete=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.runInteraction(
|
||||||
|
"user_delete_access_tokens", f,
|
||||||
|
table="access_tokens",
|
||||||
|
except_tokens=except_token_ids,
|
||||||
|
call_after_delete=self.get_user_by_access_token.invalidate,
|
||||||
|
)
|
||||||
|
|
||||||
def delete_access_token(self, access_token):
|
def delete_access_token(self, access_token):
|
||||||
def f(txn):
|
def f(txn):
|
||||||
|
@ -259,9 +339,8 @@ class RegistrationStore(SQLBaseStore):
|
||||||
Args:
|
Args:
|
||||||
token (str): The access token of a user.
|
token (str): The access token of a user.
|
||||||
Returns:
|
Returns:
|
||||||
dict: Including the name (user_id) and the ID of their access token.
|
defer.Deferred: None, if the token did not match, otherwise dict
|
||||||
Raises:
|
including the keys `name`, `is_guest`, `device_id`, `token_id`.
|
||||||
StoreError if no user was found.
|
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"get_user_by_access_token",
|
"get_user_by_access_token",
|
||||||
|
@ -270,18 +349,18 @@ class RegistrationStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
def exchange_refresh_token(self, refresh_token, token_generator):
|
def exchange_refresh_token(self, refresh_token, token_generator):
|
||||||
"""Exchange a refresh token for a new access token and refresh token.
|
"""Exchange a refresh token for a new one.
|
||||||
|
|
||||||
Doing so invalidates the old refresh token - refresh tokens are single
|
Doing so invalidates the old refresh token - refresh tokens are single
|
||||||
use.
|
use.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token (str): The refresh token of a user.
|
refresh_token (str): The refresh token of a user.
|
||||||
token_generator (fn: str -> str): Function which, when given a
|
token_generator (fn: str -> str): Function which, when given a
|
||||||
user ID, returns a unique refresh token for that user. This
|
user ID, returns a unique refresh token for that user. This
|
||||||
function must never return the same value twice.
|
function must never return the same value twice.
|
||||||
Returns:
|
Returns:
|
||||||
tuple of (user_id, refresh_token)
|
tuple of (user_id, new_refresh_token, device_id)
|
||||||
Raises:
|
Raises:
|
||||||
StoreError if no user was found with that refresh token.
|
StoreError if no user was found with that refresh token.
|
||||||
"""
|
"""
|
||||||
|
@ -293,12 +372,13 @@ class RegistrationStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _exchange_refresh_token(self, txn, old_token, token_generator):
|
def _exchange_refresh_token(self, txn, old_token, token_generator):
|
||||||
sql = "SELECT user_id FROM refresh_tokens WHERE token = ?"
|
sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?"
|
||||||
txn.execute(sql, (old_token,))
|
txn.execute(sql, (old_token,))
|
||||||
rows = self.cursor_to_dict(txn)
|
rows = self.cursor_to_dict(txn)
|
||||||
if not rows:
|
if not rows:
|
||||||
raise StoreError(403, "Did not recognize refresh token")
|
raise StoreError(403, "Did not recognize refresh token")
|
||||||
user_id = rows[0]["user_id"]
|
user_id = rows[0]["user_id"]
|
||||||
|
device_id = rows[0]["device_id"]
|
||||||
|
|
||||||
# TODO(danielwh): Maybe perform a validation on the macaroon that
|
# TODO(danielwh): Maybe perform a validation on the macaroon that
|
||||||
# macaroon.user_id == user_id.
|
# macaroon.user_id == user_id.
|
||||||
|
@ -307,7 +387,7 @@ class RegistrationStore(SQLBaseStore):
|
||||||
sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
|
sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
|
||||||
txn.execute(sql, (new_token, old_token,))
|
txn.execute(sql, (new_token, old_token,))
|
||||||
|
|
||||||
return user_id, new_token
|
return user_id, new_token, device_id
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def is_server_admin(self, user):
|
def is_server_admin(self, user):
|
||||||
|
@ -335,7 +415,8 @@ class RegistrationStore(SQLBaseStore):
|
||||||
|
|
||||||
def _query_for_auth(self, txn, token):
|
def _query_for_auth(self, txn, token):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT users.name, users.is_guest, access_tokens.id as token_id"
|
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
|
||||||
|
" access_tokens.device_id"
|
||||||
" FROM users"
|
" FROM users"
|
||||||
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
|
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
|
||||||
" WHERE token = ?"
|
" WHERE token = ?"
|
||||||
|
@ -384,6 +465,15 @@ class RegistrationStore(SQLBaseStore):
|
||||||
defer.returnValue(ret['user_id'])
|
defer.returnValue(ret['user_id'])
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
|
def user_delete_threepids(self, user_id):
|
||||||
|
return self._simple_delete(
|
||||||
|
"user_threepids",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
desc="user_delete_threepids",
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def count_all_users(self):
|
def count_all_users(self):
|
||||||
"""Counts all users registered on the homeserver."""
|
"""Counts all users registered on the homeserver."""
|
||||||
|
|
|
@ -18,7 +18,6 @@ from twisted.internet import defer
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
|
||||||
from .engines import PostgresEngine, Sqlite3Engine
|
from .engines import PostgresEngine, Sqlite3Engine
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
@ -192,49 +191,6 @@ class RoomStore(SQLBaseStore):
|
||||||
# This should be unreachable.
|
# This should be unreachable.
|
||||||
raise Exception("Unrecognized database engine")
|
raise Exception("Unrecognized database engine")
|
||||||
|
|
||||||
@cachedInlineCallbacks()
|
|
||||||
def get_room_name_and_aliases(self, room_id):
|
|
||||||
def get_room_name(txn):
|
|
||||||
sql = (
|
|
||||||
"SELECT name FROM room_names"
|
|
||||||
" INNER JOIN current_state_events USING (room_id, event_id)"
|
|
||||||
" WHERE room_id = ?"
|
|
||||||
" LIMIT 1"
|
|
||||||
)
|
|
||||||
|
|
||||||
txn.execute(sql, (room_id,))
|
|
||||||
rows = txn.fetchall()
|
|
||||||
if rows:
|
|
||||||
return rows[0][0]
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return [row[0] for row in txn.fetchall()]
|
|
||||||
|
|
||||||
def get_room_aliases(txn):
|
|
||||||
sql = (
|
|
||||||
"SELECT content FROM current_state_events"
|
|
||||||
" INNER JOIN events USING (room_id, event_id)"
|
|
||||||
" WHERE room_id = ?"
|
|
||||||
)
|
|
||||||
txn.execute(sql, (room_id,))
|
|
||||||
return [row[0] for row in txn.fetchall()]
|
|
||||||
|
|
||||||
name = yield self.runInteraction("get_room_name", get_room_name)
|
|
||||||
alias_contents = yield self.runInteraction("get_room_aliases", get_room_aliases)
|
|
||||||
|
|
||||||
aliases = []
|
|
||||||
|
|
||||||
for c in alias_contents:
|
|
||||||
try:
|
|
||||||
content = json.loads(c)
|
|
||||||
except:
|
|
||||||
continue
|
|
||||||
|
|
||||||
aliases.extend(content.get('aliases', []))
|
|
||||||
|
|
||||||
defer.returnValue((name, aliases))
|
|
||||||
|
|
||||||
def add_event_report(self, room_id, event_id, user_id, reason, content,
|
def add_event_report(self, room_id, event_id, user_id, reason, content,
|
||||||
received_ts):
|
received_ts):
|
||||||
next_id = self._event_reports_id_gen.get_next()
|
next_id = self._event_reports_id_gen.get_next()
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
/* Copyright 2016 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||||
|
('access_tokens_device_index', '{}');
|
21
synapse/storage/schema/delta/33/devices.sql
Normal file
21
synapse/storage/schema/delta/33/devices.sql
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
/* Copyright 2016 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
CREATE TABLE devices (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
device_id TEXT NOT NULL,
|
||||||
|
display_name TEXT,
|
||||||
|
CONSTRAINT device_uniqueness UNIQUE (user_id, device_id)
|
||||||
|
);
|
19
synapse/storage/schema/delta/33/devices_for_e2e_keys.sql
Normal file
19
synapse/storage/schema/delta/33/devices_for_e2e_keys.sql
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
/* Copyright 2016 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- make sure that we have a device record for each set of E2E keys, so that the
|
||||||
|
-- user can delete them if they like.
|
||||||
|
INSERT INTO devices
|
||||||
|
SELECT user_id, device_id, NULL FROM e2e_device_keys_json;
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue