Merge branch 'release-v0.16.0' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2016-06-09 14:21:23 +01:00
commit ba0406d10d
218 changed files with 13400 additions and 5458 deletions

View file

@ -57,3 +57,6 @@ Florent Violleau <floviolleau at gmail dot com>
Niklas Riekenbrauck <nikriek at gmail dot.com> Niklas Riekenbrauck <nikriek at gmail dot.com>
* Add JWT support for registration and login * Add JWT support for registration and login
Christoph Witzany <christoph at web.crofting.com>
* Add LDAP support for authentication

View file

@ -1,3 +1,126 @@
Changes in synapse v0.16.0 (2016-06-09)
=======================================
NB: As of v0.14 all AS config files must have an ID field.
Bug fixes:
* Don't make rooms published by default (PR #857)
Changes in synapse v0.16.0-rc2 (2016-06-08)
===========================================
Features:
* Add configuration option for tuning GC via ``gc.set_threshold`` (PR #849)
Changes:
* Record metrics about GC (PR #771, #847, #852)
* Add metric counter for number of persisted events (PR #841)
Bug fixes:
* Fix 'From' header in email notifications (PR #843)
* Fix presence where timeouts were not being fired for the first 8h after
restarts (PR #842)
* Fix bug where synapse sent malformed transactions to AS's when retrying
transactions (Commits 310197b, 8437906)
Performance Improvements:
* Remove event fetching from DB threads (PR #835)
* Change the way we cache events (PR #836)
* Add events to cache when we persist them (PR #840)
Changes in synapse v0.16.0-rc1 (2016-06-03)
===========================================
Version 0.15 was not released. See v0.15.0-rc1 below for additional changes.
Features:
* Add email notifications for missed messages (PR #759, #786, #799, #810, #815,
#821)
* Add a ``url_preview_ip_range_whitelist`` config param (PR #760)
* Add /report endpoint (PR #762)
* Add basic ignore user API (PR #763)
* Add an openidish mechanism for proving that you own a given user_id (PR #765)
* Allow clients to specify a server_name to avoid 'No known servers' (PR #794)
* Add secondary_directory_servers option to fetch room list from other servers
(PR #808, #813)
Changes:
* Report per request metrics for all of the things using request_handler (PR
#756)
* Correctly handle ``NULL`` password hashes from the database (PR #775)
* Allow receipts for events we haven't seen in the db (PR #784)
* Make synctl read a cache factor from config file (PR #785)
* Increment badge count per missed convo, not per msg (PR #793)
* Special case m.room.third_party_invite event auth to match invites (PR #814)
Bug fixes:
* Fix typo in event_auth servlet path (PR #757)
* Fix password reset (PR #758)
Performance improvements:
* Reduce database inserts when sending transactions (PR #767)
* Queue events by room for persistence (PR #768)
* Add cache to ``get_user_by_id`` (PR #772)
* Add and use ``get_domain_from_id`` (PR #773)
* Use tree cache for ``get_linearized_receipts_for_room`` (PR #779)
* Remove unused indices (PR #782)
* Add caches to ``bulk_get_push_rules*`` (PR #804)
* Cache ``get_event_reference_hashes`` (PR #806)
* Add ``get_users_with_read_receipts_in_room`` cache (PR #809)
* Use state to calculate ``get_users_in_room`` (PR #811)
* Load push rules in storage layer so that they get cached (PR #825)
* Make ``get_joined_hosts_for_room`` use get_users_in_room (PR #828)
* Poke notifier on next reactor tick (PR #829)
* Change CacheMetrics to be quicker (PR #830)
Changes in synapse v0.15.0-rc1 (2016-04-26)
===========================================
Features:
* Add login support for Javascript Web Tokens, thanks to Niklas Riekenbrauck
(PR #671,#687)
* Add URL previewing support (PR #688)
* Add login support for LDAP, thanks to Christoph Witzany (PR #701)
* Add GET endpoint for pushers (PR #716)
Changes:
* Never notify for member events (PR #667)
* Deduplicate identical ``/sync`` requests (PR #668)
* Require user to have left room to forget room (PR #673)
* Use DNS cache if within TTL (PR #677)
* Let users see their own leave events (PR #699)
* Deduplicate membership changes (PR #700)
* Increase performance of pusher code (PR #705)
* Respond with error status 504 if failed to talk to remote server (PR #731)
* Increase search performance on postgres (PR #745)
Bug fixes:
* Fix bug where disabling all notifications still resulted in push (PR #678)
* Fix bug where users couldn't reject remote invites if remote refused (PR #691)
* Fix bug where synapse attempted to backfill from itself (PR #693)
* Fix bug where profile information was not correctly added when joining remote
rooms (PR #703)
* Fix bug where register API required incorrect key name for AS registration
(PR #727)
Changes in synapse v0.14.0 (2016-03-30) Changes in synapse v0.14.0 (2016-03-30)
======================================= =======================================
@ -511,7 +634,7 @@ Configuration:
* Add support for changing the bind host of the metrics listener via the * Add support for changing the bind host of the metrics listener via the
``metrics_bind_host`` option. ``metrics_bind_host`` option.
Changes in synapse v0.9.0-r5 (2015-05-21) Changes in synapse v0.9.0-r5 (2015-05-21)
========================================= =========================================
@ -853,7 +976,7 @@ See UPGRADE for information about changes to the client server API, including
breaking backwards compatibility with VoIP calls and registration API. breaking backwards compatibility with VoIP calls and registration API.
Homeserver: Homeserver:
* When a user changes their displayname or avatar the server will now update * When a user changes their displayname or avatar the server will now update
all their join states to reflect this. all their join states to reflect this.
* The server now adds "age" key to events to indicate how old they are. This * The server now adds "age" key to events to indicate how old they are. This
is clock independent, so at no point does any server or webclient have to is clock independent, so at no point does any server or webclient have to
@ -911,7 +1034,7 @@ Changes in synapse 0.2.2 (2014-09-06)
===================================== =====================================
Homeserver: Homeserver:
* When the server returns state events it now also includes the previous * When the server returns state events it now also includes the previous
content. content.
* Add support for inviting people when creating a new room. * Add support for inviting people when creating a new room.
* Make the homeserver inform the room via `m.room.aliases` when a new alias * Make the homeserver inform the room via `m.room.aliases` when a new alias
@ -923,7 +1046,7 @@ Webclient:
* Handle `m.room.aliases` events. * Handle `m.room.aliases` events.
* Asynchronously send messages and show a local echo. * Asynchronously send messages and show a local echo.
* Inform the UI when a message failed to send. * Inform the UI when a message failed to send.
* Only autoscroll on receiving a new message if the user was already at the * Only autoscroll on receiving a new message if the user was already at the
bottom of the screen. bottom of the screen.
* Add support for ban/kick reasons. * Add support for ban/kick reasons.

View file

@ -11,6 +11,7 @@ recursive-include synapse/storage/schema *.sql
recursive-include synapse/storage/schema *.py recursive-include synapse/storage/schema *.py
recursive-include docs * recursive-include docs *
recursive-include res *
recursive-include scripts * recursive-include scripts *
recursive-include scripts-dev * recursive-include scripts-dev *
recursive-include tests *.py recursive-include tests *.py

View file

@ -105,7 +105,7 @@ Installing prerequisites on Ubuntu or Debian::
sudo apt-get install build-essential python2.7-dev libffi-dev \ sudo apt-get install build-essential python2.7-dev libffi-dev \
python-pip python-setuptools sqlite3 \ python-pip python-setuptools sqlite3 \
libssl-dev python-virtualenv libjpeg-dev libssl-dev python-virtualenv libjpeg-dev libxslt1-dev
Installing prerequisites on ArchLinux:: Installing prerequisites on ArchLinux::
@ -119,7 +119,6 @@ Installing prerequisites on CentOS 7::
python-virtualenv libffi-devel openssl-devel python-virtualenv libffi-devel openssl-devel
sudo yum groupinstall "Development Tools" sudo yum groupinstall "Development Tools"
Installing prerequisites on Mac OS X:: Installing prerequisites on Mac OS X::
xcode-select --install xcode-select --install
@ -151,12 +150,7 @@ In case of problems, please see the _Troubleshooting section below.
Alternatively, Silvio Fricke has contributed a Dockerfile to automate the Alternatively, Silvio Fricke has contributed a Dockerfile to automate the
above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/. above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/.
Another alternative is to install via apt from http://matrix.org/packages/debian/. Also, Martin Giess has created an auto-deployment process with vagrant/ansible,
Note that these packages do not include a client - choose one from
https://matrix.org/blog/try-matrix-now/ (or build your own with
https://github.com/matrix-org/matrix-js-sdk/).
Finally, Martin Giess has created an auto-deployment process with vagrant/ansible,
tested with VirtualBox/AWS/DigitalOcean - see https://github.com/EMnify/matrix-synapse-auto-deploy tested with VirtualBox/AWS/DigitalOcean - see https://github.com/EMnify/matrix-synapse-auto-deploy
for details. for details.
@ -230,6 +224,19 @@ For information on how to install and use PostgreSQL, please see
Platform Specific Instructions Platform Specific Instructions
============================== ==============================
Debian
------
Matrix provides official Debian packages via apt from http://matrix.org/packages/debian/.
Note that these packages do not include a client - choose one from
https://matrix.org/blog/try-matrix-now/ (or build your own with one of our SDKs :)
Fedora
------
Oleg Girko provides Fedora RPMs at
https://obs.infoserver.lv/project/monitor/matrix-synapse
ArchLinux ArchLinux
--------- ---------
@ -271,11 +278,17 @@ During setup of Synapse you need to call python2.7 directly again::
FreeBSD FreeBSD
------- -------
Synapse can be installed via FreeBSD Ports or Packages: Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from:
- Ports: ``cd /usr/ports/net/py-matrix-synapse && make install clean`` - Ports: ``cd /usr/ports/net/py-matrix-synapse && make install clean``
- Packages: ``pkg install py27-matrix-synapse`` - Packages: ``pkg install py27-matrix-synapse``
NixOS
-----
Robin Lambertz has packaged Synapse for NixOS at:
https://github.com/NixOS/nixpkgs/blob/master/nixos/modules/services/misc/matrix-synapse.nix
Windows Install Windows Install
--------------- ---------------
Synapse can be installed on Cygwin. It requires the following Cygwin packages: Synapse can be installed on Cygwin. It requires the following Cygwin packages:
@ -545,6 +558,23 @@ as the primary means of identity and E2E encryption is not complete. As such,
we are running a single identity server (https://matrix.org) at the current we are running a single identity server (https://matrix.org) at the current
time. time.
URL Previews
============
Synapse 0.15.0 introduces an experimental new API for previewing URLs at
/_matrix/media/r0/preview_url. This is disabled by default. To turn it on
you must enable the `url_preview_enabled: True` config parameter and explicitly
specify the IP ranges that Synapse is not allowed to spider for previewing in
the `url_preview_ip_range_blacklist` configuration parameter. This is critical
from a security perspective to stop arbitrary Matrix users spidering 'internal'
URLs on your network. At the very least we recommend that your loopback and
RFC1918 IP addresses are blacklisted.
This also requires the optional lxml and netaddr python dependencies to be
installed.
Password reset Password reset
============== ==============

View file

@ -30,6 +30,14 @@ running:
python synapse/python_dependencies.py | xargs -n1 pip install python synapse/python_dependencies.py | xargs -n1 pip install
Upgrading to v0.15.0
====================
If you want to use the new URL previewing API (/_matrix/media/r0/preview_url)
then you have to explicitly enable it in the config and update your dependencies
dependencies. See README.rst for details.
Upgrading to v0.11.0 Upgrading to v0.11.0
==================== ====================

View file

@ -32,5 +32,4 @@ The format of the AS configuration file is as follows:
See the spec_ for further details on how application services work. See the spec_ for further details on how application services work.
.. _spec: https://github.com/matrix-org/matrix-doc/blob/master/specification/25_application_service_api.rst#application-service-api .. _spec: https://matrix.org/docs/spec/application_service/unstable.html

10
docs/log_contexts.rst Normal file
View file

@ -0,0 +1,10 @@
What do I do about "Unexpected logging context" debug log-lines everywhere?
<Mjark> The logging context lives in thread local storage
<Mjark> Sometimes it gets out of sync with what it should actually be, usually because something scheduled something to run on the reactor without preserving the logging context.
<Matthew> what is the impact of it getting out of sync? and how and when should we preserve log context?
<Mjark> The impact is that some of the CPU and database metrics will be under-reported, and some log lines will be mis-attributed.
<Mjark> It should happen auto-magically in all the APIs that do IO or otherwise defer to the reactor.
<Erik> Mjark: the other place is if we branch, e.g. using defer.gatherResults
Unanswered: how and when should we preserve log context?

58
docs/replication.rst Normal file
View file

@ -0,0 +1,58 @@
Replication Architecture
========================
Motivation
----------
We'd like to be able to split some of the work that synapse does into multiple
python processes. In theory multiple synapse processes could share a single
postgresql database and we'd scale up by running more synapse processes.
However much of synapse assumes that only one process is interacting with the
database, both for assigning unique identifiers when inserting into tables,
notifying components about new updates, and for invalidating its caches.
So running multiple copies of the current code isn't an option. One way to
run multiple processes would be to have a single writer process and multiple
reader processes connected to the same database. In order to do this we'd need
a way for the reader process to invalidate its in-memory caches when an update
happens on the writer. One way to do this is for the writer to present an
append-only log of updates which the readers can consume to invalidate their
caches and to push updates to listening clients or pushers.
Synapse already stores much of its data as an append-only log so that it can
correctly respond to /sync requests so the amount of code changes needed to
expose the append-only log to the readers should be fairly minimal.
Architecture
------------
The Replication API
~~~~~~~~~~~~~~~~~~~
Synapse will optionally expose a long poll HTTP API for extracting updates. The
API will have a similar shape to /sync in that clients provide tokens
indicating where in the log they have reached and a timeout. The synapse server
then either responds with updates immediately if it already has updates or it
waits until the timeout for more updates. If the timeout expires and nothing
happened then the server returns an empty response.
However unlike the /sync API this replication API is returning synapse specific
data rather than trying to implement a matrix specification. The replication
results are returned as arrays of rows where the rows are mostly lifted
directly from the database. This avoids unnecessary JSON parsing on the server
and hopefully avoids an impedance mismatch between the data returned and the
required updates to the datastore.
This does not replicate all the database tables as many of the database tables
are indexes that can be recovered from the contents of other tables.
The format and parameters for the api are documented in
``synapse/replication/resource.py``.
The Slaved DataStore
~~~~~~~~~~~~~~~~~~~~
There are read-only version of the synapse storage layer in
``synapse/replication/slave/storage`` that use the response of the replication
API to invalidate their caches.

74
docs/url_previews.rst Normal file
View file

@ -0,0 +1,74 @@
URL Previews
============
Design notes on a URL previewing service for Matrix:
Options are:
1. Have an AS which listens for URLs, downloads them, and inserts an event that describes their metadata.
* Pros:
* Decouples the implementation entirely from Synapse.
* Uses existing Matrix events & content repo to store the metadata.
* Cons:
* Which AS should provide this service for a room, and why should you trust it?
* Doesn't work well with E2E; you'd have to cut the AS into every room
* the AS would end up subscribing to every room anyway.
2. Have a generic preview API (nothing to do with Matrix) that provides a previewing service:
* Pros:
* Simple and flexible; can be used by any clients at any point
* Cons:
* If each HS provides one of these independently, all the HSes in a room may needlessly DoS the target URI
* We need somewhere to store the URL metadata rather than just using Matrix itself
* We can't piggyback on matrix to distribute the metadata between HSes.
3. Make the synapse of the sending user responsible for spidering the URL and inserting an event asynchronously which describes the metadata.
* Pros:
* Works transparently for all clients
* Piggy-backs nicely on using Matrix for distributing the metadata.
* No confusion as to which AS
* Cons:
* Doesn't work with E2E
* We might want to decouple the implementation of the spider from the HS, given spider behaviour can be quite complicated and evolve much more rapidly than the HS. It's more like a bot than a core part of the server.
4. Make the sending client use the preview API and insert the event itself when successful.
* Pros:
* Works well with E2E
* No custom server functionality
* Lets the client customise the preview that they send (like on FB)
* Cons:
* Entirely specific to the sending client, whereas it'd be nice if /any/ URL was correctly previewed if clients support it.
5. Have the option of specifying a shared (centralised) previewing service used by a room, to avoid all the different HSes in the room DoSing the target.
Best solution is probably a combination of both 2 and 4.
* Sending clients do their best to create and send a preview at the point of sending the message, perhaps delaying the message until the preview is computed? (This also lets the user validate the preview before sending)
* Receiving clients have the option of going and creating their own preview if one doesn't arrive soon enough (or if the original sender didn't create one)
This is a bit magical though in that the preview could come from two entirely different sources - the sending HS or your local one. However, this can always be exposed to users: "Generate your own URL previews if none are available?"
This is tantamount also to senders calculating their own thumbnails for sending in advance of the main content - we are trusting the sender not to lie about the content in the thumbnail. Whereas currently thumbnails are calculated by the receiving homeserver to avoid this attack.
However, this kind of phishing attack does exist whether we let senders pick their thumbnails or not, in that a malicious sender can send normal text messages around the attachment claiming it to be legitimate. We could rely on (future) reputation/abuse management to punish users who phish (be it with bogus metadata or bogus descriptions). Bogus metadata is particularly bad though, especially if it's avoidable.
As a first cut, let's do #2 and have the receiver hit the API to calculate its own previews (as it does currently for image thumbnails). We can then extend/optimise this to option 4 as a special extra if needed.
API
---
GET /_matrix/media/r0/preview_url?url=http://wherever.com
200 OK
{
"og:type" : "article"
"og:url" : "https://twitter.com/matrixdotorg/status/684074366691356672"
"og:title" : "Matrix on Twitter"
"og:image" : "https://pbs.twimg.com/profile_images/500400952029888512/yI0qtFi7_400x400.png"
"og:description" : "“Synapse 0.12 is out! Lots of polishing, performance &amp;amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP”"
"og:site_name" : "Twitter"
}
* Downloads the URL
* If HTML, just stores it in RAM and parses it for OG meta tags
* Download any media OG meta tags to the media repo, and refer to them in the OG via mxc:// URIs.
* If a media filetype we know we can thumbnail: store it on disk, and hand it to the thumbnailer. Generate OG meta tags from the thumbnailer contents.
* Otherwise, don't bother downloading further.

84
jenkins-dendron-postgres.sh Executable file
View file

@ -0,0 +1,84 @@
#!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install psycopg2
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .dendron-base ]]; then
git clone https://github.com/matrix-org/dendron.git .dendron-base --mirror
else
(cd .dendron-base; git fetch -p)
fi
rm -rf dendron
git clone .dendron-base dendron --shared
cd dendron
: ${GOPATH:=${WORKSPACE}/.gopath}
if [[ "${GOPATH}" != *:* ]]; then
mkdir -p "${GOPATH}"
export PATH="${GOPATH}/bin:${PATH}"
fi
export GOPATH
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
go get github.com/constabulary/gb/...
gb generate
gb build
cd ..
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8000}
./jenkins/prep_sytest_for_postgres.sh
echo >&2 "Running sytest with PostgreSQL";
./jenkins/install_and_run.sh --python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--dendron $WORKSPACE/dendron/bin/dendron \
--synchrotron \
--pusher \
--port-base $PORT_BASE
cd ..

View file

@ -25,7 +25,9 @@ rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27 tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install psycopg2 $TOX_BIN/pip install psycopg2
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"} : ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}

View file

@ -24,6 +24,8 @@ rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27 tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"} : ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}

View file

@ -1,86 +0,0 @@
#!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
TOX_BIN=$WORKSPACE/.tox/py27/bin
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PERL5LIB:=$WORKSPACE/perl5/lib/perl5}
: ${PERL_MB_OPT:=--install_base=$WORKSPACE/perl5}
: ${PERL_MM_OPT:=INSTALL_BASE=$WORKSPACE/perl5}
export PERL5LIB PERL_MB_OPT PERL_MM_OPT
./install-deps.pl
: ${PORT_BASE:=8000}
echo >&2 "Running sytest with SQLite3";
./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \
--python $TOX_BIN/python --all --port-base $PORT_BASE > results-sqlite3.tap
RUN_POSTGRES=""
for port in $(($PORT_BASE + 1)) $(($PORT_BASE + 2)); do
if psql synapse_jenkins_$port <<< ""; then
RUN_POSTGRES="$RUN_POSTGRES:$port"
cat > localhost-$port/database.yaml << EOF
name: psycopg2
args:
database: synapse_jenkins_$port
EOF
fi
done
# Run if both postgresql databases exist
if test "$RUN_POSTGRES" = ":$(($PORT_BASE + 1)):$(($PORT_BASE + 2))"; then
echo >&2 "Running sytest with PostgreSQL";
$TOX_BIN/pip install psycopg2
./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \
--python $TOX_BIN/python --all --port-base $PORT_BASE > results-postgresql.tap
else
echo >&2 "Skipping running sytest with PostgreSQL, $RUN_POSTGRES"
fi
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

View file

@ -0,0 +1,7 @@
.header {
border-bottom: 4px solid #e4f7ed ! important;
}
.notif_link a, .footer a {
color: #76CFA6 ! important;
}

156
res/templates/mail.css Normal file
View file

@ -0,0 +1,156 @@
body {
margin: 0px;
}
pre, code {
word-break: break-word;
white-space: pre-wrap;
}
#page {
font-family: 'Open Sans', Helvetica, Arial, Sans-Serif;
font-color: #454545;
font-size: 12pt;
width: 100%;
padding: 20px;
}
#inner {
width: 640px;
}
.header {
width: 100%;
height: 87px;
color: #454545;
border-bottom: 4px solid #e5e5e5;
}
.logo {
text-align: right;
margin-left: 20px;
}
.salutation {
padding-top: 10px;
font-weight: bold;
}
.summarytext {
}
.room {
width: 100%;
color: #454545;
border-bottom: 1px solid #e5e5e5;
}
.room_header td {
padding-top: 38px;
padding-bottom: 10px;
border-bottom: 1px solid #e5e5e5;
}
.room_name {
vertical-align: middle;
font-size: 18px;
font-weight: bold;
}
.room_header h2 {
margin-top: 0px;
margin-left: 75px;
font-size: 20px;
}
.room_avatar {
width: 56px;
line-height: 0px;
text-align: center;
vertical-align: middle;
}
.room_avatar img {
width: 48px;
height: 48px;
object-fit: cover;
border-radius: 24px;
}
.notif {
border-bottom: 1px solid #e5e5e5;
margin-top: 16px;
padding-bottom: 16px;
}
.historical_message .sender_avatar {
opacity: 0.3;
}
/* spell out opacity and historical_message class names for Outlook aka Word */
.historical_message .sender_name {
color: #e3e3e3;
}
.historical_message .message_time {
color: #e3e3e3;
}
.historical_message .message_body {
color: #c7c7c7;
}
.historical_message td,
.message td {
padding-top: 10px;
}
.sender_avatar {
width: 56px;
text-align: center;
vertical-align: top;
}
.sender_avatar img {
margin-top: -2px;
width: 32px;
height: 32px;
border-radius: 16px;
}
.sender_name {
display: inline;
font-size: 13px;
color: #a2a2a2;
}
.message_time {
text-align: right;
width: 100px;
font-size: 11px;
color: #a2a2a2;
}
.message_body {
}
.notif_link td {
padding-top: 10px;
padding-bottom: 10px;
font-weight: bold;
}
.notif_link a, .footer a {
color: #454545;
text-decoration: none;
}
.debug {
font-size: 10px;
color: #888;
}
.footer {
margin-top: 20px;
text-align: center;
}

45
res/templates/notif.html Normal file
View file

@ -0,0 +1,45 @@
{% for message in notif.messages %}
<tr class="{{ "historical_message" if message.is_historical else "message" }}">
<td class="sender_avatar">
{% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
{% if message.sender_avatar_url %}
<img alt="" class="sender_avatar" src="{{ message.sender_avatar_url|mxc_to_http(32,32) }}" />
{% else %}
{% if message.sender_hash % 3 == 0 %}
<img class="sender_avatar" src="https://vector.im/beta/img/76cfa6.png" />
{% elif message.sender_hash % 3 == 1 %}
<img class="sender_avatar" src="https://vector.im/beta/img/50e2c2.png" />
{% else %}
<img class="sender_avatar" src="https://vector.im/beta/img/f4c371.png" />
{% endif %}
{% endif %}
{% endif %}
</td>
<td class="message_contents">
{% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
<div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div>
{% endif %}
<div class="message_body">
{% if message.msgtype == "m.text" %}
{{ message.body_text_html }}
{% elif message.msgtype == "m.emote" %}
{{ message.body_text_html }}
{% elif message.msgtype == "m.notice" %}
{{ message.body_text_html }}
{% elif message.msgtype == "m.image" %}
<img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
{% elif message.msgtype == "m.file" %}
<span class="filename">{{ message.body_text_plain }}</span>
{% endif %}
</div>
</td>
<td class="message_time">{{ message.ts|format_ts("%H:%M") }}</td>
</tr>
{% endfor %}
<tr class="notif_link">
<td></td>
<td>
<a href="{{ notif.link }}">View {{ room.title }}</a>
</td>
<td></td>
</tr>

16
res/templates/notif.txt Normal file
View file

@ -0,0 +1,16 @@
{% for message in notif.messages %}
{% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
{% if message.msgtype == "m.text" %}
{{ message.body_text_plain }}
{% elif message.msgtype == "m.emote" %}
{{ message.body_text_plain }}
{% elif message.msgtype == "m.notice" %}
{{ message.body_text_plain }}
{% elif message.msgtype == "m.image" %}
{{ message.body_text_plain }}
{% elif message.msgtype == "m.file" %}
{{ message.body_text_plain }}
{% endif %}
{% endfor %}
View {{ room.title }} at {{ notif.link }}

View file

@ -0,0 +1,53 @@
<!doctype html>
<html lang="en">
<head>
<style type="text/css">
{% include 'mail.css' without context %}
{% include "mail-%s.css" % app_name ignore missing without context %}
</style>
</head>
<body>
<table id="page">
<tr>
<td> </td>
<td id="inner">
<table class="header">
<tr>
<td>
<div class="salutation">Hi {{ user_display_name }},</div>
<div class="summarytext">{{ summary_text }}</div>
</td>
<td class="logo">
{% if app_name == "Vector" %}
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
{% else %}
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
{% endif %}
</td>
</tr>
</table>
{% for room in rooms %}
{% include 'room.html' with context %}
{% endfor %}
<div class="footer">
<a href="{{ unsubscribe_link }}">Unsubscribe</a>
<br/>
<br/>
<div class="debug">
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
an event was received at {{ reason.received_at|format_ts("%c") }}
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} (delay_before_mail_ms) mins ago,
{% if reason.last_sent_ts %}
and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
{% else %}
and we don't have a last time we sent a mail for this room.
{% endif %}
</div>
</div>
</td>
<td> </td>
</tr>
</table>
</body>
</html>

View file

@ -0,0 +1,10 @@
Hi {{ user_display_name }},
{{ summary_text }}
{% for room in rooms %}
{% include 'room.txt' with context %}
{% endfor %}
You can disable these notifications at {{ unsubscribe_link }}

33
res/templates/room.html Normal file
View file

@ -0,0 +1,33 @@
<table class="room">
<tr class="room_header">
<td class="room_avatar">
{% if room.avatar_url %}
<img alt="" src="{{ room.avatar_url|mxc_to_http(48,48) }}" />
{% else %}
{% if room.hash % 3 == 0 %}
<img alt="" src="https://vector.im/beta/img/76cfa6.png" />
{% elif room.hash % 3 == 1 %}
<img alt="" src="https://vector.im/beta/img/50e2c2.png" />
{% else %}
<img alt="" src="https://vector.im/beta/img/f4c371.png" />
{% endif %}
{% endif %}
</td>
<td class="room_name" colspan="2">
{{ room.title }}
</td>
</tr>
{% if room.invite %}
<tr>
<td></td>
<td>
<a href="{{ room.link }}">Join the conversation.</a>
</td>
<td></td>
</tr>
{% else %}
{% for notif in room.notifs %}
{% include 'notif.html' with context %}
{% endfor %}
{% endif %}
</table>

9
res/templates/room.txt Normal file
View file

@ -0,0 +1,9 @@
{{ room.title }}
{% if room.invite %}
You've been invited, join at {{ room.link }}
{% else %}
{% for notif in room.notifs %}
{% include 'notif.txt' with context %}
{% endfor %}
{% endif %}

View file

@ -19,6 +19,7 @@ from twisted.enterprise import adbapi
from synapse.storage._base import LoggingTransaction, SQLBaseStore from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
import argparse import argparse
import curses import curses
@ -37,6 +38,7 @@ BOOLEAN_COLUMNS = {
"rooms": ["is_public"], "rooms": ["is_public"],
"event_edges": ["is_state"], "event_edges": ["is_state"],
"presence_list": ["accepted"], "presence_list": ["accepted"],
"presence_stream": ["currently_active"],
} }
@ -212,6 +214,10 @@ class Porter(object):
self.progress.add_table(table, postgres_size, table_size) self.progress.add_table(table, postgres_size, table_size)
if table == "event_search":
yield self.handle_search_table(postgres_size, table_size, next_chunk)
return
select = ( select = (
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
% (table,) % (table,)
@ -230,51 +236,19 @@ class Porter(object):
if rows: if rows:
next_chunk = rows[-1][0] + 1 next_chunk = rows[-1][0] + 1
if table == "event_search": self._convert_rows(table, headers, rows)
# We have to treat event_search differently since it has a
# different structure in the two different databases.
def insert(txn):
sql = (
"INSERT INTO event_search (event_id, room_id, key, sender, vector)"
" VALUES (?,?,?,?,to_tsvector('english', ?))"
)
rows_dict = [ def insert(txn):
dict(zip(headers, row)) self.postgres_store.insert_many_txn(
for row in rows txn, table, headers[1:], rows
] )
txn.executemany(sql, [ self.postgres_store._simple_update_one_txn(
( txn,
row["event_id"], table="port_from_sqlite3",
row["room_id"], keyvalues={"table_name": table},
row["key"], updatevalues={"rowid": next_chunk},
row["sender"], )
row["value"],
)
for row in rows_dict
])
self.postgres_store._simple_update_one_txn(
txn,
table="port_from_sqlite3",
keyvalues={"table_name": table},
updatevalues={"rowid": next_chunk},
)
else:
self._convert_rows(table, headers, rows)
def insert(txn):
self.postgres_store.insert_many_txn(
txn, table, headers[1:], rows
)
self.postgres_store._simple_update_one_txn(
txn,
table="port_from_sqlite3",
keyvalues={"table_name": table},
updatevalues={"rowid": next_chunk},
)
yield self.postgres_store.execute(insert) yield self.postgres_store.execute(insert)
@ -284,6 +258,73 @@ class Porter(object):
else: else:
return return
@defer.inlineCallbacks
def handle_search_table(self, postgres_size, table_size, next_chunk):
select = (
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
" FROM event_search as es"
" INNER JOIN events AS e USING (event_id, room_id)"
" WHERE es.rowid >= ?"
" ORDER BY es.rowid LIMIT ?"
)
while True:
def r(txn):
txn.execute(select, (next_chunk, self.batch_size,))
rows = txn.fetchall()
headers = [column[0] for column in txn.description]
return headers, rows
headers, rows = yield self.sqlite_store.runInteraction("select", r)
if rows:
next_chunk = rows[-1][0] + 1
# We have to treat event_search differently since it has a
# different structure in the two different databases.
def insert(txn):
sql = (
"INSERT INTO event_search (event_id, room_id, key,"
" sender, vector, origin_server_ts, stream_ordering)"
" VALUES (?,?,?,?,to_tsvector('english', ?),?,?)"
)
rows_dict = [
dict(zip(headers, row))
for row in rows
]
txn.executemany(sql, [
(
row["event_id"],
row["room_id"],
row["key"],
row["sender"],
row["value"],
row["origin_server_ts"],
row["stream_ordering"],
)
for row in rows_dict
])
self.postgres_store._simple_update_one_txn(
txn,
table="port_from_sqlite3",
keyvalues={"table_name": "event_search"},
updatevalues={"rowid": next_chunk},
)
yield self.postgres_store.execute(insert)
postgres_size += len(rows)
self.progress.update("event_search", postgres_size)
else:
return
def setup_db(self, db_config, database_engine): def setup_db(self, db_config, database_engine):
db_conn = database_engine.module.connect( db_conn = database_engine.module.connect(
**{ **{
@ -292,7 +333,7 @@ class Porter(object):
} }
) )
database_engine.prepare_database(db_conn) prepare_database(db_conn, database_engine, config=None)
db_conn.commit() db_conn.commit()
@ -309,8 +350,8 @@ class Porter(object):
**self.postgres_config["args"] **self.postgres_config["args"]
) )
sqlite_engine = create_engine(FakeConfig(sqlite_config)) sqlite_engine = create_engine(sqlite_config)
postgres_engine = create_engine(FakeConfig(postgres_config)) postgres_engine = create_engine(postgres_config)
self.sqlite_store = Store(sqlite_db_pool, sqlite_engine) self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
self.postgres_store = Store(postgres_db_pool, postgres_engine) self.postgres_store = Store(postgres_db_pool, postgres_engine)
@ -792,8 +833,3 @@ if __name__ == "__main__":
if end_error_exec_info: if end_error_exec_info:
exc_type, exc_value, exc_traceback = end_error_exec_info exc_type, exc_value, exc_traceback = end_error_exec_info
traceback.print_exception(exc_type, exc_value, exc_traceback) traceback.print_exception(exc_type, exc_value, exc_traceback)
class FakeConfig:
def __init__(self, database_config):
self.database_config = database_config

View file

@ -17,3 +17,6 @@ ignore =
[flake8] [flake8]
max-line-length = 90 max-line-length = 90
ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it. ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
[pep8]
max-line-length = 90

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.14.0" __version__ = "0.16.0"

View file

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""This module contains classes for authenticating the user."""
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException from signedjson.sign import verify_signed_json, SignatureVerifyException
@ -22,9 +21,10 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import Requester, RoomID, UserID, EventID from synapse.types import Requester, UserID, get_domain_from_id
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.metrics import Measure
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
import logging import logging
@ -41,12 +41,20 @@ AuthEventTypes = (
class Auth(object): class Auth(object):
"""
FIXME: This class contains a mix of functions for authenticating users
of our client-server API and authenticating events added to room graphs.
"""
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
# Docs for these currently lives at
# https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
# In addition, we have type == delete_pusher which grants access only to
# delete pushers.
self._KNOWN_CAVEAT_PREFIXES = set([ self._KNOWN_CAVEAT_PREFIXES = set([
"gen = ", "gen = ",
"guest = ", "guest = ",
@ -66,9 +74,9 @@ class Auth(object):
Returns: Returns:
True if the auth checks pass. True if the auth checks pass.
""" """
self.check_size_limits(event) with Measure(self.clock, "auth.check"):
self.check_size_limits(event)
try:
if not hasattr(event, "room_id"): if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event) raise AuthError(500, "Event has no room_id: %s" % event)
if auth_events is None: if auth_events is None:
@ -89,8 +97,8 @@ class Auth(object):
"Room %r does not exist" % (event.room_id,) "Room %r does not exist" % (event.room_id,)
) )
creating_domain = RoomID.from_string(event.room_id).domain creating_domain = get_domain_from_id(event.room_id)
originating_domain = UserID.from_string(event.sender).domain originating_domain = get_domain_from_id(event.sender)
if creating_domain != originating_domain: if creating_domain != originating_domain:
if not self.can_federate(event, auth_events): if not self.can_federate(event, auth_events):
raise AuthError( raise AuthError(
@ -118,6 +126,24 @@ class Auth(object):
return allowed return allowed
self.check_event_sender_in_room(event, auth_events) self.check_event_sender_in_room(event, auth_events)
# Special case to allow m.room.third_party_invite events wherever
# a user is allowed to issue invites. Fixes
# https://github.com/vector-im/vector-web/issues/1208 hopefully
if event.type == EventTypes.ThirdPartyInvite:
user_level = self._get_user_power_level(event.user_id, auth_events)
invite_level = self._get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, (
"You cannot issue a third party invite for %s." %
(event.content.display_name,)
)
)
else:
return True
self._can_send_event(event, auth_events) self._can_send_event(event, auth_events)
if event.type == EventTypes.PowerLevels: if event.type == EventTypes.PowerLevels:
@ -127,13 +153,6 @@ class Auth(object):
self.check_redaction(event, auth_events) self.check_redaction(event, auth_events)
logger.debug("Allowing! %s", event) logger.debug("Allowing! %s", event)
except AuthError as e:
logger.info(
"Event auth check failed on event %s with msg: %s",
event, e.msg
)
logger.info("Denying! %s", event)
raise
def check_size_limits(self, event): def check_size_limits(self, event):
def too_big(field): def too_big(field):
@ -224,7 +243,7 @@ class Auth(object):
for event in curr_state.values(): for event in curr_state.values():
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
try: try:
if UserID.from_string(event.state_key).domain != host: if get_domain_from_id(event.state_key) != host:
continue continue
except: except:
logger.warn("state_key not user_id: %s", event.state_key) logger.warn("state_key not user_id: %s", event.state_key)
@ -271,8 +290,8 @@ class Auth(object):
target_user_id = event.state_key target_user_id = event.state_key
creating_domain = RoomID.from_string(event.room_id).domain creating_domain = get_domain_from_id(event.room_id)
target_domain = UserID.from_string(target_user_id).domain target_domain = get_domain_from_id(target_user_id)
if creating_domain != target_domain: if creating_domain != target_domain:
if not self.can_federate(event, auth_events): if not self.can_federate(event, auth_events):
raise AuthError( raise AuthError(
@ -512,7 +531,7 @@ class Auth(object):
return default return default
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_req(self, request, allow_guest=False): def get_user_by_req(self, request, allow_guest=False, rights="access"):
""" Get a registered user's ID. """ Get a registered user's ID.
Args: Args:
@ -534,7 +553,7 @@ class Auth(object):
) )
access_token = request.args["access_token"][0] access_token = request.args["access_token"][0]
user_info = yield self.get_user_by_access_token(access_token) user_info = yield self.get_user_by_access_token(access_token, rights)
user = user_info["user"] user = user_info["user"]
token_id = user_info["token_id"] token_id = user_info["token_id"]
is_guest = user_info["is_guest"] is_guest = user_info["is_guest"]
@ -595,7 +614,7 @@ class Auth(object):
defer.returnValue(user_id) defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_access_token(self, token): def get_user_by_access_token(self, token, rights="access"):
""" Get a registered user's ID. """ Get a registered user's ID.
Args: Args:
@ -606,7 +625,7 @@ class Auth(object):
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
try: try:
ret = yield self.get_user_from_macaroon(token) ret = yield self.get_user_from_macaroon(token, rights)
except AuthError: except AuthError:
# TODO(daniel): Remove this fallback when all existing access tokens # TODO(daniel): Remove this fallback when all existing access tokens
# have been re-issued as macaroons. # have been re-issued as macaroons.
@ -614,10 +633,11 @@ class Auth(object):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_from_macaroon(self, macaroon_str): def get_user_from_macaroon(self, macaroon_str, rights="access"):
try: try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
self.validate_macaroon(macaroon, "access", False)
self.validate_macaroon(macaroon, rights, self.hs.config.expire_access_token)
user_prefix = "user_id = " user_prefix = "user_id = "
user = None user = None
@ -640,6 +660,13 @@ class Auth(object):
"is_guest": True, "is_guest": True,
"token_id": None, "token_id": None,
} }
elif rights == "delete_pusher":
# We don't store these tokens in the database
ret = {
"user": user,
"is_guest": False,
"token_id": None,
}
else: else:
# This codepath exists so that we can actually return a # This codepath exists so that we can actually return a
# token ID, because we use token IDs in place of device # token ID, because we use token IDs in place of device
@ -671,7 +698,8 @@ class Auth(object):
Args: Args:
macaroon(pymacaroons.Macaroon): The macaroon to validate macaroon(pymacaroons.Macaroon): The macaroon to validate
type_string(str): The kind of token this is (e.g. "access", "refresh") type_string(str): The kind of token required (e.g. "access", "refresh",
"delete_pusher")
verify_expiry(bool): Whether to verify whether the macaroon has expired. verify_expiry(bool): Whether to verify whether the macaroon has expired.
This should really always be True, but no clients currently implement This should really always be True, but no clients currently implement
token refresh, so we can't enforce expiry yet. token refresh, so we can't enforce expiry yet.
@ -894,8 +922,8 @@ class Auth(object):
if user_level >= redact_level: if user_level >= redact_level:
return False return False
redacter_domain = EventID.from_string(event.event_id).domain redacter_domain = get_domain_from_id(event.event_id)
redactee_domain = EventID.from_string(event.redacts).domain redactee_domain = get_domain_from_id(event.redacts)
if redacter_domain == redactee_domain: if redacter_domain == redactee_domain:
return True return True

View file

@ -15,6 +15,8 @@
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.types import UserID, RoomID from synapse.types import UserID, RoomID
from twisted.internet import defer
import ujson as json import ujson as json
@ -24,10 +26,10 @@ class Filtering(object):
super(Filtering, self).__init__() super(Filtering, self).__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks
def get_user_filter(self, user_localpart, filter_id): def get_user_filter(self, user_localpart, filter_id):
result = self.store.get_user_filter(user_localpart, filter_id) result = yield self.store.get_user_filter(user_localpart, filter_id)
result.addCallback(FilterCollection) defer.returnValue(FilterCollection(result))
return result
def add_user_filter(self, user_localpart, user_filter): def add_user_filter(self, user_localpart, user_filter):
self.check_valid_filter(user_filter) self.check_valid_filter(user_filter)

View file

@ -16,14 +16,10 @@
import synapse import synapse
import contextlib import gc
import logging import logging
import os import os
import re
import resource
import subprocess
import sys import sys
import time
from synapse.config._base import ConfigError from synapse.config._base import ConfigError
from synapse.python_dependencies import ( from synapse.python_dependencies import (
@ -33,22 +29,15 @@ from synapse.python_dependencies import (
from synapse.rest import ClientRestResource from synapse.rest import ClientRestResource
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import are_all_users_on_domain from synapse.storage import are_all_users_on_domain
from synapse.storage.prepare_database import UpgradeDatabaseException from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
from synapse.server import HomeServer from synapse.server import HomeServer
from twisted.conch.manhole import ColoredManhole
from twisted.conch.insults import insults
from twisted.conch import manhole_ssh
from twisted.cred import checkers, portal
from twisted.internet import reactor, task, defer from twisted.internet import reactor, task, defer
from twisted.application import service from twisted.application import service
from twisted.web.resource import Resource, EncodingResourceWrapper from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.static import File from twisted.web.static import File
from twisted.web.server import Site, GzipEncoderFactory, Request from twisted.web.server import GzipEncoderFactory
from synapse.http.server import RootRedirect from synapse.http.server import RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource
@ -66,6 +55,13 @@ from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer from synapse.federation.transport.server import TransportLayerServer
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.http.site import SynapseSite
from synapse import events from synapse import events
from daemonize import Daemonize from daemonize import Daemonize
@ -73,9 +69,6 @@ from daemonize import Daemonize
logger = logging.getLogger("synapse.app.homeserver") logger = logging.getLogger("synapse.app.homeserver")
ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
def gz_wrap(r): def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()]) return EncodingResourceWrapper(r, [GzipEncoderFactory()])
@ -173,7 +166,12 @@ class SynapseHomeServer(HomeServer):
if name == "replication": if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationResource(self) resources[REPLICATION_PREFIX] = ReplicationResource(self)
root_resource = create_resource_tree(resources) if WEB_CLIENT_PREFIX in resources:
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
else:
root_resource = Resource()
root_resource = create_resource_tree(resources, root_resource)
if tls: if tls:
reactor.listenSSL( reactor.listenSSL(
port, port,
@ -206,24 +204,13 @@ class SynapseHomeServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listener_http(config, listener) self._listener_http(config, listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(
matrix="rabbithole"
)
rlm = manhole_ssh.TerminalRealm()
rlm.chainedProtocolFactory = lambda: insults.ServerProtocol(
ColoredManhole,
{
"__name__": "__console__",
"hs": self,
}
)
f = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
reactor.listenTCP( reactor.listenTCP(
listener["port"], listener["port"],
f, manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1') interface=listener.get("bind_address", '127.0.0.1')
) )
else: else:
@ -245,7 +232,7 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e: except IncorrectDatabaseSetup as e:
quit_with_error(e.message) quit_with_error(e.message)
def get_db_conn(self): def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should # Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine. # not be passed to the database engine.
db_params = { db_params = {
@ -254,7 +241,8 @@ class SynapseHomeServer(HomeServer):
} }
db_conn = self.database_engine.module.connect(**db_params) db_conn = self.database_engine.module.connect(**db_params)
self.database_engine.on_new_connection(db_conn) if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn return db_conn
@ -268,86 +256,6 @@ def quit_with_error(error_string):
sys.exit(1) sys.exit(1)
def get_version_string():
try:
null = open(os.devnull, 'w')
cwd = os.path.dirname(os.path.abspath(__file__))
try:
git_branch = subprocess.check_output(
['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
stderr=null,
cwd=cwd,
).strip()
git_branch = "b=" + git_branch
except subprocess.CalledProcessError:
git_branch = ""
try:
git_tag = subprocess.check_output(
['git', 'describe', '--exact-match'],
stderr=null,
cwd=cwd,
).strip()
git_tag = "t=" + git_tag
except subprocess.CalledProcessError:
git_tag = ""
try:
git_commit = subprocess.check_output(
['git', 'rev-parse', '--short', 'HEAD'],
stderr=null,
cwd=cwd,
).strip()
except subprocess.CalledProcessError:
git_commit = ""
try:
dirty_string = "-this_is_a_dirty_checkout"
is_dirty = subprocess.check_output(
['git', 'describe', '--dirty=' + dirty_string],
stderr=null,
cwd=cwd,
).strip().endswith(dirty_string)
git_dirty = "dirty" if is_dirty else ""
except subprocess.CalledProcessError:
git_dirty = ""
if git_branch or git_tag or git_commit or git_dirty:
git_version = ",".join(
s for s in
(git_branch, git_tag, git_commit, git_dirty,)
if s
)
return (
"Synapse/%s (%s)" % (
synapse.__version__, git_version,
)
).encode("ascii")
except Exception as e:
logger.info("Failed to check for git repository: %s", e)
return ("Synapse/%s" % (synapse.__version__,)).encode("ascii")
def change_resource_limit(soft_file_no):
try:
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
if not soft_file_no:
soft_file_no = hard
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_file_no, hard))
logger.info("Set file limit to: %d", soft_file_no)
resource.setrlimit(
resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY)
)
except (ValueError, resource.error) as e:
logger.warn("Failed to set file or core limit: %s", e)
def setup(config_options): def setup(config_options):
""" """
Args: Args:
@ -377,7 +285,7 @@ def setup(config_options):
# check any extra requirements we have now we have a config # check any extra requirements we have now we have a config
check_requirements(config) check_requirements(config)
version_string = get_version_string() version_string = get_version_string("Synapse", synapse)
logger.info("Server hostname: %s", config.server_name) logger.info("Server hostname: %s", config.server_name)
logger.info("Server version: %s", version_string) logger.info("Server version: %s", version_string)
@ -386,7 +294,7 @@ def setup(config_options):
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
database_engine = create_engine(config) database_engine = create_engine(config.database_config)
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
hs = SynapseHomeServer( hs = SynapseHomeServer(
@ -402,8 +310,10 @@ def setup(config_options):
logger.info("Preparing database: %s...", config.database_config['name']) logger.info("Preparing database: %s...", config.database_config['name'])
try: try:
db_conn = hs.get_db_conn() db_conn = hs.get_db_conn(run_new_connection=False)
database_engine.prepare_database(db_conn) prepare_database(db_conn, database_engine, config=config)
database_engine.on_new_connection(db_conn)
hs.run_startup_checks(db_conn, database_engine) hs.run_startup_checks(db_conn, database_engine)
db_conn.commit() db_conn.commit()
@ -442,215 +352,13 @@ class SynapseService(service.Service):
def startService(self): def startService(self):
hs = setup(self.config) hs = setup(self.config)
change_resource_limit(hs.config.soft_file_limit) change_resource_limit(hs.config.soft_file_limit)
if hs.config.gc_thresholds:
gc.set_threshold(*hs.config.gc_thresholds)
def stopService(self): def stopService(self):
return self._port.stopListening() return self._port.stopListening()
class SynapseRequest(Request):
def __init__(self, site, *args, **kw):
Request.__init__(self, *args, **kw)
self.site = site
self.authenticated_entity = None
self.start_time = 0
def __repr__(self):
# We overwrite this so that we don't log ``access_token``
return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
self.__class__.__name__,
id(self),
self.method,
self.get_redacted_uri(),
self.clientproto,
self.site.site_tag,
)
def get_redacted_uri(self):
return ACCESS_TOKEN_RE.sub(
r'\1<redacted>\3',
self.uri
)
def get_user_agent(self):
return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
def started_processing(self):
self.site.access_logger.info(
"%s - %s - Received request: %s %s",
self.getClientIP(),
self.site.site_tag,
self.method,
self.get_redacted_uri()
)
self.start_time = int(time.time() * 1000)
def finished_processing(self):
try:
context = LoggingContext.current_context()
ru_utime, ru_stime = context.get_resource_usage()
db_txn_count = context.db_txn_count
db_txn_duration = context.db_txn_duration
except:
ru_utime, ru_stime = (0, 0)
db_txn_count, db_txn_duration = (0, 0)
self.site.access_logger.info(
"%s - %s - {%s}"
" Processed request: %dms (%dms, %dms) (%dms/%d)"
" %sB %s \"%s %s %s\" \"%s\"",
self.getClientIP(),
self.site.site_tag,
self.authenticated_entity,
int(time.time() * 1000) - self.start_time,
int(ru_utime * 1000),
int(ru_stime * 1000),
int(db_txn_duration * 1000),
int(db_txn_count),
self.sentLength,
self.code,
self.method,
self.get_redacted_uri(),
self.clientproto,
self.get_user_agent(),
)
@contextlib.contextmanager
def processing(self):
self.started_processing()
yield
self.finished_processing()
class XForwardedForRequest(SynapseRequest):
def __init__(self, *args, **kw):
SynapseRequest.__init__(self, *args, **kw)
"""
Add a layer on top of another request that only uses the value of an
X-Forwarded-For header as the result of C{getClientIP}.
"""
def getClientIP(self):
"""
@return: The client address (the first address) in the value of the
I{X-Forwarded-For header}. If the header is not present, return
C{b"-"}.
"""
return self.requestHeaders.getRawHeaders(
b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
class SynapseRequestFactory(object):
def __init__(self, site, x_forwarded_for):
self.site = site
self.x_forwarded_for = x_forwarded_for
def __call__(self, *args, **kwargs):
if self.x_forwarded_for:
return XForwardedForRequest(self.site, *args, **kwargs)
else:
return SynapseRequest(self.site, *args, **kwargs)
class SynapseSite(Site):
"""
Subclass of a twisted http Site that does access logging with python's
standard logging
"""
def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
Site.__init__(self, resource, *args, **kwargs)
self.site_tag = site_tag
proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name)
def log(self, request):
pass
def create_resource_tree(desired_tree, redirect_root_to_web_client=True):
"""Create the resource tree for this Home Server.
This in unduly complicated because Twisted does not support putting
child resources more than 1 level deep at a time.
Args:
web_client (bool): True to enable the web client.
redirect_root_to_web_client (bool): True to redirect '/' to the
location of the web client. This does nothing if web_client is not
True.
"""
if redirect_root_to_web_client and WEB_CLIENT_PREFIX in desired_tree:
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
else:
root_resource = Resource()
# ideally we'd just use getChild and putChild but getChild doesn't work
# unless you give it a Request object IN ADDITION to the name :/ So
# instead, we'll store a copy of this mapping so we can actually add
# extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {}
for full_path, res in desired_tree.items():
logger.info("Attaching %s to path %s", res, full_path)
last_resource = root_resource
for path_seg in full_path.split('/')[1:-1]:
if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
child_resource = Resource()
last_resource.putChild(path_seg, child_resource)
res_id = _resource_id(last_resource, path_seg)
resource_mappings[res_id] = child_resource
last_resource = child_resource
else:
# we have an existing Resource, use that instead.
res_id = _resource_id(last_resource, path_seg)
last_resource = resource_mappings[res_id]
# ===========================
# now attach the actual desired resource
last_path_seg = full_path.split('/')[-1]
# if there is already a resource here, thieve its children and
# replace it
res_id = _resource_id(last_resource, last_path_seg)
if res_id in resource_mappings:
# there is a dummy resource at this path already, which needs
# to be replaced with the desired resource.
existing_dummy_resource = resource_mappings[res_id]
for child_name in existing_dummy_resource.listNames():
child_res_id = _resource_id(
existing_dummy_resource, child_name
)
child_resource = resource_mappings[child_res_id]
# steal the children
res.putChild(child_name, child_resource)
# finally, insert the desired resource in the right place
last_resource.putChild(last_path_seg, res)
res_id = _resource_id(last_resource, last_path_seg)
resource_mappings[res_id] = res
return root_resource
def _resource_id(resource, path_seg):
"""Construct an arbitrary resource ID so you can retrieve the mapping
later.
If you want to represent resource A putChild resource B with path C,
the mapping should looks like _resource_id(A,C) = B.
Args:
resource (Resource): The *parent* Resourceb
path_seg (str): The name of the child Resource to be attached.
Returns:
str: A unique string which can be a key to the child Resource.
"""
return "%s-%s" % (resource, path_seg)
def run(hs): def run(hs):
PROFILE_SYNAPSE = False PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE: if PROFILE_SYNAPSE:
@ -717,6 +425,8 @@ def run(hs):
# sys.settrace(logcontext_tracer) # sys.settrace(logcontext_tracer)
with LoggingContext("run"): with LoggingContext("run"):
change_resource_limit(hs.config.soft_file_limit) change_resource_limit(hs.config.soft_file_limit)
if hs.config.gc_thresholds:
gc.set_threshold(*hs.config.gc_thresholds)
reactor.run() reactor.run()
if hs.config.daemonize: if hs.config.daemonize:

396
synapse/app/pusher.py Normal file
View file

@ -0,0 +1,396 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.server import HomeServer
from synapse.config._base import ConfigError
from synapse.config.database import DatabaseConfig
from synapse.config.logger import LoggingConfig
from synapse.config.emailconfig import EmailConfig
from synapse.config.key import KeyConfig
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.storage.roommember import RoomMemberStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.pushers import SlavedPusherStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.storage.engines import create_engine
from synapse.storage import DataStore
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import gc
import sys
import logging
logger = logging.getLogger("synapse.app.pusher")
class SlaveConfig(DatabaseConfig):
def read_config(self, config):
self.replication_url = config["replication_url"]
self.server_name = config["server_name"]
self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
"use_insecure_ssl_client_just_for_testing_do_not_use", False
)
self.user_agent_suffix = None
self.start_pushers = True
self.listeners = config["listeners"]
self.soft_file_limit = config.get("soft_file_limit")
self.daemonize = config.get("daemonize")
self.pid_file = self.abspath(config.get("pid_file"))
self.public_baseurl = config["public_baseurl"]
thresholds = config.get("gc_thresholds", None)
if thresholds is not None:
try:
assert len(thresholds) == 3
self.gc_thresholds = (
int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
)
except:
raise ConfigError(
"Value of `gc_threshold` must be a list of three integers if set"
)
else:
self.gc_thresholds = None
# some things used by the auth handler but not actually used in the
# pusher codebase
self.bcrypt_rounds = None
self.ldap_enabled = None
self.ldap_server = None
self.ldap_port = None
self.ldap_tls = None
self.ldap_search_base = None
self.ldap_search_property = None
self.ldap_email_property = None
self.ldap_full_name_property = None
# We would otherwise try to use the registration shared secret as the
# macaroon shared secret if there was no macaroon_shared_secret, but
# that means pulling in RegistrationConfig too. We don't need to be
# backwards compaitible in the pusher codebase so just make people set
# macaroon_shared_secret. We set this to None to prevent it referencing
# an undefined key.
self.registration_shared_secret = None
def default_config(self, server_name, **kwargs):
pid_file = self.abspath("pusher.pid")
return """\
# Slave configuration
# The replication listener on the synapse to talk to.
#replication_url: https://localhost:{replication_port}/_synapse/replication
server_name: "%(server_name)s"
listeners: []
# Enable a ssh manhole listener on the pusher.
# - type: manhole
# port: {manhole_port}
# bind_address: 127.0.0.1
# Enable a metric listener on the pusher.
# - type: http
# port: {metrics_port}
# bind_address: 127.0.0.1
# resources:
# - names: ["metrics"]
# compress: False
report_stats: False
daemonize: False
pid_file: %(pid_file)s
""" % locals()
class PusherSlaveConfig(SlaveConfig, LoggingConfig, EmailConfig, KeyConfig):
pass
class PusherSlaveStore(
SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore,
SlavedAccountDataStore
):
update_pusher_last_stream_ordering_and_success = (
DataStore.update_pusher_last_stream_ordering_and_success.__func__
)
update_pusher_failing_since = (
DataStore.update_pusher_failing_since.__func__
)
update_pusher_last_stream_ordering = (
DataStore.update_pusher_last_stream_ordering.__func__
)
get_throttle_params_by_room = (
DataStore.get_throttle_params_by_room.__func__
)
set_throttle_params = (
DataStore.set_throttle_params.__func__
)
get_time_of_last_push_action_before = (
DataStore.get_time_of_last_push_action_before.__func__
)
get_profile_displayname = (
DataStore.get_profile_displayname.__func__
)
# XXX: This is a bit broken because we don't persist forgotten rooms
# in a way that they can be streamed. This means that we don't have a
# way to invalidate the forgotten rooms cache correctly.
# For now we expire the cache every 10 minutes.
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
)
class PusherServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = PusherSlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def remove_pusher(self, app_id, push_key, user_id):
http_client = self.get_simple_http_client()
replication_url = self.config.replication_url
url = replication_url + "/remove_pushers"
return http_client.post_json_get_json(url, {
"remove": [{
"app_id": app_id,
"push_key": push_key,
"user_id": user_id,
}]
})
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse pusher now listening on port %d", port)
def start_listening(self):
for listener in self.config.listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.replication_url
pusher_pool = self.get_pusherpool()
clock = self.get_clock()
def stop_pusher(user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey)
pushers_for_user = pusher_pool.pushers.get(user_id, {})
pusher = pushers_for_user.pop(key, None)
if pusher is None:
return
logger.info("Stopping pusher %r / %r", user_id, key)
pusher.on_stop()
def start_pusher(user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey)
logger.info("Starting pusher %r / %r", user_id, key)
return pusher_pool._refresh_pusher(app_id, pushkey, user_id)
@defer.inlineCallbacks
def poke_pushers(results):
pushers_rows = set(
map(tuple, results.get("pushers", {}).get("rows", []))
)
deleted_pushers_rows = set(
map(tuple, results.get("deleted_pushers", {}).get("rows", []))
)
for row in sorted(pushers_rows | deleted_pushers_rows):
if row in deleted_pushers_rows:
user_id, app_id, pushkey = row[1:4]
stop_pusher(user_id, app_id, pushkey)
elif row in pushers_rows:
user_id = row[1]
app_id = row[5]
pushkey = row[8]
yield start_pusher(user_id, app_id, pushkey)
stream = results.get("events")
if stream:
min_stream_id = stream["rows"][0][0]
max_stream_id = stream["position"]
preserve_fn(pusher_pool.on_new_notifications)(
min_stream_id, max_stream_id
)
stream = results.get("receipts")
if stream:
rows = stream["rows"]
affected_room_ids = set(row[1] for row in rows)
min_stream_id = rows[0][0]
max_stream_id = stream["position"]
preserve_fn(pusher_pool.on_new_receipts)(
min_stream_id, max_stream_id, affected_room_ids
)
def expire_broken_caches():
store.who_forgot_in_room.invalidate_all()
next_expire_broken_caches_ms = 0
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
now_ms = clock.time_msec()
if now_ms > next_expire_broken_caches_ms:
expire_broken_caches()
next_expire_broken_caches_ms = (
now_ms + store.BROKEN_CACHE_EXPIRY_MS
)
yield store.process_replication(result)
poke_pushers(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(30)
def setup(config_options):
try:
config = PusherSlaveConfig.load_config(
"Synapse pusher", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
if not config:
sys.exit(0)
config.setup_logging()
database_engine = create_engine(config.database_config)
ps = PusherServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string=get_version_string("Synapse", synapse),
database_engine=database_engine,
)
ps.setup()
ps.start_listening()
change_resource_limit(ps.config.soft_file_limit)
if ps.config.gc_thresholds:
gc.set_threshold(*ps.config.gc_thresholds)
def start():
ps.replicate()
ps.get_pusherpool().start()
ps.get_datastore().start_profiling()
reactor.callWhenRunning(start)
return ps
if __name__ == '__main__':
with LoggingContext("main"):
ps = setup(sys.argv[1:])
if ps.config.daemonize:
def run():
with LoggingContext("run"):
change_resource_limit(ps.config.soft_file_limit)
if ps.config.gc_thresholds:
gc.set_threshold(*ps.config.gc_thresholds)
reactor.run()
daemon = Daemonize(
app="synapse-pusher",
pid=ps.config.pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
reactor.run()

537
synapse/app/synchrotron.py Normal file
View file

@ -0,0 +1,537 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.api.constants import EventTypes, PresenceState
from synapse.config._base import ConfigError
from synapse.config.database import DatabaseConfig
from synapse.config.logger import LoggingConfig
from synapse.config.appservice import AppServiceConfig
from synapse.events import FrozenEvent
from synapse.handlers.presence import PresenceHandler
from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.rest.client.v2_alpha import sync
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
from synapse.storage.presence import PresenceStore, UserPresenceState
from synapse.storage.roommember import RoomMemberStore
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.stringutils import random_string
from synapse.util.versionstring import get_version_string
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import contextlib
import gc
import ujson as json
logger = logging.getLogger("synapse.app.synchrotron")
class SynchrotronConfig(DatabaseConfig, LoggingConfig, AppServiceConfig):
def read_config(self, config):
self.replication_url = config["replication_url"]
self.server_name = config["server_name"]
self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
"use_insecure_ssl_client_just_for_testing_do_not_use", False
)
self.user_agent_suffix = None
self.listeners = config["listeners"]
self.soft_file_limit = config.get("soft_file_limit")
self.daemonize = config.get("daemonize")
self.pid_file = self.abspath(config.get("pid_file"))
self.macaroon_secret_key = config["macaroon_secret_key"]
self.expire_access_token = config.get("expire_access_token", False)
thresholds = config.get("gc_thresholds", None)
if thresholds is not None:
try:
assert len(thresholds) == 3
self.gc_thresholds = (
int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
)
except:
raise ConfigError(
"Value of `gc_threshold` must be a list of three integers if set"
)
else:
self.gc_thresholds = None
def default_config(self, server_name, **kwargs):
pid_file = self.abspath("synchroton.pid")
return """\
# Slave configuration
# The replication listener on the synapse to talk to.
#replication_url: https://localhost:{replication_port}/_synapse/replication
server_name: "%(server_name)s"
listeners:
# Enable a /sync listener on the synchrontron
#- type: http
# port: {http_port}
# bind_address: ""
# Enable a ssh manhole listener on the synchrotron
# - type: manhole
# port: {manhole_port}
# bind_address: 127.0.0.1
# Enable a metric listener on the synchrotron
# - type: http
# port: {metrics_port}
# bind_address: 127.0.0.1
# resources:
# - names: ["metrics"]
# compress: False
report_stats: False
daemonize: False
pid_file: %(pid_file)s
""" % locals()
class SynchrotronSlavedStore(
SlavedPushRuleStore,
SlavedEventStore,
SlavedReceiptsStore,
SlavedAccountDataStore,
SlavedApplicationServiceStore,
SlavedRegistrationStore,
SlavedFilteringStore,
SlavedPresenceStore,
BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different
):
# XXX: This is a bit broken because we don't persist forgotten rooms
# in a way that they can be streamed. This means that we don't have a
# way to invalidate the forgotten rooms cache correctly.
# For now we expire the cache every 10 minutes.
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
)
# XXX: This is a bit broken because we don't persist the accepted list in a
# way that can be replicated. This means that we don't have a way to
# invalidate the cache correctly.
get_presence_list_accepted = PresenceStore.__dict__[
"get_presence_list_accepted"
]
UPDATE_SYNCING_USERS_MS = 10 * 1000
class SynchrotronPresence(object):
def __init__(self, hs):
self.http_client = hs.get_simple_http_client()
self.store = hs.get_datastore()
self.user_to_num_current_syncs = {}
self.syncing_users_url = hs.config.replication_url + "/syncing_users"
self.clock = hs.get_clock()
active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = {
state.user_id: state
for state in active_presence
}
self.process_id = random_string(16)
logger.info("Presence process_id is %r", self.process_id)
self._sending_sync = False
self._need_to_send_sync = False
self.clock.looping_call(
self._send_syncing_users_regularly,
UPDATE_SYNCING_USERS_MS,
)
reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
def set_state(self, user, state):
# TODO Hows this supposed to work?
pass
get_states = PresenceHandler.get_states.__func__
current_state_for_users = PresenceHandler.current_state_for_users.__func__
@defer.inlineCallbacks
def user_syncing(self, user_id, affect_presence):
if affect_presence:
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
self.user_to_num_current_syncs[user_id] = curr_sync + 1
prev_states = yield self.current_state_for_users([user_id])
if prev_states[user_id].state == PresenceState.OFFLINE:
# TODO: Don't block the sync request on this HTTP hit.
yield self._send_syncing_users_now()
def _end():
# We check that the user_id is in user_to_num_current_syncs because
# user_to_num_current_syncs may have been cleared if we are
# shutting down.
if affect_presence and user_id in self.user_to_num_current_syncs:
self.user_to_num_current_syncs[user_id] -= 1
@contextlib.contextmanager
def _user_syncing():
try:
yield
finally:
_end()
defer.returnValue(_user_syncing())
@defer.inlineCallbacks
def _on_shutdown(self):
# When the synchrotron is shutdown tell the master to clear the in
# progress syncs for this process
self.user_to_num_current_syncs.clear()
yield self._send_syncing_users_now()
def _send_syncing_users_regularly(self):
# Only send an update if we aren't in the middle of sending one.
if not self._sending_sync:
preserve_fn(self._send_syncing_users_now)()
@defer.inlineCallbacks
def _send_syncing_users_now(self):
if self._sending_sync:
# We don't want to race with sending another update.
# Instead we wait for that update to finish and send another
# update afterwards.
self._need_to_send_sync = True
return
# Flag that we are sending an update.
self._sending_sync = True
yield self.http_client.post_json_get_json(self.syncing_users_url, {
"process_id": self.process_id,
"syncing_users": [
user_id for user_id, count in self.user_to_num_current_syncs.items()
if count > 0
],
})
# Unset the flag as we are no longer sending an update.
self._sending_sync = False
if self._need_to_send_sync:
# If something happened while we were sending the update then
# we might need to send another update.
# TODO: Check if the update that was sent matches the current state
# as we only need to send an update if they are different.
self._need_to_send_sync = False
yield self._send_syncing_users_now()
def process_replication(self, result):
stream = result.get("presence", {"rows": []})
for row in stream["rows"]:
(
position, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active
) = row
self.user_to_current_state[user_id] = UserPresenceState(
user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active
)
class SynchrotronTyping(object):
def __init__(self, hs):
self._latest_room_serial = 0
self._room_serials = {}
self._room_typing = {}
def stream_positions(self):
return {"typing": self._latest_room_serial}
def process_replication(self, result):
stream = result.get("typing")
if stream:
self._latest_room_serial = int(stream["position"])
for row in stream["rows"]:
position, room_id, typing_json = row
typing = json.loads(typing_json)
self._room_serials[room_id] = position
self._room_typing[room_id] = typing
class SynchrotronApplicationService(object):
def notify_interested_services(self, event):
pass
class SynchrotronServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "client":
resource = JsonResource(self, canonical_json=False)
sync.register_servlets(self, resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse synchrotron now listening on port %d", port)
def start_listening(self):
for listener in self.config.listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.replication_url
clock = self.get_clock()
notifier = self.get_notifier()
presence_handler = self.get_presence_handler()
typing_handler = self.get_typing_handler()
def expire_broken_caches():
store.who_forgot_in_room.invalidate_all()
store.get_presence_list_accepted.invalidate_all()
def notify_from_stream(
result, stream_name, stream_key, room=None, user=None
):
stream = result.get(stream_name)
if stream:
position_index = stream["field_names"].index("position")
if room:
room_index = stream["field_names"].index(room)
if user:
user_index = stream["field_names"].index(user)
users = ()
rooms = ()
for row in stream["rows"]:
position = row[position_index]
if user:
users = (row[user_index],)
if room:
rooms = (row[room_index],)
notifier.on_new_event(
stream_key, position, users=users, rooms=rooms
)
def notify(result):
stream = result.get("events")
if stream:
max_position = stream["position"]
for row in stream["rows"]:
position = row[0]
internal = json.loads(row[1])
event_json = json.loads(row[2])
event = FrozenEvent(event_json, internal_metadata_dict=internal)
extra_users = ()
if event.type == EventTypes.Member:
extra_users = (event.state_key,)
notifier.on_new_room_event(
event, position, max_position, extra_users
)
notify_from_stream(
result, "push_rules", "push_rules_key", user="user_id"
)
notify_from_stream(
result, "user_account_data", "account_data_key", user="user_id"
)
notify_from_stream(
result, "room_account_data", "account_data_key", user="user_id"
)
notify_from_stream(
result, "tag_account_data", "account_data_key", user="user_id"
)
notify_from_stream(
result, "receipts", "receipt_key", room="room_id"
)
notify_from_stream(
result, "typing", "typing_key", room="room_id"
)
next_expire_broken_caches_ms = 0
while True:
try:
args = store.stream_positions()
args.update(typing_handler.stream_positions())
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
now_ms = clock.time_msec()
if now_ms > next_expire_broken_caches_ms:
expire_broken_caches()
next_expire_broken_caches_ms = (
now_ms + store.BROKEN_CACHE_EXPIRY_MS
)
yield store.process_replication(result)
typing_handler.process_replication(result)
presence_handler.process_replication(result)
notify(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def build_presence_handler(self):
return SynchrotronPresence(self)
def build_typing_handler(self):
return SynchrotronTyping(self)
def setup(config_options):
try:
config = SynchrotronConfig.load_config(
"Synapse synchrotron", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
if not config:
sys.exit(0)
config.setup_logging()
database_engine = create_engine(config.database_config)
ss = SynchrotronServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string=get_version_string("Synapse", synapse),
database_engine=database_engine,
application_service_handler=SynchrotronApplicationService(),
)
ss.setup()
ss.start_listening()
change_resource_limit(ss.config.soft_file_limit)
if ss.config.gc_thresholds:
ss.set_threshold(*ss.config.gc_thresholds)
def start():
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
return ss
if __name__ == '__main__':
with LoggingContext("main"):
ss = setup(sys.argv[1:])
if ss.config.daemonize:
def run():
with LoggingContext("run"):
change_resource_limit(ss.config.soft_file_limit)
if ss.config.gc_thresholds:
gc.set_threshold(*ss.config.gc_thresholds)
reactor.run()
daemon = Daemonize(
app="synapse-synchrotron",
pid=ss.config.pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
reactor.run()

View file

@ -66,6 +66,10 @@ def main():
config = yaml.load(open(configfile)) config = yaml.load(open(configfile))
pidfile = config["pid_file"] pidfile = config["pid_file"]
cache_factor = config.get("synctl_cache_factor", None)
if cache_factor:
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
action = sys.argv[1] if sys.argv[1:] else "usage" action = sys.argv[1] if sys.argv[1:] else "usage"
if action == "start": if action == "start":

View file

@ -100,11 +100,6 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("push_bulk to %s threw exception %s", uri, ex) logger.warning("push_bulk to %s threw exception %s", uri, ex)
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks
def push(self, service, event, txn_id=None):
response = yield self.push_bulk(service, [event], txn_id)
defer.returnValue(response)
def _serialize(self, events): def _serialize(self, events):
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
return [ return [

View file

@ -56,22 +56,22 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AppServiceScheduler(object): class ApplicationServiceScheduler(object):
""" Public facing API for this module. Does the required DI to tie the """ Public facing API for this module. Does the required DI to tie the
components together. This also serves as the "event_pool", which in this components together. This also serves as the "event_pool", which in this
case is a simple array. case is a simple array.
""" """
def __init__(self, clock, store, as_api): def __init__(self, hs):
self.clock = clock self.clock = hs.get_clock()
self.store = store self.store = hs.get_datastore()
self.as_api = as_api self.as_api = hs.get_application_service_api()
def create_recoverer(service, callback): def create_recoverer(service, callback):
return _Recoverer(clock, store, as_api, service, callback) return _Recoverer(self.clock, self.store, self.as_api, service, callback)
self.txn_ctrl = _TransactionController( self.txn_ctrl = _TransactionController(
clock, store, as_api, create_recoverer self.clock, self.store, self.as_api, create_recoverer
) )
self.queuer = _ServiceQueuer(self.txn_ctrl) self.queuer = _ServiceQueuer(self.txn_ctrl)

View file

@ -12,7 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import Config from ._base import Config, ConfigError
from synapse.appservice import ApplicationService
from synapse.types import UserID
import urllib
import yaml
import logging
logger = logging.getLogger(__name__)
class AppServiceConfig(Config): class AppServiceConfig(Config):
@ -25,3 +34,99 @@ class AppServiceConfig(Config):
# A list of application service config file to use # A list of application service config file to use
app_service_config_files: [] app_service_config_files: []
""" """
def load_appservices(hostname, config_files):
"""Returns a list of Application Services from the config files."""
if not isinstance(config_files, list):
logger.warning(
"Expected %s to be a list of AS config files.", config_files
)
return []
# Dicts of value -> filename
seen_as_tokens = {}
seen_ids = {}
appservices = []
for config_file in config_files:
try:
with open(config_file, 'r') as f:
appservice = _load_appservice(
hostname, yaml.load(f), config_file
)
if appservice.id in seen_ids:
raise ConfigError(
"Cannot reuse ID across application services: "
"%s (files: %s, %s)" % (
appservice.id, config_file, seen_ids[appservice.id],
)
)
seen_ids[appservice.id] = config_file
if appservice.token in seen_as_tokens:
raise ConfigError(
"Cannot reuse as_token across application services: "
"%s (files: %s, %s)" % (
appservice.token,
config_file,
seen_as_tokens[appservice.token],
)
)
seen_as_tokens[appservice.token] = config_file
logger.info("Loaded application service: %s", appservice)
appservices.append(appservice)
except Exception as e:
logger.error("Failed to load appservice from '%s'", config_file)
logger.exception(e)
raise
return appservices
def _load_appservice(hostname, as_info, config_filename):
required_string_fields = [
"id", "url", "as_token", "hs_token", "sender_localpart"
]
for field in required_string_fields:
if not isinstance(as_info.get(field), basestring):
raise KeyError("Required string field: '%s' (%s)" % (
field, config_filename,
))
localpart = as_info["sender_localpart"]
if urllib.quote(localpart) != localpart:
raise ValueError(
"sender_localpart needs characters which are not URL encoded."
)
user = UserID(localpart, hostname)
user_id = user.to_string()
# namespace checks
if not isinstance(as_info.get("namespaces"), dict):
raise KeyError("Requires 'namespaces' object.")
for ns in ApplicationService.NS_LIST:
# specific namespaces are optional
if ns in as_info["namespaces"]:
# expect a list of dicts with exclusive and regex keys
for regex_obj in as_info["namespaces"][ns]:
if not isinstance(regex_obj, dict):
raise ValueError(
"Expected namespace entry in %s to be an object,"
" but got %s", ns, regex_obj
)
if not isinstance(regex_obj.get("regex"), basestring):
raise ValueError(
"Missing/bad type 'regex' key in %s", regex_obj
)
if not isinstance(regex_obj.get("exclusive"), bool):
raise ValueError(
"Missing/bad type 'exclusive' key in %s", regex_obj
)
return ApplicationService(
token=as_info["as_token"],
url=as_info["url"],
namespaces=as_info["namespaces"],
hs_token=as_info["hs_token"],
sender=user_id,
id=as_info["id"],
)

View file

@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file can't be called email.py because if it is, we cannot:
import email.utils
from ._base import Config
class EmailConfig(Config):
def read_config(self, config):
self.email_enable_notifs = False
email_config = config.get("email", {})
self.email_enable_notifs = email_config.get("enable_notifs", False)
if self.email_enable_notifs:
# make sure we can import the required deps
import jinja2
import bleach
# prevent unused warnings
jinja2
bleach
required = [
"smtp_host",
"smtp_port",
"notif_from",
"template_dir",
"notif_template_html",
"notif_template_text",
]
missing = []
for k in required:
if k not in email_config:
missing.append(k)
if (len(missing) > 0):
raise RuntimeError(
"email.enable_notifs is True but required keys are missing: %s" %
(", ".join(["email." + k for k in missing]),)
)
if config.get("public_baseurl") is None:
raise RuntimeError(
"email.enable_notifs is True but no public_baseurl is set"
)
self.email_smtp_host = email_config["smtp_host"]
self.email_smtp_port = email_config["smtp_port"]
self.email_notif_from = email_config["notif_from"]
self.email_template_dir = email_config["template_dir"]
self.email_notif_template_html = email_config["notif_template_html"]
self.email_notif_template_text = email_config["notif_template_text"]
self.email_notif_for_new_users = email_config.get(
"notif_for_new_users", True
)
if "app_name" in email_config:
self.email_app_name = email_config["app_name"]
else:
self.email_app_name = "Matrix"
# make sure it's valid
parsed = email.utils.parseaddr(self.email_notif_from)
if parsed[1] == '':
raise RuntimeError("Invalid notif_from address")
else:
self.email_enable_notifs = False
# Not much point setting defaults for the rest: it would be an
# error for them to be used.
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable sending emails for notification events
#email:
# enable_notifs: false
# smtp_host: "localhost"
# smtp_port: 25
# notif_from: "Your Friendly %(app)s Home Server <noreply@example.com>"
# app_name: Matrix
# template_dir: res/templates
# notif_template_html: notif_mail.html
# notif_template_text: notif_mail.txt
# notif_for_new_users: True
"""

View file

@ -29,13 +29,16 @@ from .key import KeyConfig
from .saml2 import SAML2Config from .saml2 import SAML2Config
from .cas import CasConfig from .cas import CasConfig
from .password import PasswordConfig from .password import PasswordConfig
from .jwt import JWTConfig
from .ldap import LDAPConfig
from .emailconfig import EmailConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
PasswordConfig,): JWTConfig, LDAPConfig, PasswordConfig, EmailConfig,):
pass pass

54
synapse/config/jwt.py Normal file
View file

@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
# Copyright 2015 Niklas Riekenbrauck
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config, ConfigError
MISSING_JWT = (
"""Missing jwt library. This is required for jwt login.
Install by running:
pip install pyjwt
"""
)
class JWTConfig(Config):
def read_config(self, config):
jwt_config = config.get("jwt_config", None)
if jwt_config:
self.jwt_enabled = jwt_config.get("enabled", False)
self.jwt_secret = jwt_config["secret"]
self.jwt_algorithm = jwt_config["algorithm"]
try:
import jwt
jwt # To stop unused lint.
except ImportError:
raise ConfigError(MISSING_JWT)
else:
self.jwt_enabled = False
self.jwt_secret = None
self.jwt_algorithm = None
def default_config(self, **kwargs):
return """\
# The JWT needs to contain a globally unique "sub" (subject) claim.
#
# jwt_config:
# enabled: true
# secret: "a secret"
# algorithm: "HS256"
"""

View file

@ -57,6 +57,8 @@ class KeyConfig(Config):
seed = self.signing_key[0].seed seed = self.signing_key[0].seed
self.macaroon_secret_key = hashlib.sha256(seed) self.macaroon_secret_key = hashlib.sha256(seed)
self.expire_access_token = config.get("expire_access_token", False)
def default_config(self, config_dir_path, server_name, is_generating_file=False, def default_config(self, config_dir_path, server_name, is_generating_file=False,
**kwargs): **kwargs):
base_key_name = os.path.join(config_dir_path, server_name) base_key_name = os.path.join(config_dir_path, server_name)
@ -69,6 +71,9 @@ class KeyConfig(Config):
return """\ return """\
macaroon_secret_key: "%(macaroon_secret_key)s" macaroon_secret_key: "%(macaroon_secret_key)s"
# Used to enable access token expiration.
expire_access_token: False
## Signing Keys ## ## Signing Keys ##
# Path to the signing key to sign messages with # Path to the signing key to sign messages with

52
synapse/config/ldap.py Normal file
View file

@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
# Copyright 2015 Niklas Riekenbrauck
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class LDAPConfig(Config):
def read_config(self, config):
ldap_config = config.get("ldap_config", None)
if ldap_config:
self.ldap_enabled = ldap_config.get("enabled", False)
self.ldap_server = ldap_config["server"]
self.ldap_port = ldap_config["port"]
self.ldap_tls = ldap_config.get("tls", False)
self.ldap_search_base = ldap_config["search_base"]
self.ldap_search_property = ldap_config["search_property"]
self.ldap_email_property = ldap_config["email_property"]
self.ldap_full_name_property = ldap_config["full_name_property"]
else:
self.ldap_enabled = False
self.ldap_server = None
self.ldap_port = None
self.ldap_tls = False
self.ldap_search_base = None
self.ldap_search_property = None
self.ldap_email_property = None
self.ldap_full_name_property = None
def default_config(self, **kwargs):
return """\
# ldap_config:
# enabled: true
# server: "ldap://localhost"
# port: 389
# tls: false
# search_base: "ou=Users,dc=example,dc=com"
# search_property: "cn"
# email_property: "email"
# full_name_property: "givenName"
"""

View file

@ -32,6 +32,7 @@ class RegistrationConfig(Config):
) )
self.registration_shared_secret = config.get("registration_shared_secret") self.registration_shared_secret = config.get("registration_shared_secret")
self.user_creation_max_duration = int(config["user_creation_max_duration"])
self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"] self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
@ -54,6 +55,11 @@ class RegistrationConfig(Config):
# secret, even if registration is otherwise disabled. # secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s" registration_shared_secret: "%(registration_shared_secret)s"
# Sets the expiry for the short term user creation in
# milliseconds. For instance the bellow duration is two weeks
# in milliseconds.
user_creation_max_duration: 1209600000
# Set the number of bcrypt rounds used to generate password hash. # Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash. # Larger numbers increase the work factor needed to generate the hash.
# The default number of rounds is 12. # The default number of rounds is 12.

View file

@ -13,9 +13,25 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import Config from ._base import Config, ConfigError
from collections import namedtuple from collections import namedtuple
MISSING_NETADDR = (
"Missing netaddr library. This is required for URL preview API."
)
MISSING_LXML = (
"""Missing lxml library. This is required for URL preview API.
Install by running:
pip install lxml
Requires libxslt1-dev system package.
"""
)
ThumbnailRequirement = namedtuple( ThumbnailRequirement = namedtuple(
"ThumbnailRequirement", ["width", "height", "method", "media_type"] "ThumbnailRequirement", ["width", "height", "method", "media_type"]
) )
@ -23,7 +39,7 @@ ThumbnailRequirement = namedtuple(
def parse_thumbnail_requirements(thumbnail_sizes): def parse_thumbnail_requirements(thumbnail_sizes):
""" Takes a list of dictionaries with "width", "height", and "method" keys """ Takes a list of dictionaries with "width", "height", and "method" keys
and creates a map from image media types to the thumbnail size, thumnailing and creates a map from image media types to the thumbnail size, thumbnailing
method, and thumbnail media type to precalculate method, and thumbnail media type to precalculate
Args: Args:
@ -53,12 +69,44 @@ class ContentRepositoryConfig(Config):
def read_config(self, config): def read_config(self, config):
self.max_upload_size = self.parse_size(config["max_upload_size"]) self.max_upload_size = self.parse_size(config["max_upload_size"])
self.max_image_pixels = self.parse_size(config["max_image_pixels"]) self.max_image_pixels = self.parse_size(config["max_image_pixels"])
self.max_spider_size = self.parse_size(config["max_spider_size"])
self.media_store_path = self.ensure_directory(config["media_store_path"]) self.media_store_path = self.ensure_directory(config["media_store_path"])
self.uploads_path = self.ensure_directory(config["uploads_path"]) self.uploads_path = self.ensure_directory(config["uploads_path"])
self.dynamic_thumbnails = config["dynamic_thumbnails"] self.dynamic_thumbnails = config["dynamic_thumbnails"]
self.thumbnail_requirements = parse_thumbnail_requirements( self.thumbnail_requirements = parse_thumbnail_requirements(
config["thumbnail_sizes"] config["thumbnail_sizes"]
) )
self.url_preview_enabled = config.get("url_preview_enabled", False)
if self.url_preview_enabled:
try:
import lxml
lxml # To stop unused lint.
except ImportError:
raise ConfigError(MISSING_LXML)
try:
from netaddr import IPSet
except ImportError:
raise ConfigError(MISSING_NETADDR)
if "url_preview_ip_range_blacklist" in config:
self.url_preview_ip_range_blacklist = IPSet(
config["url_preview_ip_range_blacklist"]
)
else:
raise ConfigError(
"For security, you must specify an explicit target IP address "
"blacklist in url_preview_ip_range_blacklist for url previewing "
"to work"
)
self.url_preview_ip_range_whitelist = IPSet(
config.get("url_preview_ip_range_whitelist", ())
)
self.url_preview_url_blacklist = config.get(
"url_preview_url_blacklist", ()
)
def default_config(self, **kwargs): def default_config(self, **kwargs):
media_store = self.default_path("media_store") media_store = self.default_path("media_store")
@ -80,7 +128,7 @@ class ContentRepositoryConfig(Config):
# the resolution requested by the client. If true then whenever # the resolution requested by the client. If true then whenever
# a new resolution is requested by the client the server will # a new resolution is requested by the client the server will
# generate a new thumbnail. If false the server will pick a thumbnail # generate a new thumbnail. If false the server will pick a thumbnail
# from a precalcualted list. # from a precalculated list.
dynamic_thumbnails: false dynamic_thumbnails: false
# List of thumbnail to precalculate when an image is uploaded. # List of thumbnail to precalculate when an image is uploaded.
@ -100,4 +148,71 @@ class ContentRepositoryConfig(Config):
- width: 800 - width: 800
height: 600 height: 600
method: scale method: scale
# Is the preview URL API enabled? If enabled, you *must* specify
# an explicit url_preview_ip_range_blacklist of IPs that the spider is
# denied from accessing.
url_preview_enabled: False
# List of IP address CIDR ranges that the URL preview spider is denied
# from accessing. There are no defaults: you must explicitly
# specify a list for URL previewing to work. You should specify any
# internal services in your network that you do not want synapse to try
# to connect to, otherwise anyone in any Matrix room could cause your
# synapse to issue arbitrary GET requests to your internal services,
# causing serious security issues.
#
# url_preview_ip_range_blacklist:
# - '127.0.0.0/8'
# - '10.0.0.0/8'
# - '172.16.0.0/12'
# - '192.168.0.0/16'
#
# List of IP address CIDR ranges that the URL preview spider is allowed
# to access even if they are specified in url_preview_ip_range_blacklist.
# This is useful for specifying exceptions to wide-ranging blacklisted
# target IP ranges - e.g. for enabling URL previews for a specific private
# website only visible in your network.
#
# url_preview_ip_range_whitelist:
# - '192.168.1.1'
# Optional list of URL matches that the URL preview spider is
# denied from accessing. You should use url_preview_ip_range_blacklist
# in preference to this, otherwise someone could define a public DNS
# entry that points to a private IP address and circumvent the blacklist.
# This is more useful if you know there is an entire shape of URL that
# you know that will never want synapse to try to spider.
#
# Each list entry is a dictionary of url component attributes as returned
# by urlparse.urlsplit as applied to the absolute form of the URL. See
# https://docs.python.org/2/library/urlparse.html#urlparse.urlsplit
# The values of the dictionary are treated as an filename match pattern
# applied to that component of URLs, unless they start with a ^ in which
# case they are treated as a regular expression match. If all the
# specified component matches for a given list item succeed, the URL is
# blacklisted.
#
# url_preview_url_blacklist:
# # blacklist any URL with a username in its URI
# - username: '*'
#
# # blacklist all *.google.com URLs
# - netloc: 'google.com'
# - netloc: '*.google.com'
#
# # blacklist all plain HTTP URLs
# - scheme: 'http'
#
# # blacklist http(s)://www.acme.com/foo
# - netloc: 'www.acme.com'
# path: '/foo'
#
# # blacklist any URL with a literal IPv4 address
# - netloc: '^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$'
# The largest allowed URL preview spidering size in bytes
max_spider_size: "10M"
""" % locals() """ % locals()

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import Config from ._base import Config, ConfigError
class ServerConfig(Config): class ServerConfig(Config):
@ -28,9 +28,30 @@ class ServerConfig(Config):
self.print_pidfile = config.get("print_pidfile") self.print_pidfile = config.get("print_pidfile")
self.user_agent_suffix = config.get("user_agent_suffix") self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", True) self.use_frozen_dicts = config.get("use_frozen_dicts", True)
self.public_baseurl = config.get("public_baseurl")
self.secondary_directory_servers = config.get("secondary_directory_servers", [])
if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/':
self.public_baseurl += '/'
self.start_pushers = config.get("start_pushers", True)
self.listeners = config.get("listeners", []) self.listeners = config.get("listeners", [])
thresholds = config.get("gc_thresholds", None)
if thresholds is not None:
try:
assert len(thresholds) == 3
self.gc_thresholds = (
int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
)
except:
raise ConfigError(
"Value of `gc_threshold` must be a list of three integers if set"
)
else:
self.gc_thresholds = None
bind_port = config.get("bind_port") bind_port = config.get("bind_port")
if bind_port: if bind_port:
self.listeners = [] self.listeners = []
@ -142,11 +163,26 @@ class ServerConfig(Config):
# Whether to serve a web client from the HTTP/HTTPS root resource. # Whether to serve a web client from the HTTP/HTTPS root resource.
web_client: True web_client: True
# The public-facing base URL for the client API (not including _matrix/...)
# public_baseurl: https://example.com:8448/
# Set the soft limit on the number of file descriptors synapse can use # Set the soft limit on the number of file descriptors synapse can use
# Zero is used to indicate synapse should set the soft limit to the # Zero is used to indicate synapse should set the soft limit to the
# hard limit. # hard limit.
soft_file_limit: 0 soft_file_limit: 0
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
# gc_thresholds: [700, 10, 10]
# A list of other Home Servers to fetch the public room directory from
# and include in the public room directory of this home server
# This is a temporary stopgap solution to populate new server with a
# list of rooms until there exists a good solution of a decentralized
# room directory.
# secondary_directory_servers:
# - matrix.org
# - vector.im
# List of ports that Synapse should listen on, their purpose and their # List of ports that Synapse should listen on, their purpose and their
# configuration. # configuration.
listeners: listeners:

View file

@ -31,7 +31,10 @@ class _EventInternalMetadata(object):
return dict(self.__dict__) return dict(self.__dict__)
def is_outlier(self): def is_outlier(self):
return hasattr(self, "outlier") and self.outlier return getattr(self, "outlier", False)
def is_invite_from_remote(self):
return getattr(self, "invite_from_remote", False)
def _event_dict_property(key): def _event_dict_property(key):

View file

@ -24,6 +24,7 @@ from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError, CodeMessageException, HttpResponseException, SynapseError,
) )
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
@ -550,6 +551,25 @@ class FederationClient(FederationBase):
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks
def get_public_rooms(self, destinations):
results_by_server = {}
@defer.inlineCallbacks
def _get_result(s):
if s == self.server_name:
defer.returnValue()
try:
result = yield self.transport_layer.get_public_rooms(s)
results_by_server[s] = result
except:
logger.exception("Error getting room list from server %r", s)
yield concurrently_execute(_get_result, destinations, 3)
defer.returnValue(results_by_server)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_auth(self, destination, room_id, event_id, local_auth): def query_auth(self, destination, room_id, event_id, local_auth):
""" """

View file

@ -387,6 +387,11 @@ class FederationServer(FederationBase):
"events": [ev.get_pdu_json(time_now) for ev in missing_events], "events": [ev.get_pdu_json(time_now) for ev in missing_events],
}) })
@log_function
def on_openid_userinfo(self, token):
ts_now_ms = self._clock.time_msec()
return self.store.get_user_id_for_open_id_token(token, ts_now_ms)
@log_function @log_function
def _get_persisted_pdu(self, origin, event_id, do_auth=True): def _get_persisted_pdu(self, origin, event_id, do_auth=True):
""" Get a PDU from the database with given origin and id. """ Get a PDU from the database with given origin and id.

View file

@ -20,6 +20,7 @@ from .persistence import TransactionActions
from .units import Transaction from .units import Transaction
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.retryutils import ( from synapse.util.retryutils import (
@ -199,6 +200,8 @@ class TransactionQueue(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _attempt_new_transaction(self, destination): def _attempt_new_transaction(self, destination):
yield run_on_reactor()
# list of (pending_pdu, deferred, order) # list of (pending_pdu, deferred, order)
if destination in self.pending_transactions: if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending # XXX: pending_transactions can get stuck on by a never-ending

View file

@ -179,7 +179,8 @@ class TransportLayerClient(object):
content = yield self.client.get_json( content = yield self.client.get_json(
destination=destination, destination=destination,
path=path, path=path,
retry_on_dns_fail=True, retry_on_dns_fail=False,
timeout=20000,
) )
defer.returnValue(content) defer.returnValue(content)
@ -223,6 +224,18 @@ class TransportLayerClient(object):
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def get_public_rooms(self, remote_server):
path = PREFIX + "/publicRooms"
response = yield self.client.get_json(
destination=remote_server,
path=path,
)
defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def exchange_third_party_invite(self, destination, room_id, event_dict): def exchange_third_party_invite(self, destination, room_id, event_dict):

View file

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request, parse_string
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
import functools import functools
@ -134,10 +134,12 @@ class Authenticator(object):
class BaseFederationServlet(object): class BaseFederationServlet(object):
def __init__(self, handler, authenticator, ratelimiter, server_name): def __init__(self, handler, authenticator, ratelimiter, server_name,
room_list_handler):
self.handler = handler self.handler = handler
self.authenticator = authenticator self.authenticator = authenticator
self.ratelimiter = ratelimiter self.ratelimiter = ratelimiter
self.room_list_handler = room_list_handler
def _wrap(self, code): def _wrap(self, code):
authenticator = self.authenticator authenticator = self.authenticator
@ -323,7 +325,7 @@ class FederationSendLeaveServlet(BaseFederationServlet):
class FederationEventAuthServlet(BaseFederationServlet): class FederationEventAuthServlet(BaseFederationServlet):
PATH = "/event_auth(?P<context>[^/]*)/(?P<event_id>[^/]*)" PATH = "/event_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
def on_GET(self, origin, content, query, context, event_id): def on_GET(self, origin, content, query, context, event_id):
return self.handler.on_event_auth(origin, context, event_id) return self.handler.on_event_auth(origin, context, event_id)
@ -448,6 +450,94 @@ class On3pidBindServlet(BaseFederationServlet):
return code return code
class OpenIdUserInfo(BaseFederationServlet):
"""
Exchange a bearer token for information about a user.
The response format should be compatible with:
http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse
GET /openid/userinfo?access_token=ABDEFGH HTTP/1.1
HTTP/1.1 200 OK
Content-Type: application/json
{
"sub": "@userpart:example.org",
}
"""
PATH = "/openid/userinfo"
@defer.inlineCallbacks
def on_GET(self, request):
token = parse_string(request, "access_token")
if token is None:
defer.returnValue((401, {
"errcode": "M_MISSING_TOKEN", "error": "Access Token required"
}))
return
user_id = yield self.handler.on_openid_userinfo(token)
if user_id is None:
defer.returnValue((401, {
"errcode": "M_UNKNOWN_TOKEN",
"error": "Access Token unknown or expired"
}))
defer.returnValue((200, {"sub": user_id}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
class PublicRoomList(BaseFederationServlet):
"""
Fetch the public room list for this server.
This API returns information in the same format as /publicRooms on the
client API, but will only ever include local public rooms and hence is
intended for consumption by other home servers.
GET /publicRooms HTTP/1.1
HTTP/1.1 200 OK
Content-Type: application/json
{
"chunk": [
{
"aliases": [
"#test:localhost"
],
"guest_can_join": false,
"name": "test room",
"num_joined_members": 3,
"room_id": "!whkydVegtvatLfXmPN:localhost",
"world_readable": false
}
],
"end": "END",
"start": "START"
}
"""
PATH = "/publicRooms"
@defer.inlineCallbacks
def on_GET(self, request):
data = yield self.room_list_handler.get_local_public_room_list()
defer.returnValue((200, data))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
SERVLET_CLASSES = ( SERVLET_CLASSES = (
FederationSendServlet, FederationSendServlet,
FederationPullServlet, FederationPullServlet,
@ -468,6 +558,8 @@ SERVLET_CLASSES = (
FederationClientKeysClaimServlet, FederationClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet, FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet, On3pidBindServlet,
OpenIdUserInfo,
PublicRoomList,
) )
@ -478,4 +570,5 @@ def register_servlets(hs, resource, authenticator, ratelimiter):
authenticator=authenticator, authenticator=authenticator,
ratelimiter=ratelimiter, ratelimiter=ratelimiter,
server_name=hs.hostname, server_name=hs.hostname,
room_list_handler=hs.get_room_list_handler(),
).register(resource) ).register(resource)

View file

@ -13,23 +13,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.appservice.scheduler import AppServiceScheduler
from synapse.appservice.api import ApplicationServiceApi
from .register import RegistrationHandler from .register import RegistrationHandler
from .room import ( from .room import (
RoomCreationHandler, RoomMemberHandler, RoomListHandler, RoomContextHandler, RoomCreationHandler, RoomContextHandler,
) )
from .room_member import RoomMemberHandler
from .message import MessageHandler from .message import MessageHandler
from .events import EventStreamHandler, EventHandler from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler from .federation import FederationHandler
from .profile import ProfileHandler from .profile import ProfileHandler
from .presence import PresenceHandler
from .directory import DirectoryHandler from .directory import DirectoryHandler
from .typing import TypingNotificationHandler
from .admin import AdminHandler from .admin import AdminHandler
from .appservice import ApplicationServicesHandler
from .sync import SyncHandler
from .auth import AuthHandler
from .identity import IdentityHandler from .identity import IdentityHandler
from .receipts import ReceiptsHandler from .receipts import ReceiptsHandler
from .search import SearchHandler from .search import SearchHandler
@ -52,22 +46,9 @@ class Handlers(object):
self.event_handler = EventHandler(hs) self.event_handler = EventHandler(hs)
self.federation_handler = FederationHandler(hs) self.federation_handler = FederationHandler(hs)
self.profile_handler = ProfileHandler(hs) self.profile_handler = ProfileHandler(hs)
self.presence_handler = PresenceHandler(hs)
self.room_list_handler = RoomListHandler(hs)
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)
self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs) self.admin_handler = AdminHandler(hs)
self.receipts_handler = ReceiptsHandler(hs) self.receipts_handler = ReceiptsHandler(hs)
asapi = ApplicationServiceApi(hs)
self.appservice_handler = ApplicationServicesHandler(
hs, asapi, AppServiceScheduler(
clock=hs.get_clock(),
store=hs.get_datastore(),
as_api=asapi
)
)
self.sync_handler = SyncHandler(hs)
self.auth_handler = AuthHandler(hs)
self.identity_handler = IdentityHandler(hs) self.identity_handler = IdentityHandler(hs)
self.search_handler = SearchHandler(hs) self.search_handler = SearchHandler(hs)
self.room_context_handler = RoomContextHandler(hs) self.room_context_handler = RoomContextHandler(hs)

View file

@ -15,13 +15,10 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import LimitExceededError, SynapseError, AuthError from synapse.api.errors import LimitExceededError
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, RoomAlias, Requester from synapse.types import UserID, Requester
from synapse.push.action_generator import ActionGenerator
from synapse.util.logcontext import PreserveLoggingContext
import logging import logging
@ -29,20 +26,13 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VISIBILITY_PRIORITY = (
"world_readable",
"shared",
"invited",
"joined",
)
class BaseHandler(object): class BaseHandler(object):
""" """
Common base class for the event handlers. Common base class for the event handlers.
:type store: synapse.storage.events.StateStore Attributes:
:type state_handler: synapse.state.StateHandler store (synapse.storage.events.StateStore):
state_handler (synapse.state.StateHandler):
""" """
def __init__(self, hs): def __init__(self, hs):
@ -55,137 +45,10 @@ class BaseHandler(object):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.hs = hs self.hs = hs
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname self.server_name = hs.hostname
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks
def filter_events_for_clients(self, user_tuples, events, event_id_to_state):
""" Returns dict of user_id -> list of events that user is allowed to
see.
:param (str, bool) user_tuples: (user id, is_peeking) for each
user to be checked. is_peeking should be true if:
* the user is not currently a member of the room, and:
* the user has not been a member of the room since the given
events
"""
forgotten = yield defer.gatherResults([
self.store.who_forgot_in_room(
room_id,
)
for room_id in frozenset(e.room_id for e in events)
], consumeErrors=True)
# Set of membership event_ids that have been forgotten
event_id_forgotten = frozenset(
row["event_id"] for rows in forgotten for row in rows
)
def allowed(event, user_id, is_peeking):
state = event_id_to_state[event.event_id]
# get the room_visibility at the time of the event.
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event:
visibility = visibility_event.content.get("history_visibility", "shared")
else:
visibility = "shared"
if visibility not in VISIBILITY_PRIORITY:
visibility = "shared"
# if it was world_readable, it's easy: everyone can read it
if visibility == "world_readable":
return True
# Always allow history visibility events on boundaries. This is done
# by setting the effective visibility to the least restrictive
# of the old vs new.
if event.type == EventTypes.RoomHistoryVisibility:
prev_content = event.unsigned.get("prev_content", {})
prev_visibility = prev_content.get("history_visibility", None)
if prev_visibility not in VISIBILITY_PRIORITY:
prev_visibility = "shared"
new_priority = VISIBILITY_PRIORITY.index(visibility)
old_priority = VISIBILITY_PRIORITY.index(prev_visibility)
if old_priority < new_priority:
visibility = prev_visibility
# get the user's membership at the time of the event. (or rather,
# just *after* the event. Which means that people can see their
# own join events, but not (currently) their own leave events.)
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
if membership_event.event_id in event_id_forgotten:
membership = None
else:
membership = membership_event.membership
else:
membership = None
# if the user was a member of the room at the time of the event,
# they can see it.
if membership == Membership.JOIN:
return True
if visibility == "joined":
# we weren't a member at the time of the event, so we can't
# see this event.
return False
elif visibility == "invited":
# user can also see the event if they were *invited* at the time
# of the event.
return membership == Membership.INVITE
else:
# visibility is shared: user can also see the event if they have
# become a member since the event
#
# XXX: if the user has subsequently joined and then left again,
# ideally we would share history up to the point they left. But
# we don't know when they left.
return not is_peeking
defer.returnValue({
user_id: [
event
for event in events
if allowed(event, user_id, is_peeking)
]
for user_id, is_peeking in user_tuples
})
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events, is_peeking=False):
"""
Check which events a user is allowed to see
:param str user_id: user id to be checked
:param [synapse.events.EventBase] events: list of events to be checked
:param bool is_peeking should be True if:
* the user is not currently a member of the room, and:
* the user has not been a member of the room since the given
events
:rtype [synapse.events.EventBase]
"""
types = (
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
event_id_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events),
types=types
)
res = yield self.filter_events_for_clients(
[(user_id, is_peeking)], events, event_id_to_state
)
defer.returnValue(res.get(user_id, []))
def ratelimit(self, requester): def ratelimit(self, requester):
time_now = self.clock.time() time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message( allowed, time_allowed = self.ratelimiter.send_message(
@ -198,95 +61,6 @@ class BaseHandler(object):
retry_after_ms=int(1000 * (time_allowed - time_now)), retry_after_ms=int(1000 * (time_allowed - time_now)),
) )
@defer.inlineCallbacks
def _create_new_client_event(self, builder):
latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
builder.room_id,
)
if latest_ret:
depth = max([d for _, _, d in latest_ret]) + 1
else:
depth = 1
prev_events = [
(event_id, prev_hashes)
for event_id, prev_hashes, _ in latest_ret
]
builder.prev_events = prev_events
builder.depth = depth
state_handler = self.state_handler
context = yield state_handler.compute_event_context(builder)
# If we've received an invite over federation, there are no latest
# events in the room, because we don't know enough about the graph
# fragment we received to treat it like a graph, so the above returned
# no relevant events. It may have returned some events (if we have
# joined and left the room), but not useful ones, like the invite.
if (
not self.is_host_in_room(context.current_state) and
builder.type == EventTypes.Member
):
prev_member_event = yield self.store.get_room_member(
builder.sender, builder.room_id
)
# The prev_member_event may already be in context.current_state,
# despite us not being present in the room; in particular, if
# inviting user, and all other local users, have already left.
#
# In that case, we have all the information we need, and we don't
# want to drop "context" - not least because we may need to handle
# the invite locally, which will require us to have the whole
# context (not just prev_member_event) to auth it.
#
context_event_ids = (
e.event_id for e in context.current_state.values()
)
if (
prev_member_event and
prev_member_event.event_id not in context_event_ids
):
# The prev_member_event is missing from context, so it must
# have arrived over federation and is an outlier. We forcibly
# set our context to the invite we received over federation
builder.prev_events = (
prev_member_event.event_id,
prev_member_event.prev_events
)
context = yield state_handler.compute_event_context(
builder,
old_state=(prev_member_event,),
outlier=True
)
if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events
)
yield self.auth.add_auth_events(builder, context)
add_hashes_and_signatures(
builder, self.server_name, self.signing_key
)
event = builder.build()
logger.debug(
"Created event %s with current state: %s",
event.event_id, context.current_state,
)
defer.returnValue(
(event, context,)
)
def is_host_in_room(self, current_state): def is_host_in_room(self, current_state):
room_members = [ room_members = [
(state_key, event.membership) (state_key, event.membership)
@ -301,143 +75,12 @@ class BaseHandler(object):
return True return True
for (state_key, membership) in room_members: for (state_key, membership) in room_members:
if ( if (
UserID.from_string(state_key).domain == self.hs.hostname self.hs.is_mine_id(state_key)
and membership == Membership.JOIN and membership == Membership.JOIN
): ):
return True return True
return False return False
@defer.inlineCallbacks
def handle_new_client_event(
self,
requester,
event,
context,
ratelimit=True,
extra_users=[]
):
# We now need to go and hit out to wherever we need to hit out to.
if ratelimit:
self.ratelimit(requester)
self.auth.check(event, auth_events=context.current_state)
yield self.maybe_kick_guest_users(event, context.current_state.values())
if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least)
room_alias_str = event.content.get("alias", None)
if room_alias_str:
room_alias = RoomAlias.from_string(room_alias_str)
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room" % (
room_alias_str,
)
)
federation_handler = self.hs.get_handlers().federation_handler
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
def is_inviter_member_event(e):
return (
e.type == EventTypes.Member and
e.sender == event.sender
)
event.unsigned["invite_room_state"] = [
{
"type": e.type,
"state_key": e.state_key,
"content": e.content,
"sender": e.sender,
}
for k, e in context.current_state.items()
if e.type in self.hs.config.room_invite_state_types
or is_inviter_member_event(e)
]
invitee = UserID.from_string(event.state_key)
if not self.hs.is_mine(invitee):
# TODO: Can we add signature from remote server in a nicer
# way? If we have been invited by a remote server, we need
# to get them to sign the event.
returned_invite = yield federation_handler.send_invite(
invitee.domain,
event,
)
event.unsigned.pop("room_state", None)
# TODO: Make sure the signatures actually are correct.
event.signatures.update(
returned_invite.signatures
)
if event.type == EventTypes.Redaction:
if self.auth.check_redaction(event, auth_events=context.current_state):
original_event = yield self.store.get_event(
event.redacts,
check_redacted=False,
get_prev_content=False,
allow_rejected=False,
allow_none=False
)
if event.user_id != original_event.user_id:
raise AuthError(
403,
"You don't have permission to redact events"
)
if event.type == EventTypes.Create and context.current_state:
raise AuthError(
403,
"Changing the room create event is forbidden",
)
action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event(
event, context, self
)
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
)
destinations = set()
for k, s in context.current_state.items():
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN:
destinations.add(
UserID.from_string(s.state_key).domain
)
except SynapseError:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
with PreserveLoggingContext():
# Don't block waiting on waking up all the listeners.
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
# If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None)
federation_handler.handle_new_event(
event, destinations=destinations,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def maybe_kick_guest_users(self, event, current_state): def maybe_kick_guest_users(self, event, current_state):
# Technically this function invalidates current_state by changing it. # Technically this function invalidates current_state by changing it.

View file

@ -17,7 +17,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.types import UserID
import logging import logging
@ -35,16 +34,13 @@ def log_failure(failure):
) )
# NB: Purposefully not inheriting BaseHandler since that contains way too much
# setup code which this handler does not need or use. This makes testing a lot
# easier.
class ApplicationServicesHandler(object): class ApplicationServicesHandler(object):
def __init__(self, hs, appservice_api, appservice_scheduler): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.hs = hs self.is_mine_id = hs.is_mine_id
self.appservice_api = appservice_api self.appservice_api = hs.get_application_service_api()
self.scheduler = appservice_scheduler self.scheduler = hs.get_application_service_scheduler()
self.started_scheduler = False self.started_scheduler = False
@defer.inlineCallbacks @defer.inlineCallbacks
@ -169,8 +165,7 @@ class ApplicationServicesHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _is_unknown_user(self, user_id): def _is_unknown_user(self, user_id):
user = UserID.from_string(user_id) if not self.is_mine_id(user_id):
if not self.hs.is_mine(user):
# we don't know if they are unknown or not since it isn't one of our # we don't know if they are unknown or not since it isn't one of our
# users. We can't poke ASes. # users. We can't poke ASes.
defer.returnValue(False) defer.returnValue(False)

View file

@ -18,7 +18,7 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import AuthError, LoginError, Codes from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
@ -49,6 +49,21 @@ class AuthHandler(BaseHandler):
self.sessions = {} self.sessions = {}
self.INVALID_TOKEN_HTTP_STATUS = 401 self.INVALID_TOKEN_HTTP_STATUS = 401
self.ldap_enabled = hs.config.ldap_enabled
self.ldap_server = hs.config.ldap_server
self.ldap_port = hs.config.ldap_port
self.ldap_tls = hs.config.ldap_tls
self.ldap_search_base = hs.config.ldap_search_base
self.ldap_search_property = hs.config.ldap_search_property
self.ldap_email_property = hs.config.ldap_email_property
self.ldap_full_name_property = hs.config.ldap_full_name_property
if self.ldap_enabled is True:
import ldap
logger.info("Import ldap version: %s", ldap.__version__)
self.hs = hs # FIXME better possibility to access registrationHandler later?
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip): def check_auth(self, flows, clientdict, clientip):
""" """
@ -163,9 +178,13 @@ class AuthHandler(BaseHandler):
def get_session_id(self, clientdict): def get_session_id(self, clientdict):
""" """
Gets the session ID for a client given the client dictionary Gets the session ID for a client given the client dictionary
:param clientdict: The dictionary sent by the client in the request
:return: The string session ID the client sent. If the client did not Args:
send a session ID, returns None. clientdict: The dictionary sent by the client in the request
Returns:
str|None: The string session ID the client sent. If the client did
not send a session ID, returns None.
""" """
sid = None sid = None
if clientdict and 'auth' in clientdict: if clientdict and 'auth' in clientdict:
@ -179,9 +198,11 @@ class AuthHandler(BaseHandler):
Store a key-value pair into the sessions data associated with this Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by request. This data is stored server-side and cannot be modified by
the client. the client.
:param session_id: (string) The ID of this session as returned from check_auth
:param key: (string) The key to store the data under Args:
:param value: (any) The data to store session_id (string): The ID of this session as returned from check_auth
key (string): The key to store the data under
value (any): The data to store
""" """
sess = self._get_session_info(session_id) sess = self._get_session_info(session_id)
sess.setdefault('serverdict', {})[key] = value sess.setdefault('serverdict', {})[key] = value
@ -190,9 +211,11 @@ class AuthHandler(BaseHandler):
def get_session_data(self, session_id, key, default=None): def get_session_data(self, session_id, key, default=None):
""" """
Retrieve data stored with set_session_data Retrieve data stored with set_session_data
:param session_id: (string) The ID of this session as returned from check_auth
:param key: (string) The key to store the data under Args:
:param default: (any) Value to return if the key has not been set session_id (string): The ID of this session as returned from check_auth
key (string): The key to store the data under
default (any): Value to return if the key has not been set
""" """
sess = self._get_session_info(session_id) sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default) return sess.setdefault('serverdict', {}).get(key, default)
@ -207,8 +230,10 @@ class AuthHandler(BaseHandler):
if not user_id.startswith('@'): if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string() user_id = UserID.create(user_id, self.hs.hostname).to_string()
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) if not (yield self._check_password(user_id, password)):
self._check_password(user_id, password, password_hash) logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id) defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -332,8 +357,10 @@ class AuthHandler(BaseHandler):
StoreError if there was a problem storing the token. StoreError if there was a problem storing the token.
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
self._check_password(user_id, password, password_hash) if not (yield self._check_password(user_id, password)):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
logger.info("Logging in user %s", user_id) logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id) access_token = yield self.issue_access_token(user_id)
@ -399,11 +426,67 @@ class AuthHandler(BaseHandler):
else: else:
defer.returnValue(user_infos.popitem()) defer.returnValue(user_infos.popitem())
def _check_password(self, user_id, password, stored_hash): @defer.inlineCallbacks
"""Checks that user_id has passed password, raises LoginError if not.""" def _check_password(self, user_id, password):
if not self.validate_hash(password, stored_hash): """
logger.warn("Failed password login for user %s", user_id) Returns:
raise LoginError(403, "", errcode=Codes.FORBIDDEN) True if the user_id successfully authenticated
"""
valid_ldap = yield self._check_ldap_password(user_id, password)
if valid_ldap:
defer.returnValue(True)
valid_local_password = yield self._check_local_password(user_id, password)
if valid_local_password:
defer.returnValue(True)
defer.returnValue(False)
@defer.inlineCallbacks
def _check_local_password(self, user_id, password):
try:
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(self.validate_hash(password, password_hash))
except LoginError:
defer.returnValue(False)
@defer.inlineCallbacks
def _check_ldap_password(self, user_id, password):
if not self.ldap_enabled:
logger.debug("LDAP not configured")
defer.returnValue(False)
import ldap
logger.info("Authenticating %s with LDAP" % user_id)
try:
ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port)
logger.debug("Connecting LDAP server at %s" % ldap_url)
l = ldap.initialize(ldap_url)
if self.ldap_tls:
logger.debug("Initiating TLS")
self._connection.start_tls_s()
local_name = UserID.from_string(user_id).localpart
dn = "%s=%s, %s" % (
self.ldap_search_property,
local_name,
self.ldap_search_base)
logger.debug("DN for LDAP authentication: %s" % dn)
l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8'))
if not (yield self.does_user_exist(user_id)):
handler = self.hs.get_handlers().registration_handler
user_id, access_token = (
yield handler.register(localpart=local_name)
)
defer.returnValue(True)
except ldap.LDAPError, e:
logger.warn("LDAP error: %s", e)
defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def issue_access_token(self, user_id): def issue_access_token(self, user_id):
@ -438,14 +521,19 @@ class AuthHandler(BaseHandler):
)) ))
return m.serialize() return m.serialize()
def generate_short_term_login_token(self, user_id): def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login") macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
expiry = now + (2 * 60 * 1000) expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize() return macaroon.serialize()
def generate_delete_pusher_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
return macaroon.serialize()
def validate_short_term_login_token_and_get_user_id(self, login_token): def validate_short_term_login_token_and_get_user_id(self, login_token):
try: try:
macaroon = pymacaroons.Macaroon.deserialize(login_token) macaroon = pymacaroons.Macaroon.deserialize(login_token)
@ -480,7 +568,12 @@ class AuthHandler(BaseHandler):
except_access_token_ids = [requester.access_token_id] if requester else [] except_access_token_ids = [requester.access_token_id] if requester else []
yield self.store.user_set_password_hash(user_id, password_hash) try:
yield self.store.user_set_password_hash(user_id, password_hash)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e
yield self.store.user_delete_access_tokens( yield self.store.user_delete_access_tokens(
user_id, except_access_token_ids user_id, except_access_token_ids
) )
@ -532,4 +625,7 @@ class AuthHandler(BaseHandler):
Returns: Returns:
Whether self.hash(password) == stored_hash (bool). Whether self.hash(password) == stored_hash (bool).
""" """
return bcrypt.hashpw(password, stored_hash) == stored_hash if stored_hash:
return bcrypt.hashpw(password, stored_hash) == stored_hash
else:
return False

View file

@ -33,6 +33,7 @@ class DirectoryHandler(BaseHandler):
super(DirectoryHandler, self).__init__(hs) super(DirectoryHandler, self).__init__(hs)
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.appservice_handler = hs.get_application_service_handler()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.federation.register_query_handler( self.federation.register_query_handler(
@ -281,7 +282,7 @@ class DirectoryHandler(BaseHandler):
) )
if not result: if not result:
# Query AS to see if it exists # Query AS to see if it exists
as_handler = self.hs.get_handlers().appservice_handler as_handler = self.appservice_handler
result = yield as_handler.query_room_alias_exists(room_alias) result = yield as_handler.query_room_alias_exists(room_alias)
defer.returnValue(result) defer.returnValue(result)

View file

@ -58,7 +58,7 @@ class EventStreamHandler(BaseHandler):
If `only_keys` is not None, events from keys will be sent down. If `only_keys` is not None, events from keys will be sent down.
""" """
auth_user = UserID.from_string(auth_user_id) auth_user = UserID.from_string(auth_user_id)
presence_handler = self.hs.get_handlers().presence_handler presence_handler = self.hs.get_presence_handler()
context = yield presence_handler.user_syncing( context = yield presence_handler.user_syncing(
auth_user_id, affect_presence=affect_presence, auth_user_id, affect_presence=affect_presence,

View file

@ -26,20 +26,21 @@ from synapse.api.errors import (
from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
from synapse.crypto.event_signing import ( from synapse.crypto.event_signing import (
compute_event_signature, add_hashes_and_signatures, compute_event_signature, add_hashes_and_signatures,
) )
from synapse.types import UserID from synapse.types import UserID, get_domain_from_id
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
from synapse.push.action_generator import ActionGenerator from synapse.push.action_generator import ActionGenerator
from synapse.util.distributor import user_joined_room
from twisted.internet import defer from twisted.internet import defer
@ -49,10 +50,6 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def user_joined_room(distributor, user, room_id):
return distributor.fire("user_joined_room", user, room_id)
class FederationHandler(BaseHandler): class FederationHandler(BaseHandler):
"""Handles events that originated from federation. """Handles events that originated from federation.
Responsible for: Responsible for:
@ -69,10 +66,6 @@ class FederationHandler(BaseHandler):
self.hs = hs self.hs = hs
self.distributor.observe("user_joined_room", self.user_joined_room)
self.waiting_for_join_list = {}
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.replication_layer = hs.get_replication_layer() self.replication_layer = hs.get_replication_layer()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
@ -102,8 +95,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def on_receive_pdu(self, origin, pdu, state=None, def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
auth_chain=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to """ Called by the ReplicationLayer when we have a new pdu. We need to
do auth checks and put it through the StateHandler. do auth checks and put it through the StateHandler.
""" """
@ -174,11 +166,7 @@ class FederationHandler(BaseHandler):
}) })
seen_ids.add(e.event_id) seen_ids.add(e.event_id)
yield self._handle_new_events( yield self._handle_new_events(origin, event_infos)
origin,
event_infos,
outliers=True
)
try: try:
context, event_stream_id, max_stream_id = yield self._handle_new_event( context, event_stream_id, max_stream_id = yield self._handle_new_event(
@ -288,7 +276,14 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities=[]): def backfill(self, dest, room_id, limit, extremities=[]):
""" Trigger a backfill request to `dest` for the given `room_id` """ Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. This may return
be successfull and still return no events if the other side has no new
events to offer.
""" """
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
if not extremities: if not extremities:
extremities = yield self.store.get_oldest_events_in_room(room_id) extremities = yield self.store.get_oldest_events_in_room(room_id)
@ -299,6 +294,16 @@ class FederationHandler(BaseHandler):
extremities=extremities, extremities=extremities,
) )
# Don't bother processing events we already have.
seen_events = yield self.store.have_events_in_timeline(
set(e.event_id for e in events)
)
events = [e for e in events if e.event_id not in seen_events]
if not events:
defer.returnValue([])
event_map = {e.event_id: e for e in events} event_map = {e.event_id: e for e in events}
event_ids = set(e.event_id for e in events) event_ids = set(e.event_id for e in events)
@ -358,6 +363,7 @@ class FederationHandler(BaseHandler):
for a in auth_events.values(): for a in auth_events.values():
if a.event_id in seen_events: if a.event_id in seen_events:
continue continue
a.internal_metadata.outlier = True
ev_infos.append({ ev_infos.append({
"event": a, "event": a,
"auth_events": { "auth_events": {
@ -378,20 +384,23 @@ class FederationHandler(BaseHandler):
} }
}) })
yield self._handle_new_events(
dest, ev_infos,
backfilled=True,
)
events.sort(key=lambda e: e.depth) events.sort(key=lambda e: e.depth)
for event in events: for event in events:
if event in events_to_state: if event in events_to_state:
continue continue
ev_infos.append({ # We store these one at a time since each event depends on the
"event": event, # previous to work out the state.
}) # TODO: We can probably do something more clever here.
yield self._handle_new_event(
yield self._handle_new_events( dest, event
dest, ev_infos, )
backfilled=True,
)
defer.returnValue(events) defer.returnValue(events)
@ -440,7 +449,7 @@ class FederationHandler(BaseHandler):
joined_domains = {} joined_domains = {}
for u, d in joined_users: for u, d in joined_users:
try: try:
dom = UserID.from_string(u).domain dom = get_domain_from_id(u)
old_d = joined_domains.get(dom) old_d = joined_domains.get(dom)
if old_d: if old_d:
joined_domains[dom] = min(d, old_d) joined_domains[dom] = min(d, old_d)
@ -455,7 +464,7 @@ class FederationHandler(BaseHandler):
likely_domains = [ likely_domains = [
domain for domain, depth in curr_domains domain for domain, depth in curr_domains
if domain is not self.server_name if domain != self.server_name
] ]
@defer.inlineCallbacks @defer.inlineCallbacks
@ -463,11 +472,15 @@ class FederationHandler(BaseHandler):
# TODO: Should we try multiple of these at a time? # TODO: Should we try multiple of these at a time?
for dom in domains: for dom in domains:
try: try:
events = yield self.backfill( yield self.backfill(
dom, room_id, dom, room_id,
limit=100, limit=100,
extremities=[e for e in extremities.keys()] extremities=[e for e in extremities.keys()]
) )
# If this succeeded then we probably already have the
# appropriate stuff.
# TODO: We can probably do something more intelligent here.
defer.returnValue(True)
except SynapseError as e: except SynapseError as e:
logger.info( logger.info(
"Failed to backfill from %s because %s", "Failed to backfill from %s because %s",
@ -493,8 +506,6 @@ class FederationHandler(BaseHandler):
) )
continue continue
if events:
defer.returnValue(True)
defer.returnValue(False) defer.returnValue(False)
success = yield try_backfill(likely_domains) success = yield try_backfill(likely_domains)
@ -666,9 +677,14 @@ class FederationHandler(BaseHandler):
"state_key": user_id, "state_key": user_id,
}) })
event, context = yield self._create_new_client_event( try:
builder=builder, message_handler = self.hs.get_handlers().message_handler
) event, context = yield message_handler._create_new_client_event(
builder=builder,
)
except AuthError as e:
logger.warn("Failed to create join %r because %s", event, e)
raise e
self.auth.check(event, auth_events=context.current_state) self.auth.check(event, auth_events=context.current_state)
@ -724,9 +740,7 @@ class FederationHandler(BaseHandler):
try: try:
if k[0] == EventTypes.Member: if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN: if s.content["membership"] == Membership.JOIN:
destinations.add( destinations.add(get_domain_from_id(s.state_key))
UserID.from_string(s.state_key).domain
)
except: except:
logger.warn( logger.warn(
"Failed to get destination from event %s", s.event_id "Failed to get destination from event %s", s.event_id
@ -761,6 +775,7 @@ class FederationHandler(BaseHandler):
event = pdu event = pdu
event.internal_metadata.outlier = True event.internal_metadata.outlier = True
event.internal_metadata.invite_from_remote = True
event.signatures.update( event.signatures.update(
compute_event_signature( compute_event_signature(
@ -788,13 +803,19 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id): def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
origin, event = yield self._make_and_verify_event( try:
target_hosts, origin, event = yield self._make_and_verify_event(
room_id, target_hosts,
user_id, room_id,
"leave" user_id,
) "leave"
signed_event = self._sign_event(event) )
signed_event = self._sign_event(event)
except SynapseError:
raise
except CodeMessageException as e:
logger.warn("Failed to reject invite: %s", e)
raise SynapseError(500, "Failed to reject invite")
# Try the host we successfully got a response to /make_join/ # Try the host we successfully got a response to /make_join/
# request first. # request first.
@ -804,10 +825,16 @@ class FederationHandler(BaseHandler):
except ValueError: except ValueError:
pass pass
yield self.replication_layer.send_leave( try:
target_hosts, yield self.replication_layer.send_leave(
signed_event target_hosts,
) signed_event
)
except SynapseError:
raise
except CodeMessageException as e:
logger.warn("Failed to reject invite: %s", e)
raise SynapseError(500, "Failed to reject invite")
context = yield self.state_handler.compute_event_context(event) context = yield self.state_handler.compute_event_context(event)
@ -883,11 +910,16 @@ class FederationHandler(BaseHandler):
"state_key": user_id, "state_key": user_id,
}) })
event, context = yield self._create_new_client_event( message_handler = self.hs.get_handlers().message_handler
event, context = yield message_handler._create_new_client_event(
builder=builder, builder=builder,
) )
self.auth.check(event, auth_events=context.current_state) try:
self.auth.check(event, auth_events=context.current_state)
except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e)
raise e
defer.returnValue(event) defer.returnValue(event)
@ -934,9 +966,7 @@ class FederationHandler(BaseHandler):
try: try:
if k[0] == EventTypes.Member: if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.LEAVE: if s.content["membership"] == Membership.LEAVE:
destinations.add( destinations.add(get_domain_from_id(s.state_key))
UserID.from_string(s.state_key).domain
)
except: except:
logger.warn( logger.warn(
"Failed to get destination from event %s", s.event_id "Failed to get destination from event %s", s.event_id
@ -1057,21 +1087,10 @@ class FederationHandler(BaseHandler):
def get_min_depth_for_context(self, context): def get_min_depth_for_context(self, context):
return self.store.get_min_depth(context) return self.store.get_min_depth(context)
@log_function
def user_joined_room(self, user, room_id):
waiters = self.waiting_for_join_list.get(
(user.to_string(), room_id),
[]
)
while waiters:
waiters.pop().callback(None)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _handle_new_event(self, origin, event, state=None, auth_events=None): def _handle_new_event(self, origin, event, state=None, auth_events=None,
backfilled=False):
outlier = event.internal_metadata.is_outlier()
context = yield self._prep_event( context = yield self._prep_event(
origin, event, origin, event,
state=state, state=state,
@ -1081,20 +1100,30 @@ class FederationHandler(BaseHandler):
if not event.internal_metadata.is_outlier(): if not event.internal_metadata.is_outlier():
action_generator = ActionGenerator(self.hs) action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event( yield action_generator.handle_push_actions_for_event(
event, context, self event, context
) )
event_stream_id, max_stream_id = yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, event,
context=context, context=context,
is_new_state=not outlier, backfilled=backfilled,
)
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
event_stream_id, max_stream_id
) )
defer.returnValue((context, event_stream_id, max_stream_id)) defer.returnValue((context, event_stream_id, max_stream_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_new_events(self, origin, event_infos, backfilled=False, def _handle_new_events(self, origin, event_infos, backfilled=False):
outliers=False): """Creates the appropriate contexts and persists events. The events
should not depend on one another, e.g. this should be used to persist
a bunch of outliers, but not a chunk of individual events that depend
on each other for state calculations.
"""
contexts = yield defer.gatherResults( contexts = yield defer.gatherResults(
[ [
self._prep_event( self._prep_event(
@ -1113,7 +1142,6 @@ class FederationHandler(BaseHandler):
for ev_info, context in itertools.izip(event_infos, contexts) for ev_info, context in itertools.izip(event_infos, contexts)
], ],
backfilled=backfilled, backfilled=backfilled,
is_new_state=(not outliers and not backfilled),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -1128,11 +1156,9 @@ class FederationHandler(BaseHandler):
""" """
events_to_context = {} events_to_context = {}
for e in itertools.chain(auth_events, state): for e in itertools.chain(auth_events, state):
ctx = yield self.state_handler.compute_event_context(
e, outlier=True,
)
events_to_context[e.event_id] = ctx
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
ctx = yield self.state_handler.compute_event_context(e)
events_to_context[e.event_id] = ctx
event_map = { event_map = {
e.event_id: e e.event_id: e
@ -1176,16 +1202,14 @@ class FederationHandler(BaseHandler):
(e, events_to_context[e.event_id]) (e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state) for e in itertools.chain(auth_events, state)
], ],
is_new_state=False,
) )
new_event_context = yield self.state_handler.compute_event_context( new_event_context = yield self.state_handler.compute_event_context(
event, old_state=state, outlier=False, event, old_state=state
) )
event_stream_id, max_stream_id = yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, new_event_context, event, new_event_context,
is_new_state=True,
current_state=state, current_state=state,
) )
@ -1193,10 +1217,9 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, auth_events=None): def _prep_event(self, origin, event, state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier()
context = yield self.state_handler.compute_event_context( context = yield self.state_handler.compute_event_context(
event, old_state=state, outlier=outlier, event, old_state=state,
) )
if not auth_events: if not auth_events:
@ -1482,8 +1505,9 @@ class FederationHandler(BaseHandler):
try: try:
self.auth.check(event, auth_events=auth_events) self.auth.check(event, auth_events=auth_events)
except AuthError: except AuthError as e:
raise logger.warn("Failed auth resolution for %r because %s", event, e)
raise e
@defer.inlineCallbacks @defer.inlineCallbacks
def construct_auth_difference(self, local_auth, remote_auth): def construct_auth_difference(self, local_auth, remote_auth):
@ -1653,13 +1677,21 @@ class FederationHandler(BaseHandler):
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder) EventValidator().validate_new(builder)
event, context = yield self._create_new_client_event(builder=builder) message_handler = self.hs.get_handlers().message_handler
event, context = yield message_handler._create_new_client_event(
builder=builder
)
event, context = yield self.add_display_name_to_third_party_invite( event, context = yield self.add_display_name_to_third_party_invite(
event_dict, event, context event_dict, event, context
) )
self.auth.check(event, context.current_state) try:
self.auth.check(event, context.current_state)
except AuthError as e:
logger.warn("Denying new third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, auth_events=context.current_state) yield self._check_signature(event, auth_events=context.current_state)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(None, event, context) yield member_handler.send_membership_event(None, event, context)
@ -1676,7 +1708,8 @@ class FederationHandler(BaseHandler):
def on_exchange_third_party_invite_request(self, origin, room_id, event_dict): def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
event, context = yield self._create_new_client_event( message_handler = self.hs.get_handlers().message_handler
event, context = yield message_handler._create_new_client_event(
builder=builder, builder=builder,
) )
@ -1684,7 +1717,11 @@ class FederationHandler(BaseHandler):
event_dict, event, context event_dict, event, context
) )
self.auth.check(event, auth_events=context.current_state) try:
self.auth.check(event, auth_events=context.current_state)
except AuthError as e:
logger.warn("Denying third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, auth_events=context.current_state) yield self._check_signature(event, auth_events=context.current_state)
returned_invite = yield self.send_invite(origin, event) returned_invite = yield self.send_invite(origin, event)
@ -1711,20 +1748,23 @@ class FederationHandler(BaseHandler):
event_dict["content"]["third_party_invite"]["display_name"] = display_name event_dict["content"]["third_party_invite"]["display_name"] = display_name
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder) EventValidator().validate_new(builder)
event, context = yield self._create_new_client_event(builder=builder) message_handler = self.hs.get_handlers().message_handler
event, context = yield message_handler._create_new_client_event(builder=builder)
defer.returnValue((event, context)) defer.returnValue((event, context))
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_signature(self, event, auth_events): def _check_signature(self, event, auth_events):
""" """
Checks that the signature in the event is consistent with its invite. Checks that the signature in the event is consistent with its invite.
:param event (Event): The m.room.member event to check
:param auth_events (dict<(event type, state_key), event>)
:raises Args:
AuthError if signature didn't match any keys, or key has been event (Event): The m.room.member event to check
auth_events (dict<(event type, state_key), event>):
Raises:
AuthError: if signature didn't match any keys, or key has been
revoked, revoked,
SynapseError if a transient error meant a key couldn't be checked SynapseError: if a transient error meant a key couldn't be checked
for revocation. for revocation.
""" """
signed = event.content["third_party_invite"]["signed"] signed = event.content["third_party_invite"]["signed"]
@ -1766,12 +1806,13 @@ class FederationHandler(BaseHandler):
""" """
Checks whether public_key has been revoked. Checks whether public_key has been revoked.
:param public_key (str): base-64 encoded public key. Args:
:param url (str): Key revocation URL. public_key (str): base-64 encoded public key.
url (str): Key revocation URL.
:raises Raises:
AuthError if they key has been revoked. AuthError: if they key has been revoked.
SynapseError if a transient error meant a key couldn't be checked SynapseError: if a transient error meant a key couldn't be checked
for revocation. for revocation.
""" """
try: try:

View file

@ -17,12 +17,19 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.streams.config import PaginationConfig from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.push.action_generator import ActionGenerator
from synapse.streams.config import PaginationConfig
from synapse.types import (
UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id
)
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute, run_on_reactor
from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.types import UserID, RoomStreamToken, StreamToken from synapse.util.logcontext import preserve_fn
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
@ -33,10 +40,6 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def collect_presencelike_data(distributor, user, content):
return distributor.fire("collect_presencelike_data", user, content)
class MessageHandler(BaseHandler): class MessageHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
@ -47,35 +50,6 @@ class MessageHandler(BaseHandler):
self.validator = EventValidator() self.validator = EventValidator()
self.snapshot_cache = SnapshotCache() self.snapshot_cache = SnapshotCache()
@defer.inlineCallbacks
def get_message(self, msg_id=None, room_id=None, sender_id=None,
user_id=None):
""" Retrieve a message.
Args:
msg_id (str): The message ID to obtain.
room_id (str): The room where the message resides.
sender_id (str): The user ID of the user who sent the message.
user_id (str): The user ID of the user making this request.
Returns:
The message, or None if no message exists.
Raises:
SynapseError if something went wrong.
"""
yield self.auth.check_joined_room(room_id, user_id)
# Pull out the message from the db
# msg = yield self.store.get_message(
# room_id=room_id,
# msg_id=msg_id,
# user_id=sender_id
# )
# TODO (erikj): Once we work out the correct c-s api we need to think
# on how to do this.
defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None, def get_messages(self, requester, room_id=None, pagin_config=None,
as_client_event=True): as_client_event=True):
@ -155,7 +129,8 @@ class MessageHandler(BaseHandler):
"end": next_token.to_string(), "end": next_token.to_string(),
}) })
events = yield self._filter_events_for_client( events = yield filter_events_for_client(
self.store,
user_id, user_id,
events, events,
is_peeking=(member_event_id is None), is_peeking=(member_event_id is None),
@ -175,7 +150,7 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
@defer.inlineCallbacks @defer.inlineCallbacks
def create_event(self, event_dict, token_id=None, txn_id=None): def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None):
""" """
Given a dict from a client, create a new event. Given a dict from a client, create a new event.
@ -186,6 +161,9 @@ class MessageHandler(BaseHandler):
Args: Args:
event_dict (dict): An entire event event_dict (dict): An entire event
token_id (str)
txn_id (str)
prev_event_ids (list): The prev event ids to use when creating the event
Returns: Returns:
Tuple of created event (FrozenEvent), Context Tuple of created event (FrozenEvent), Context
@ -198,12 +176,8 @@ class MessageHandler(BaseHandler):
membership = builder.content.get("membership", None) membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key) target = UserID.from_string(builder.state_key)
if membership == Membership.JOIN: if membership in {Membership.JOIN, Membership.INVITE}:
# If event doesn't include a display name, add one. # If event doesn't include a display name, add one.
yield collect_presencelike_data(
self.distributor, target, builder.content
)
elif membership == Membership.INVITE:
profile = self.hs.get_handlers().profile_handler profile = self.hs.get_handlers().profile_handler
content = builder.content content = builder.content
@ -224,6 +198,7 @@ class MessageHandler(BaseHandler):
event, context = yield self._create_new_client_event( event, context = yield self._create_new_client_event(
builder=builder, builder=builder,
prev_event_ids=prev_event_ids,
) )
defer.returnValue((event, context)) defer.returnValue((event, context))
@ -261,7 +236,7 @@ class MessageHandler(BaseHandler):
) )
if event.type == EventTypes.Message: if event.type == EventTypes.Message:
presence = self.hs.get_handlers().presence_handler presence = self.hs.get_presence_handler()
yield presence.bump_presence_active_time(user) yield presence.bump_presence_active_time(user)
def deduplicate_state_event(self, event, context): def deduplicate_state_event(self, event, context):
@ -515,8 +490,8 @@ class MessageHandler(BaseHandler):
] ]
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
messages = yield self._filter_events_for_client( messages = yield filter_events_for_client(
user_id, messages self.store, user_id, messages
) )
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token[0])
@ -556,14 +531,7 @@ class MessageHandler(BaseHandler):
except: except:
logger.exception("Failed to get snapshot") logger.exception("Failed to get snapshot")
# Only do N rooms at once yield concurrently_execute(handle_room, room_list, 10)
n = 5
d_list = [handle_room(e) for e in room_list]
for i in range(0, len(d_list), n):
yield defer.gatherResults(
d_list[i:i + n],
consumeErrors=True
).addErrback(unwrapFirstError)
account_data_events = [] account_data_events = []
for account_data_type, content in account_data.items(): for account_data_type, content in account_data.items():
@ -658,8 +626,8 @@ class MessageHandler(BaseHandler):
end_token=stream_token end_token=stream_token
) )
messages = yield self._filter_events_for_client( messages = yield filter_events_for_client(
user_id, messages, is_peeking=is_peeking self.store, user_id, messages, is_peeking=is_peeking
) )
start_token = StreamToken.START.copy_and_replace("room_key", token[0]) start_token = StreamToken.START.copy_and_replace("room_key", token[0])
@ -706,7 +674,7 @@ class MessageHandler(BaseHandler):
and m.content["membership"] == Membership.JOIN and m.content["membership"] == Membership.JOIN
] ]
presence_handler = self.hs.get_handlers().presence_handler presence_handler = self.hs.get_presence_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_presence(): def get_presence():
@ -739,8 +707,8 @@ class MessageHandler(BaseHandler):
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
messages = yield self._filter_events_for_client( messages = yield filter_events_for_client(
user_id, messages, is_peeking=is_peeking, self.store, user_id, messages, is_peeking=is_peeking,
) )
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token[0])
@ -763,3 +731,196 @@ class MessageHandler(BaseHandler):
ret["membership"] = membership ret["membership"] = membership
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks
def _create_new_client_event(self, builder, prev_event_ids=None):
if prev_event_ids:
prev_events = yield self.store.add_event_hashes(prev_event_ids)
prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
depth = prev_max_depth + 1
else:
latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
builder.room_id,
)
if latest_ret:
depth = max([d for _, _, d in latest_ret]) + 1
else:
depth = 1
prev_events = [
(event_id, prev_hashes)
for event_id, prev_hashes, _ in latest_ret
]
builder.prev_events = prev_events
builder.depth = depth
state_handler = self.state_handler
context = yield state_handler.compute_event_context(builder)
if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events
)
yield self.auth.add_auth_events(builder, context)
signing_key = self.hs.config.signing_key[0]
add_hashes_and_signatures(
builder, self.server_name, signing_key
)
event = builder.build()
logger.debug(
"Created event %s with current state: %s",
event.event_id, context.current_state,
)
defer.returnValue(
(event, context,)
)
@defer.inlineCallbacks
def handle_new_client_event(
self,
requester,
event,
context,
ratelimit=True,
extra_users=[]
):
# We now need to go and hit out to wherever we need to hit out to.
if ratelimit:
self.ratelimit(requester)
try:
self.auth.check(event, auth_events=context.current_state)
except AuthError as err:
logger.warn("Denying new event %r because %s", event, err)
raise err
yield self.maybe_kick_guest_users(event, context.current_state.values())
if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least)
room_alias_str = event.content.get("alias", None)
if room_alias_str:
room_alias = RoomAlias.from_string(room_alias_str)
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room" % (
room_alias_str,
)
)
federation_handler = self.hs.get_handlers().federation_handler
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
def is_inviter_member_event(e):
return (
e.type == EventTypes.Member and
e.sender == event.sender
)
event.unsigned["invite_room_state"] = [
{
"type": e.type,
"state_key": e.state_key,
"content": e.content,
"sender": e.sender,
}
for k, e in context.current_state.items()
if e.type in self.hs.config.room_invite_state_types
or is_inviter_member_event(e)
]
invitee = UserID.from_string(event.state_key)
if not self.hs.is_mine(invitee):
# TODO: Can we add signature from remote server in a nicer
# way? If we have been invited by a remote server, we need
# to get them to sign the event.
returned_invite = yield federation_handler.send_invite(
invitee.domain,
event,
)
event.unsigned.pop("room_state", None)
# TODO: Make sure the signatures actually are correct.
event.signatures.update(
returned_invite.signatures
)
if event.type == EventTypes.Redaction:
if self.auth.check_redaction(event, auth_events=context.current_state):
original_event = yield self.store.get_event(
event.redacts,
check_redacted=False,
get_prev_content=False,
allow_rejected=False,
allow_none=False
)
if event.user_id != original_event.user_id:
raise AuthError(
403,
"You don't have permission to redact events"
)
if event.type == EventTypes.Create and context.current_state:
raise AuthError(
403,
"Changing the room create event is forbidden",
)
action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event(
event, context
)
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
)
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
event_stream_id, max_stream_id
)
destinations = set()
for k, s in context.current_state.items():
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN:
destinations.add(get_domain_from_id(s.state_key))
except SynapseError:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
@defer.inlineCallbacks
def _notify():
yield run_on_reactor()
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
preserve_fn(_notify)()
# If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None)
federation_handler.handle_new_event(
event, destinations=destinations,
)

View file

@ -33,11 +33,9 @@ from synapse.util.logcontext import preserve_fn
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer from synapse.util.wheel_timer import WheelTimer
from synapse.types import UserID from synapse.types import UserID, get_domain_from_id
import synapse.metrics import synapse.metrics
from ._base import BaseHandler
import logging import logging
@ -52,6 +50,8 @@ timers_fired_counter = metrics.register_counter("timers_fired")
federation_presence_counter = metrics.register_counter("federation_presence") federation_presence_counter = metrics.register_counter("federation_presence")
bump_active_time_counter = metrics.register_counter("bump_active_time") bump_active_time_counter = metrics.register_counter("bump_active_time")
get_updates_counter = metrics.register_counter("get_updates", labels=["type"])
# If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them # If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them
# "currently_active" # "currently_active"
@ -70,14 +70,18 @@ FEDERATION_TIMEOUT = 30 * 60 * 1000
# How often to resend presence to remote servers # How often to resend presence to remote servers
FEDERATION_PING_INTERVAL = 25 * 60 * 1000 FEDERATION_PING_INTERVAL = 25 * 60 * 1000
# How long we will wait before assuming that the syncs from an external process
# are dead.
EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000
assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
class PresenceHandler(BaseHandler): class PresenceHandler(object):
def __init__(self, hs): def __init__(self, hs):
super(PresenceHandler, self).__init__(hs) self.is_mine = hs.is_mine
self.hs = hs self.is_mine_id = hs.is_mine_id
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.wheel_timer = WheelTimer() self.wheel_timer = WheelTimer()
@ -138,7 +142,7 @@ class PresenceHandler(BaseHandler):
obj=state.user_id, obj=state.user_id,
then=state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT, then=state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
) )
if self.hs.is_mine_id(state.user_id): if self.is_mine_id(state.user_id):
self.wheel_timer.insert( self.wheel_timer.insert(
now=now, now=now,
obj=state.user_id, obj=state.user_id,
@ -160,15 +164,26 @@ class PresenceHandler(BaseHandler):
self.serial_to_user = {} self.serial_to_user = {}
self._next_serial = 1 self._next_serial = 1
# Keeps track of the number of *ongoing* syncs. While this is non zero # Keeps track of the number of *ongoing* syncs on this process. While
# a user will never go offline. # this is non zero a user will never go offline.
self.user_to_num_current_syncs = {} self.user_to_num_current_syncs = {}
# Keeps track of the number of *ongoing* syncs on other processes.
# While any sync is ongoing on another process the user will never
# go offline.
# Each process has a unique identifier and an update frequency. If
# no update is received from that process within the update period then
# we assume that all the sync requests on that process have stopped.
# Stored as a dict from process_id to set of user_id, and a dict of
# process_id to millisecond timestamp last updated.
self.external_process_to_current_syncs = {}
self.external_process_last_updated_ms = {}
# Start a LoopingCall in 30s that fires every 5s. # Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to # The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline. # reconnect before we treat them as offline.
self.clock.call_later( self.clock.call_later(
0 * 1000, 30,
self.clock.looping_call, self.clock.looping_call,
self._handle_timeouts, self._handle_timeouts,
5000, 5000,
@ -228,7 +243,7 @@ class PresenceHandler(BaseHandler):
new_state, should_notify, should_ping = handle_update( new_state, should_notify, should_ping = handle_update(
prev_state, new_state, prev_state, new_state,
is_mine=self.hs.is_mine_id(user_id), is_mine=self.is_mine_id(user_id),
wheel_timer=self.wheel_timer, wheel_timer=self.wheel_timer,
now=now now=now
) )
@ -268,31 +283,48 @@ class PresenceHandler(BaseHandler):
"""Checks the presence of users that have timed out and updates as """Checks the presence of users that have timed out and updates as
appropriate. appropriate.
""" """
logger.info("Handling presence timeouts")
now = self.clock.time_msec() now = self.clock.time_msec()
with Measure(self.clock, "presence_handle_timeouts"): try:
# Fetch the list of users that *may* have timed out. Things may have with Measure(self.clock, "presence_handle_timeouts"):
# changed since the timeout was set, so we won't necessarily have to # Fetch the list of users that *may* have timed out. Things may have
# take any action. # changed since the timeout was set, so we won't necessarily have to
users_to_check = self.wheel_timer.fetch(now) # take any action.
users_to_check = set(self.wheel_timer.fetch(now))
states = [ # Check whether the lists of syncing processes from an external
self.user_to_current_state.get( # process have expired.
user_id, UserPresenceState.default(user_id) expired_process_ids = [
process_id for process_id, last_update
in self.external_process_last_updated_ms.items()
if now - last_update > EXTERNAL_PROCESS_EXPIRY
]
for process_id in expired_process_ids:
users_to_check.update(
self.external_process_last_updated_ms.pop(process_id, ())
)
self.external_process_last_update.pop(process_id)
states = [
self.user_to_current_state.get(
user_id, UserPresenceState.default(user_id)
)
for user_id in users_to_check
]
timers_fired_counter.inc_by(len(states))
changes = handle_timeouts(
states,
is_mine_fn=self.is_mine_id,
syncing_user_ids=self.get_currently_syncing_users(),
now=now,
) )
for user_id in set(users_to_check)
]
timers_fired_counter.inc_by(len(states)) preserve_fn(self._update_states)(changes)
except:
changes = handle_timeouts( logger.exception("Exception in _handle_timeouts loop")
states,
is_mine_fn=self.hs.is_mine_id,
user_to_num_current_syncs=self.user_to_num_current_syncs,
now=now,
)
preserve_fn(self._update_states)(changes)
@defer.inlineCallbacks @defer.inlineCallbacks
def bump_presence_active_time(self, user): def bump_presence_active_time(self, user):
@ -365,6 +397,74 @@ class PresenceHandler(BaseHandler):
defer.returnValue(_user_syncing()) defer.returnValue(_user_syncing())
def get_currently_syncing_users(self):
"""Get the set of user ids that are currently syncing on this HS.
Returns:
set(str): A set of user_id strings.
"""
syncing_user_ids = {
user_id for user_id, count in self.user_to_num_current_syncs.items()
if count
}
for user_ids in self.external_process_to_current_syncs.values():
syncing_user_ids.update(user_ids)
return syncing_user_ids
@defer.inlineCallbacks
def update_external_syncs(self, process_id, syncing_user_ids):
"""Update the syncing users for an external process
Args:
process_id(str): An identifier for the process the users are
syncing against. This allows synapse to process updates
as user start and stop syncing against a given process.
syncing_user_ids(set(str)): The set of user_ids that are
currently syncing on that server.
"""
# Grab the previous list of user_ids that were syncing on that process
prev_syncing_user_ids = (
self.external_process_to_current_syncs.get(process_id, set())
)
# Grab the current presence state for both the users that are syncing
# now and the users that were syncing before this update.
prev_states = yield self.current_state_for_users(
syncing_user_ids | prev_syncing_user_ids
)
updates = []
time_now_ms = self.clock.time_msec()
# For each new user that is syncing check if we need to mark them as
# being online.
for new_user_id in syncing_user_ids - prev_syncing_user_ids:
prev_state = prev_states[new_user_id]
if prev_state.state == PresenceState.OFFLINE:
updates.append(prev_state.copy_and_replace(
state=PresenceState.ONLINE,
last_active_ts=time_now_ms,
last_user_sync_ts=time_now_ms,
))
else:
updates.append(prev_state.copy_and_replace(
last_user_sync_ts=time_now_ms,
))
# For each user that is still syncing or stopped syncing update the
# last sync time so that we will correctly apply the grace period when
# they stop syncing.
for old_user_id in prev_syncing_user_ids:
prev_state = prev_states[old_user_id]
updates.append(prev_state.copy_and_replace(
last_user_sync_ts=time_now_ms,
))
yield self._update_states(updates)
# Update the last updated time for the process. We expire the entries
# if we don't receive an update in the given timeframe.
self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
self.external_process_to_current_syncs[process_id] = syncing_user_ids
@defer.inlineCallbacks @defer.inlineCallbacks
def current_state_for_user(self, user_id): def current_state_for_user(self, user_id):
"""Get the current presence state for a user. """Get the current presence state for a user.
@ -427,7 +527,7 @@ class PresenceHandler(BaseHandler):
hosts_to_states = {} hosts_to_states = {}
for room_id, states in room_ids_to_states.items(): for room_id, states in room_ids_to_states.items():
local_states = filter(lambda s: self.hs.is_mine_id(s.user_id), states) local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
if not local_states: if not local_states:
continue continue
@ -436,11 +536,11 @@ class PresenceHandler(BaseHandler):
hosts_to_states.setdefault(host, []).extend(local_states) hosts_to_states.setdefault(host, []).extend(local_states)
for user_id, states in users_to_states.items(): for user_id, states in users_to_states.items():
local_states = filter(lambda s: self.hs.is_mine_id(s.user_id), states) local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
if not local_states: if not local_states:
continue continue
host = UserID.from_string(user_id).domain host = get_domain_from_id(user_id)
hosts_to_states.setdefault(host, []).extend(local_states) hosts_to_states.setdefault(host, []).extend(local_states)
# TODO: de-dup hosts_to_states, as a single host might have multiple # TODO: de-dup hosts_to_states, as a single host might have multiple
@ -611,14 +711,14 @@ class PresenceHandler(BaseHandler):
# don't need to send to local clients here, as that is done as part # don't need to send to local clients here, as that is done as part
# of the event stream/sync. # of the event stream/sync.
# TODO: Only send to servers not already in the room. # TODO: Only send to servers not already in the room.
if self.hs.is_mine(user): if self.is_mine(user):
state = yield self.current_state_for_user(user.to_string()) state = yield self.current_state_for_user(user.to_string())
hosts = yield self.store.get_joined_hosts_for_room(room_id) hosts = yield self.store.get_joined_hosts_for_room(room_id)
self._push_to_remotes({host: (state,) for host in hosts}) self._push_to_remotes({host: (state,) for host in hosts})
else: else:
user_ids = yield self.store.get_users_in_room(room_id) user_ids = yield self.store.get_users_in_room(room_id)
user_ids = filter(self.hs.is_mine_id, user_ids) user_ids = filter(self.is_mine_id, user_ids)
states = yield self.current_state_for_users(user_ids) states = yield self.current_state_for_users(user_ids)
@ -628,7 +728,7 @@ class PresenceHandler(BaseHandler):
def get_presence_list(self, observer_user, accepted=None): def get_presence_list(self, observer_user, accepted=None):
"""Returns the presence for all users in their presence list. """Returns the presence for all users in their presence list.
""" """
if not self.hs.is_mine(observer_user): if not self.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
presence_list = yield self.store.get_presence_list( presence_list = yield self.store.get_presence_list(
@ -659,7 +759,7 @@ class PresenceHandler(BaseHandler):
observer_user.localpart, observed_user.to_string() observer_user.localpart, observed_user.to_string()
) )
if self.hs.is_mine(observed_user): if self.is_mine(observed_user):
yield self.invite_presence(observed_user, observer_user) yield self.invite_presence(observed_user, observer_user)
else: else:
yield self.federation.send_edu( yield self.federation.send_edu(
@ -675,11 +775,11 @@ class PresenceHandler(BaseHandler):
def invite_presence(self, observed_user, observer_user): def invite_presence(self, observed_user, observer_user):
"""Handles new presence invites. """Handles new presence invites.
""" """
if not self.hs.is_mine(observed_user): if not self.is_mine(observed_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
# TODO: Don't auto accept # TODO: Don't auto accept
if self.hs.is_mine(observer_user): if self.is_mine(observer_user):
yield self.accept_presence(observed_user, observer_user) yield self.accept_presence(observed_user, observer_user)
else: else:
self.federation.send_edu( self.federation.send_edu(
@ -742,7 +842,7 @@ class PresenceHandler(BaseHandler):
Returns: Returns:
A Deferred. A Deferred.
""" """
if not self.hs.is_mine(observer_user): if not self.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
yield self.store.del_presence_list( yield self.store.del_presence_list(
@ -834,7 +934,11 @@ def _format_user_presence_state(state, now):
class PresenceEventSource(object): class PresenceEventSource(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs # We can't call get_presence_handler here because there's a cycle:
#
# Presence -> Notifier -> PresenceEventSource -> Presence
#
self.get_presence_handler = hs.get_presence_handler
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -860,7 +964,7 @@ class PresenceEventSource(object):
from_key = int(from_key) from_key = int(from_key)
room_ids = room_ids or [] room_ids = room_ids or []
presence = self.hs.get_handlers().presence_handler presence = self.get_presence_handler()
stream_change_cache = self.store.presence_stream_cache stream_change_cache = self.store.presence_stream_cache
if not room_ids: if not room_ids:
@ -877,13 +981,13 @@ class PresenceEventSource(object):
user_ids_changed = set() user_ids_changed = set()
changed = None changed = None
if from_key and max_token - from_key < 100: if from_key:
# For small deltas, its quicker to get all changes and then
# work out if we share a room or they're in our presence list
changed = stream_change_cache.get_all_entities_changed(from_key) changed = stream_change_cache.get_all_entities_changed(from_key)
# get_all_entities_changed can return None if changed is not None and len(changed) < 500:
if changed is not None: # For small deltas, its quicker to get all changes and then
# work out if we share a room or they're in our presence list
get_updates_counter.inc("stream")
for other_user_id in changed: for other_user_id in changed:
if other_user_id in friends: if other_user_id in friends:
user_ids_changed.add(other_user_id) user_ids_changed.add(other_user_id)
@ -895,6 +999,8 @@ class PresenceEventSource(object):
else: else:
# Too many possible updates. Find all users we can see and check # Too many possible updates. Find all users we can see and check
# if any of them have changed. # if any of them have changed.
get_updates_counter.inc("full")
user_ids_to_check = set() user_ids_to_check = set()
for room_id in room_ids: for room_id in room_ids:
users = yield self.store.get_users_in_room(room_id) users = yield self.store.get_users_in_room(room_id)
@ -933,15 +1039,14 @@ class PresenceEventSource(object):
return self.get_new_events(user, from_key=None, include_offline=False) return self.get_new_events(user, from_key=None, include_offline=False)
def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now): def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
"""Checks the presence of users that have timed out and updates as """Checks the presence of users that have timed out and updates as
appropriate. appropriate.
Args: Args:
user_states(list): List of UserPresenceState's to check. user_states(list): List of UserPresenceState's to check.
is_mine_fn (fn): Function that returns if a user_id is ours is_mine_fn (fn): Function that returns if a user_id is ours
user_to_num_current_syncs (dict): Mapping of user_id to number of currently syncing_user_ids (set): Set of user_ids with active syncs.
active syncs.
now (int): Current time in ms. now (int): Current time in ms.
Returns: Returns:
@ -952,21 +1057,20 @@ def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now):
for state in user_states: for state in user_states:
is_mine = is_mine_fn(state.user_id) is_mine = is_mine_fn(state.user_id)
new_state = handle_timeout(state, is_mine, user_to_num_current_syncs, now) new_state = handle_timeout(state, is_mine, syncing_user_ids, now)
if new_state: if new_state:
changes[state.user_id] = new_state changes[state.user_id] = new_state
return changes.values() return changes.values()
def handle_timeout(state, is_mine, user_to_num_current_syncs, now): def handle_timeout(state, is_mine, syncing_user_ids, now):
"""Checks the presence of the user to see if any of the timers have elapsed """Checks the presence of the user to see if any of the timers have elapsed
Args: Args:
state (UserPresenceState) state (UserPresenceState)
is_mine (bool): Whether the user is ours is_mine (bool): Whether the user is ours
user_to_num_current_syncs (dict): Mapping of user_id to number of currently syncing_user_ids (set): Set of user_ids with active syncs.
active syncs.
now (int): Current time in ms. now (int): Current time in ms.
Returns: Returns:
@ -1000,7 +1104,7 @@ def handle_timeout(state, is_mine, user_to_num_current_syncs, now):
# If there are have been no sync for a while (and none ongoing), # If there are have been no sync for a while (and none ongoing),
# set presence to offline # set presence to offline
if not user_to_num_current_syncs.get(user_id, 0): if user_id not in syncing_user_ids:
if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT: if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT:
state = state.copy_and_replace( state = state.copy_and_replace(
state=PresenceState.OFFLINE, state=PresenceState.OFFLINE,

View file

@ -17,7 +17,6 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.types import UserID, Requester from synapse.types import UserID, Requester
from synapse.util import unwrapFirstError
from ._base import BaseHandler from ._base import BaseHandler
@ -27,14 +26,6 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def changed_presencelike_data(distributor, user, state):
return distributor.fire("changed_presencelike_data", user, state)
def collect_presencelike_data(distributor, user, content):
return distributor.fire("collect_presencelike_data", user, content)
class ProfileHandler(BaseHandler): class ProfileHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
@ -46,17 +37,9 @@ class ProfileHandler(BaseHandler):
) )
distributor = hs.get_distributor() distributor = hs.get_distributor()
self.distributor = distributor
distributor.declare("collect_presencelike_data")
distributor.declare("changed_presencelike_data")
distributor.observe("registered_user", self.registered_user) distributor.observe("registered_user", self.registered_user)
distributor.observe(
"collect_presencelike_data", self.collect_presencelike_data
)
def registered_user(self, user): def registered_user(self, user):
return self.store.create_profile(user.localpart) return self.store.create_profile(user.localpart)
@ -105,10 +88,6 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_displayname target_user.localpart, new_displayname
) )
yield changed_presencelike_data(self.distributor, target_user, {
"displayname": new_displayname,
})
yield self._update_join_states(requester) yield self._update_join_states(requester)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -152,30 +131,8 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_avatar_url target_user.localpart, new_avatar_url
) )
yield changed_presencelike_data(self.distributor, target_user, {
"avatar_url": new_avatar_url,
})
yield self._update_join_states(requester) yield self._update_join_states(requester)
@defer.inlineCallbacks
def collect_presencelike_data(self, user, state):
if not self.hs.is_mine(user):
defer.returnValue(None)
(displayname, avatar_url) = yield defer.gatherResults(
[
self.store.get_profile_displayname(user.localpart),
self.store.get_profile_avatar_url(user.localpart),
],
consumeErrors=True
).addErrback(unwrapFirstError)
state["displayname"] = displayname
state["avatar_url"] = avatar_url
defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_profile_query(self, args): def on_profile_query(self, args):
user = UserID.from_string(args["user_id"]) user = UserID.from_string(args["user_id"])

View file

@ -29,6 +29,8 @@ class ReceiptsHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(ReceiptsHandler, self).__init__(hs) super(ReceiptsHandler, self).__init__(hs)
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.federation.register_edu_handler( self.federation.register_edu_handler(
@ -80,6 +82,9 @@ class ReceiptsHandler(BaseHandler):
def _handle_new_receipts(self, receipts): def _handle_new_receipts(self, receipts):
"""Takes a list of receipts, stores them and informs the notifier. """Takes a list of receipts, stores them and informs the notifier.
""" """
min_batch_id = None
max_batch_id = None
for receipt in receipts: for receipt in receipts:
room_id = receipt["room_id"] room_id = receipt["room_id"]
receipt_type = receipt["receipt_type"] receipt_type = receipt["receipt_type"]
@ -97,10 +102,21 @@ class ReceiptsHandler(BaseHandler):
stream_id, max_persisted_id = res stream_id, max_persisted_id = res
with PreserveLoggingContext(): if min_batch_id is None or stream_id < min_batch_id:
self.notifier.on_new_event( min_batch_id = stream_id
"receipt_key", max_persisted_id, rooms=[room_id] if max_batch_id is None or max_persisted_id > max_batch_id:
) max_batch_id = max_persisted_id
affected_room_ids = list(set([r["room_id"] for r in receipts]))
with PreserveLoggingContext():
self.notifier.on_new_event(
"receipt_key", max_batch_id, rooms=affected_room_ids
)
# Note that the min here shouldn't be relied upon to be accurate.
self.hs.get_pusherpool().on_new_receipts(
min_batch_id, max_batch_id, affected_room_ids
)
defer.returnValue(True) defer.returnValue(True)
@ -117,12 +133,9 @@ class ReceiptsHandler(BaseHandler):
event_ids = receipt["event_ids"] event_ids = receipt["event_ids"]
data = receipt["data"] data = receipt["data"]
remotedomains = set() remotedomains = yield self.store.get_joined_hosts_for_room(room_id)
remotedomains = remotedomains.copy()
rm_handler = self.hs.get_handlers().room_member_handler remotedomains.discard(self.server_name)
yield rm_handler.fetch_room_distributions_into(
room_id, localusers=None, remotedomains=remotedomains
)
logger.debug("Sending receipt to: %r", remotedomains) logger.debug("Sending receipt to: %r", remotedomains)

View file

@ -16,13 +16,14 @@
"""Contains functions for registering clients.""" """Contains functions for registering clients."""
from twisted.internet import defer from twisted.internet import defer
from synapse.types import UserID from synapse.types import UserID, Requester
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
) )
from ._base import BaseHandler from ._base import BaseHandler
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient from synapse.http.client import CaptchaServerHttpClient
from synapse.util.distributor import registered_user
import logging import logging
import urllib import urllib
@ -30,10 +31,6 @@ import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def registered_user(distributor, user):
return distributor.fire("registered_user", user)
class RegistrationHandler(BaseHandler): class RegistrationHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
@ -361,8 +358,62 @@ class RegistrationHandler(BaseHandler):
) )
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks
def get_or_create_user(self, localpart, displayname, duration_seconds):
"""Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.
Args:
localpart : The local part of the user ID to register. If None,
one will be randomly generated.
Returns:
A tuple of (user_id, access_token).
Raises:
RegistrationError if there was a problem registering.
"""
yield run_on_reactor()
if localpart is None:
raise SynapseError(400, "Request must include user id")
need_register = True
try:
yield self.check_username(localpart)
except SynapseError as e:
if e.errcode == Codes.USER_IN_USE:
need_register = False
else:
raise
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
auth_handler = self.hs.get_handlers().auth_handler
token = auth_handler.generate_short_term_login_token(user_id, duration_seconds)
if need_register:
yield self.store.register(
user_id=user_id,
token=token,
password_hash=None
)
yield registered_user(self.distributor, user)
else:
yield self.store.user_delete_access_tokens(user_id=user_id)
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler
yield profile_handler.set_displayname(
user, Requester(user, token, False), displayname
)
defer.returnValue((user_id, token))
def auth_handler(self): def auth_handler(self):
return self.hs.get_handlers().auth_handler return self.hs.get_auth_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def guest_access_token_for(self, medium, address, inviter_user_id): def guest_access_token_for(self, medium, address, inviter_user_id):

View file

@ -18,19 +18,17 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken, Requester from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, Membership, JoinRules, RoomCreationPreset, EventTypes, JoinRules, RoomCreationPreset,
) )
from synapse.api.errors import AuthError, StoreError, SynapseError, Codes from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.util import stringutils, unwrapFirstError from synapse.util import stringutils
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.async import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
from signedjson.sign import verify_signed_json from synapse.visibility import filter_events_for_client
from signedjson.key import decode_verify_key_bytes
from collections import OrderedDict from collections import OrderedDict
from unpaddedbase64 import decode_base64
import logging import logging
import math import math
@ -38,23 +36,11 @@ import string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
id_server_scheme = "https://" id_server_scheme = "https://"
def user_left_room(distributor, user, room_id):
return preserve_context_over_fn(
distributor.fire,
"user_left_room", user=user, room_id=room_id
)
def user_joined_room(distributor, user, room_id):
return preserve_context_over_fn(
distributor.fire,
"user_joined_room", user=user, room_id=room_id
)
class RoomCreationHandler(BaseHandler): class RoomCreationHandler(BaseHandler):
PRESETS_DICT = { PRESETS_DICT = {
@ -356,598 +342,31 @@ class RoomCreationHandler(BaseHandler):
) )
class RoomMemberHandler(BaseHandler):
# TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level
# API that takes ID strings and returns pagination chunks. These concerns
# ought to be separated out a lot better.
def __init__(self, hs):
super(RoomMemberHandler, self).__init__(hs)
self.clock = hs.get_clock()
self.distributor = hs.get_distributor()
self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
@defer.inlineCallbacks
def get_room_members(self, room_id):
users = yield self.store.get_users_in_room(room_id)
defer.returnValue([UserID.from_string(u) for u in users])
@defer.inlineCallbacks
def fetch_room_distributions_into(self, room_id, localusers=None,
remotedomains=None, ignore_user=None):
"""Fetch the distribution of a room, adding elements to either
'localusers' or 'remotedomains', which should be a set() if supplied.
If ignore_user is set, ignore that user.
This function returns nothing; its result is performed by the
side-effect on the two passed sets. This allows easy accumulation of
member lists of multiple rooms at once if required.
"""
members = yield self.get_room_members(room_id)
for member in members:
if ignore_user is not None and member == ignore_user:
continue
if self.hs.is_mine(member):
if localusers is not None:
localusers.add(member)
else:
if remotedomains is not None:
remotedomains.add(member.domain)
@defer.inlineCallbacks
def update_membership(
self,
requester,
target,
room_id,
action,
txn_id=None,
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
):
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
elif action == "forget":
effective_membership_state = "leave"
if third_party_signed is not None:
replication = self.hs.get_replication_layer()
yield replication.exchange_third_party_invite(
third_party_signed["sender"],
target.to_string(),
room_id,
third_party_signed,
)
msg_handler = self.hs.get_handlers().message_handler
content = {"membership": effective_membership_state}
if requester.is_guest:
content["kind"] = "guest"
event, context = yield msg_handler.create_event(
{
"type": EventTypes.Member,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
"state_key": target.to_string(),
# For backwards compatibility:
"membership": effective_membership_state,
},
token_id=requester.access_token_id,
txn_id=txn_id,
)
old_state = context.current_state.get((EventTypes.Member, event.state_key))
old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban":
raise SynapseError(
403,
"Cannot unban user who was not banned (membership=%s)" % old_membership,
errcode=Codes.BAD_STATE
)
if old_membership == "ban" and action != "unban":
raise SynapseError(
403,
"Cannot %s user who was is banned" % (action,),
errcode=Codes.BAD_STATE
)
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(
requester,
event,
context,
ratelimit=ratelimit,
remote_room_hosts=remote_room_hosts,
)
if action == "forget":
yield self.forget(requester.user, room_id)
@defer.inlineCallbacks
def send_membership_event(
self,
requester,
event,
context,
remote_room_hosts=None,
ratelimit=True,
):
"""
Change the membership status of a user in a room.
Args:
requester (Requester): The local user who requested the membership
event. If None, certain checks, like whether this homeserver can
act as the sender, will be skipped.
event (SynapseEvent): The membership event.
context: The context of the event.
is_guest (bool): Whether the sender is a guest.
room_hosts ([str]): Homeservers which are likely to already be in
the room, and could be danced with in order to join this
homeserver for the first time.
ratelimit (bool): Whether to rate limit this request.
Raises:
SynapseError if there was a problem changing the membership.
"""
remote_room_hosts = remote_room_hosts or []
target_user = UserID.from_string(event.state_key)
room_id = event.room_id
if requester is not None:
sender = UserID.from_string(event.sender)
assert sender == requester.user, (
"Sender (%s) must be same as requester (%s)" %
(sender, requester.user)
)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
requester = Requester(target_user, None, False)
message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context)
if prev_event is not None:
return
action = "send"
if event.membership == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(context.current_state):
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
do_remote_join_dance, remote_room_hosts = self._should_do_dance(
context,
(self.get_inviter(event.state_key, context.current_state)),
remote_room_hosts,
)
if do_remote_join_dance:
action = "remote_join"
elif event.membership == Membership.LEAVE:
is_host_in_room = self.is_host_in_room(context.current_state)
if not is_host_in_room:
# perhaps we've been invited
inviter = self.get_inviter(target_user.to_string(), context.current_state)
if not inviter:
raise SynapseError(404, "Not a known room")
if self.hs.is_mine(inviter):
# the inviter was on our server, but has now left. Carry on
# with the normal rejection codepath.
#
# This is a bit of a hack, because the room might still be
# active on other servers.
pass
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
action = "remote_reject"
federation_handler = self.hs.get_handlers().federation_handler
if action == "remote_join":
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
# We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
yield federation_handler.do_invite_join(
remote_room_hosts,
event.room_id,
event.user_id,
event.content,
)
elif action == "remote_reject":
yield federation_handler.do_remotely_reject_invite(
remote_room_hosts,
room_id,
event.user_id
)
else:
yield self.handle_new_client_event(
requester,
event,
context,
extra_users=[target_user],
ratelimit=ratelimit,
)
prev_member_event = context.current_state.get(
(EventTypes.Member, target_user.to_string()),
None
)
if event.membership == Membership.JOIN:
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile
# info.
yield user_joined_room(self.distributor, target_user, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event and prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target_user, room_id)
def _can_guest_join(self, current_state):
"""
Returns whether a guest can join a room based on its current state.
"""
guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
return (
guest_access
and guest_access.content
and "guest_access" in guest_access.content
and guest_access.content["guest_access"] == "can_join"
)
def _should_do_dance(self, context, inviter, room_hosts=None):
# TODO: Shouldn't this be remote_room_host?
room_hosts = room_hosts or []
is_host_in_room = self.is_host_in_room(context.current_state)
if is_host_in_room:
return False, room_hosts
if inviter and not self.hs.is_mine(inviter):
room_hosts.append(inviter.domain)
return True, room_hosts
@defer.inlineCallbacks
def lookup_room_alias(self, room_alias):
"""
Get the room ID associated with a room alias.
Args:
room_alias (RoomAlias): The alias to look up.
Returns:
A tuple of:
The room ID as a RoomID object.
Hosts likely to be participating in the room ([str]).
Raises:
SynapseError if room alias could not be found.
"""
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if not mapping:
raise SynapseError(404, "No such room alias")
room_id = mapping["room_id"]
servers = mapping["servers"]
defer.returnValue((RoomID.from_string(room_id), servers))
def get_inviter(self, user_id, current_state):
prev_state = current_state.get((EventTypes.Member, user_id))
if prev_state and prev_state.membership == Membership.INVITE:
return UserID.from_string(prev_state.user_id)
return None
@defer.inlineCallbacks
def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given
membership states in."""
rooms = yield self.store.get_rooms_for_user(
user.to_string(),
)
# For some reason the list of events contains duplicates
# TODO(paul): work out why because I really don't think it should
room_ids = set(r.room_id for r in rooms)
defer.returnValue(room_ids)
@defer.inlineCallbacks
def do_3pid_invite(
self,
room_id,
inviter,
medium,
address,
id_server,
requester,
txn_id
):
invitee = yield self._lookup_3pid(
id_server, medium, address
)
if invitee:
handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership(
requester,
UserID.from_string(invitee),
room_id,
"invite",
txn_id=txn_id,
)
else:
yield self._make_and_store_3pid_invite(
requester,
id_server,
medium,
address,
room_id,
inviter,
txn_id=txn_id
)
@defer.inlineCallbacks
def _lookup_3pid(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
Returns:
(str) the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
{
"medium": medium,
"address": address,
}
)
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
self.verify_any_signature(data, id_server)
defer.returnValue(data["mxid"])
except IOError as e:
logger.warn("Error from identity server lookup: %s" % (e,))
defer.returnValue(None)
@defer.inlineCallbacks
def verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
key_data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s" %
(id_server_scheme, server_hostname, key_name,),
)
if "public_key" not in key_data:
raise AuthError(401, "No public key named %s from %s" %
(key_name, server_hostname,))
verify_signed_json(
data,
server_hostname,
decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"]))
)
return
@defer.inlineCallbacks
def _make_and_store_3pid_invite(
self,
requester,
id_server,
medium,
address,
room_id,
user,
txn_id
):
room_state = yield self.hs.get_state_handler().get_current_state(room_id)
inviter_display_name = ""
inviter_avatar_url = ""
member_event = room_state.get((EventTypes.Member, user.to_string()))
if member_event:
inviter_display_name = member_event.content.get("displayname", "")
inviter_avatar_url = member_event.content.get("avatar_url", "")
canonical_room_alias = ""
canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event:
canonical_room_alias = canonical_alias_event.content.get("alias", "")
room_name = ""
room_name_event = room_state.get((EventTypes.Name, ""))
if room_name_event:
room_name = room_name_event.content.get("name", "")
room_join_rules = ""
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
room_join_rules = join_rules_event.content.get("join_rule", "")
room_avatar_url = ""
room_avatar_event = room_state.get((EventTypes.RoomAvatar, ""))
if room_avatar_event:
room_avatar_url = room_avatar_event.content.get("url", "")
token, public_keys, fallback_public_key, display_name = (
yield self._ask_id_server_for_third_party_invite(
id_server=id_server,
medium=medium,
address=address,
room_id=room_id,
inviter_user_id=user.to_string(),
room_alias=canonical_room_alias,
room_avatar_url=room_avatar_url,
room_join_rules=room_join_rules,
room_name=room_name,
inviter_display_name=inviter_display_name,
inviter_avatar_url=inviter_avatar_url
)
)
msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
"content": {
"display_name": display_name,
"public_keys": public_keys,
# For backwards compatibility:
"key_validity_url": fallback_public_key["key_validity_url"],
"public_key": fallback_public_key["public_key"],
},
"room_id": room_id,
"sender": user.to_string(),
"state_key": token,
},
txn_id=txn_id,
)
@defer.inlineCallbacks
def _ask_id_server_for_third_party_invite(
self,
id_server,
medium,
address,
room_id,
inviter_user_id,
room_alias,
room_avatar_url,
room_join_rules,
room_name,
inviter_display_name,
inviter_avatar_url
):
"""
Asks an identity server for a third party invite.
:param id_server (str): hostname + optional port for the identity server.
:param medium (str): The literal string "email".
:param address (str): The third party address being invited.
:param room_id (str): The ID of the room to which the user is invited.
:param inviter_user_id (str): The user ID of the inviter.
:param room_alias (str): An alias for the room, for cosmetic
notifications.
:param room_avatar_url (str): The URL of the room's avatar, for cosmetic
notifications.
:param room_join_rules (str): The join rules of the email
(e.g. "public").
:param room_name (str): The m.room.name of the room.
:param inviter_display_name (str): The current display name of the
inviter.
:param inviter_avatar_url (str): The URL of the inviter's avatar.
:return: A deferred tuple containing:
token (str): The token which must be signed to prove authenticity.
public_keys ([{"public_key": str, "key_validity_url": str}]):
public_key is a base64-encoded ed25519 public key.
fallback_public_key: One element from public_keys.
display_name (str): A user-friendly name to represent the invited
user.
"""
is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
id_server_scheme, id_server,
)
invite_config = {
"medium": medium,
"address": address,
"room_id": room_id,
"room_alias": room_alias,
"room_avatar_url": room_avatar_url,
"room_join_rules": room_join_rules,
"room_name": room_name,
"sender": inviter_user_id,
"sender_display_name": inviter_display_name,
"sender_avatar_url": inviter_avatar_url,
}
if self.hs.config.invite_3pid_guest:
registration_handler = self.hs.get_handlers().registration_handler
guest_access_token = yield registration_handler.guest_access_token_for(
medium=medium,
address=address,
inviter_user_id=inviter_user_id,
)
guest_user_info = yield self.hs.get_auth().get_user_by_access_token(
guest_access_token
)
invite_config.update({
"guest_access_token": guest_access_token,
"guest_user_id": guest_user_info["user"].to_string(),
})
data = yield self.hs.get_simple_http_client().post_urlencoded_get_json(
is_url,
invite_config
)
# TODO: Check for success
token = data["token"]
public_keys = data.get("public_keys", [])
if "public_key" in data:
fallback_public_key = {
"public_key": data["public_key"],
"key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
id_server_scheme, id_server,
),
}
else:
fallback_public_key = public_keys[0]
if not public_keys:
public_keys.append(fallback_public_key)
display_name = data["display_name"]
defer.returnValue((token, public_keys, fallback_public_key, display_name))
def forget(self, user, room_id):
return self.store.forget(user.to_string(), room_id)
class RoomListHandler(BaseHandler): class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
self.response_cache = ResponseCache()
self.remote_list_request_cache = ResponseCache()
self.remote_list_cache = {}
self.fetch_looping_call = hs.get_clock().looping_call(
self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL
)
self.fetch_all_remote_lists()
def get_local_public_room_list(self):
result = self.response_cache.get(())
if not result:
result = self.response_cache.set((), self._get_public_room_list())
return result
@defer.inlineCallbacks @defer.inlineCallbacks
def get_public_room_list(self): def _get_public_room_list(self):
room_ids = yield self.store.get_public_room_ids() room_ids = yield self.store.get_public_room_ids()
results = []
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_room(room_id): def handle_room(room_id):
aliases = yield self.store.get_aliases_for_room(room_id)
# We pull each bit of state out indvidually to avoid pulling the # We pull each bit of state out indvidually to avoid pulling the
# full state into memory. Due to how the caching works this should # full state into memory. Due to how the caching works this should
# be fairly quick, even if not originally in the cache. # be fairly quick, even if not originally in the cache.
@ -962,6 +381,14 @@ class RoomListHandler(BaseHandler):
defer.returnValue(None) defer.returnValue(None)
result = {"room_id": room_id} result = {"room_id": room_id}
joined_users = yield self.store.get_users_in_room(room_id)
if len(joined_users) == 0:
return
result["num_joined_members"] = len(joined_users)
aliases = yield self.store.get_aliases_for_room(room_id)
if aliases: if aliases:
result["aliases"] = aliases result["aliases"] = aliases
@ -1001,21 +428,61 @@ class RoomListHandler(BaseHandler):
if avatar_url: if avatar_url:
result["avatar_url"] = avatar_url result["avatar_url"] = avatar_url
joined_users = yield self.store.get_users_in_room(room_id) results.append(result)
result["num_joined_members"] = len(joined_users)
defer.returnValue(result) yield concurrently_execute(handle_room, room_ids, 10)
result = []
for chunk in (room_ids[i:i + 10] for i in xrange(0, len(room_ids), 10)):
chunk_result = yield defer.gatherResults([
handle_room(room_id)
for room_id in chunk
], consumeErrors=True).addErrback(unwrapFirstError)
result.extend(v for v in chunk_result if v)
# FIXME (erikj): START is no longer a valid value # FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": result}) defer.returnValue({"start": "START", "end": "END", "chunk": results})
@defer.inlineCallbacks
def fetch_all_remote_lists(self):
deferred = self.hs.get_replication_layer().get_public_rooms(
self.hs.config.secondary_directory_servers
)
self.remote_list_request_cache.set((), deferred)
self.remote_list_cache = yield deferred
@defer.inlineCallbacks
def get_aggregated_public_room_list(self):
"""
Get the public room list from this server and the servers
specified in the secondary_directory_servers config option.
XXX: Pagination...
"""
# We return the results from out cache which is updated by a looping call,
# unless we're missing a cache entry, in which case wait for the result
# of the fetch if there's one in progress. If not, omit that server.
wait = False
for s in self.hs.config.secondary_directory_servers:
if s not in self.remote_list_cache:
logger.warn("No cached room list from %s: waiting for fetch", s)
wait = True
break
if wait and self.remote_list_request_cache.get(()):
yield self.remote_list_request_cache.get(())
public_rooms = yield self.get_local_public_room_list()
# keep track of which room IDs we've seen so we can de-dup
room_ids = set()
# tag all the ones in our list with our server name.
# Also add the them to the de-deping set
for room in public_rooms['chunk']:
room["server_name"] = self.hs.hostname
room_ids.add(room["room_id"])
# Now add the results from federation
for server_name, server_result in self.remote_list_cache.items():
for room in server_result["chunk"]:
if room["room_id"] not in room_ids:
room["server_name"] = server_name
public_rooms["chunk"].append(room)
room_ids.add(room["room_id"])
defer.returnValue(public_rooms)
class RoomContextHandler(BaseHandler): class RoomContextHandler(BaseHandler):
@ -1040,10 +507,12 @@ class RoomContextHandler(BaseHandler):
now_token = yield self.hs.get_event_sources().get_current_token() now_token = yield self.hs.get_event_sources().get_current_token()
def filter_evts(events): def filter_evts(events):
return self._filter_events_for_client( return filter_events_for_client(
self.store,
user.to_string(), user.to_string(),
events, events,
is_peeking=is_guest) is_peeking=is_guest
)
event = yield self.store.get_event(event_id, get_prev_content=True, event = yield self.store.get_event(event_id, get_prev_content=True,
allow_none=True) allow_none=True)

View file

@ -0,0 +1,677 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.types import UserID, RoomID, Requester
from synapse.api.constants import (
EventTypes, Membership,
)
from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.util.async import Linearizer
from synapse.util.distributor import user_left_room, user_joined_room
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
import logging
logger = logging.getLogger(__name__)
id_server_scheme = "https://"
class RoomMemberHandler(BaseHandler):
# TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level
# API that takes ID strings and returns pagination chunks. These concerns
# ought to be separated out a lot better.
def __init__(self, hs):
super(RoomMemberHandler, self).__init__(hs)
self.member_linearizer = Linearizer()
self.clock = hs.get_clock()
self.distributor = hs.get_distributor()
self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
@defer.inlineCallbacks
def _local_membership_update(
self, requester, target, room_id, membership,
prev_event_ids,
txn_id=None,
ratelimit=True,
):
msg_handler = self.hs.get_handlers().message_handler
content = {"membership": membership}
if requester.is_guest:
content["kind"] = "guest"
event, context = yield msg_handler.create_event(
{
"type": EventTypes.Member,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
"state_key": target.to_string(),
# For backwards compatibility:
"membership": membership,
},
token_id=requester.access_token_id,
txn_id=txn_id,
prev_event_ids=prev_event_ids,
)
yield msg_handler.handle_new_client_event(
requester,
event,
context,
extra_users=[target],
ratelimit=ratelimit,
)
prev_member_event = context.current_state.get(
(EventTypes.Member, target.to_string()),
None
)
if event.membership == Membership.JOIN:
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile
# info.
yield user_joined_room(self.distributor, target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event and prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target, room_id)
@defer.inlineCallbacks
def remote_join(self, remote_room_hosts, room_id, user, content):
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
# We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
yield self.hs.get_handlers().federation_handler.do_invite_join(
remote_room_hosts,
room_id,
user.to_string(),
content,
)
yield user_joined_room(self.distributor, user, room_id)
def reject_remote_invite(self, user_id, room_id, remote_room_hosts):
return self.hs.get_handlers().federation_handler.do_remotely_reject_invite(
remote_room_hosts,
room_id,
user_id
)
@defer.inlineCallbacks
def update_membership(
self,
requester,
target,
room_id,
action,
txn_id=None,
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
):
key = (target, room_id,)
with (yield self.member_linearizer.queue(key)):
result = yield self._update_membership(
requester,
target,
room_id,
action,
txn_id=txn_id,
remote_room_hosts=remote_room_hosts,
third_party_signed=third_party_signed,
ratelimit=ratelimit,
)
defer.returnValue(result)
@defer.inlineCallbacks
def _update_membership(
self,
requester,
target,
room_id,
action,
txn_id=None,
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
):
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
if third_party_signed is not None:
replication = self.hs.get_replication_layer()
yield replication.exchange_third_party_invite(
third_party_signed["sender"],
target.to_string(),
room_id,
third_party_signed,
)
if not remote_room_hosts:
remote_room_hosts = []
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
current_state = yield self.state_handler.get_current_state(
room_id, latest_event_ids=latest_event_ids,
)
old_state = current_state.get((EventTypes.Member, target.to_string()))
old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban":
raise SynapseError(
403,
"Cannot unban user who was not banned (membership=%s)" % old_membership,
errcode=Codes.BAD_STATE
)
if old_membership == "ban" and action != "unban":
raise SynapseError(
403,
"Cannot %s user who was banned" % (action,),
errcode=Codes.BAD_STATE
)
is_host_in_room = self.is_host_in_room(current_state)
if effective_membership_state == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(current_state):
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
if not is_host_in_room:
inviter = yield self.get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
content = {"membership": Membership.JOIN}
profile = self.hs.get_handlers().profile_handler
content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target)
if requester.is_guest:
content["kind"] = "guest"
ret = yield self.remote_join(
remote_room_hosts, room_id, target, content
)
defer.returnValue(ret)
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
inviter = yield self.get_inviter(target.to_string(), room_id)
if not inviter:
raise SynapseError(404, "Not a known room")
if self.hs.is_mine(inviter):
# the inviter was on our server, but has now left. Carry on
# with the normal rejection codepath.
#
# This is a bit of a hack, because the room might still be
# active on other servers.
pass
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
try:
ret = yield self.reject_remote_invite(
target.to_string(), room_id, remote_room_hosts
)
defer.returnValue(ret)
except SynapseError as e:
logger.warn("Failed to reject invite: %s", e)
yield self.store.locally_reject_invite(
target.to_string(), room_id
)
defer.returnValue({})
yield self._local_membership_update(
requester=requester,
target=target,
room_id=room_id,
membership=effective_membership_state,
txn_id=txn_id,
ratelimit=ratelimit,
prev_event_ids=latest_event_ids,
)
@defer.inlineCallbacks
def send_membership_event(
self,
requester,
event,
context,
remote_room_hosts=None,
ratelimit=True,
):
"""
Change the membership status of a user in a room.
Args:
requester (Requester): The local user who requested the membership
event. If None, certain checks, like whether this homeserver can
act as the sender, will be skipped.
event (SynapseEvent): The membership event.
context: The context of the event.
is_guest (bool): Whether the sender is a guest.
room_hosts ([str]): Homeservers which are likely to already be in
the room, and could be danced with in order to join this
homeserver for the first time.
ratelimit (bool): Whether to rate limit this request.
Raises:
SynapseError if there was a problem changing the membership.
"""
remote_room_hosts = remote_room_hosts or []
target_user = UserID.from_string(event.state_key)
room_id = event.room_id
if requester is not None:
sender = UserID.from_string(event.sender)
assert sender == requester.user, (
"Sender (%s) must be same as requester (%s)" %
(sender, requester.user)
)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
requester = Requester(target_user, None, False)
message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context)
if prev_event is not None:
return
if event.membership == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(context.current_state):
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
yield message_handler.handle_new_client_event(
requester,
event,
context,
extra_users=[target_user],
ratelimit=ratelimit,
)
prev_member_event = context.current_state.get(
(EventTypes.Member, target_user.to_string()),
None
)
if event.membership == Membership.JOIN:
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile
# info.
yield user_joined_room(self.distributor, target_user, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event and prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target_user, room_id)
def _can_guest_join(self, current_state):
"""
Returns whether a guest can join a room based on its current state.
"""
guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
return (
guest_access
and guest_access.content
and "guest_access" in guest_access.content
and guest_access.content["guest_access"] == "can_join"
)
@defer.inlineCallbacks
def lookup_room_alias(self, room_alias):
"""
Get the room ID associated with a room alias.
Args:
room_alias (RoomAlias): The alias to look up.
Returns:
A tuple of:
The room ID as a RoomID object.
Hosts likely to be participating in the room ([str]).
Raises:
SynapseError if room alias could not be found.
"""
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if not mapping:
raise SynapseError(404, "No such room alias")
room_id = mapping["room_id"]
servers = mapping["servers"]
defer.returnValue((RoomID.from_string(room_id), servers))
@defer.inlineCallbacks
def get_inviter(self, user_id, room_id):
invite = yield self.store.get_invite_for_user_in_room(
user_id=user_id,
room_id=room_id,
)
if invite:
defer.returnValue(UserID.from_string(invite.sender))
@defer.inlineCallbacks
def do_3pid_invite(
self,
room_id,
inviter,
medium,
address,
id_server,
requester,
txn_id
):
invitee = yield self._lookup_3pid(
id_server, medium, address
)
if invitee:
yield self.update_membership(
requester,
UserID.from_string(invitee),
room_id,
"invite",
txn_id=txn_id,
)
else:
yield self._make_and_store_3pid_invite(
requester,
id_server,
medium,
address,
room_id,
inviter,
txn_id=txn_id
)
@defer.inlineCallbacks
def _lookup_3pid(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
Returns:
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
{
"medium": medium,
"address": address,
}
)
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
self.verify_any_signature(data, id_server)
defer.returnValue(data["mxid"])
except IOError as e:
logger.warn("Error from identity server lookup: %s" % (e,))
defer.returnValue(None)
@defer.inlineCallbacks
def verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
key_data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s" %
(id_server_scheme, server_hostname, key_name,),
)
if "public_key" not in key_data:
raise AuthError(401, "No public key named %s from %s" %
(key_name, server_hostname,))
verify_signed_json(
data,
server_hostname,
decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"]))
)
return
@defer.inlineCallbacks
def _make_and_store_3pid_invite(
self,
requester,
id_server,
medium,
address,
room_id,
user,
txn_id
):
room_state = yield self.hs.get_state_handler().get_current_state(room_id)
inviter_display_name = ""
inviter_avatar_url = ""
member_event = room_state.get((EventTypes.Member, user.to_string()))
if member_event:
inviter_display_name = member_event.content.get("displayname", "")
inviter_avatar_url = member_event.content.get("avatar_url", "")
canonical_room_alias = ""
canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event:
canonical_room_alias = canonical_alias_event.content.get("alias", "")
room_name = ""
room_name_event = room_state.get((EventTypes.Name, ""))
if room_name_event:
room_name = room_name_event.content.get("name", "")
room_join_rules = ""
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
room_join_rules = join_rules_event.content.get("join_rule", "")
room_avatar_url = ""
room_avatar_event = room_state.get((EventTypes.RoomAvatar, ""))
if room_avatar_event:
room_avatar_url = room_avatar_event.content.get("url", "")
token, public_keys, fallback_public_key, display_name = (
yield self._ask_id_server_for_third_party_invite(
id_server=id_server,
medium=medium,
address=address,
room_id=room_id,
inviter_user_id=user.to_string(),
room_alias=canonical_room_alias,
room_avatar_url=room_avatar_url,
room_join_rules=room_join_rules,
room_name=room_name,
inviter_display_name=inviter_display_name,
inviter_avatar_url=inviter_avatar_url
)
)
msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
"content": {
"display_name": display_name,
"public_keys": public_keys,
# For backwards compatibility:
"key_validity_url": fallback_public_key["key_validity_url"],
"public_key": fallback_public_key["public_key"],
},
"room_id": room_id,
"sender": user.to_string(),
"state_key": token,
},
txn_id=txn_id,
)
@defer.inlineCallbacks
def _ask_id_server_for_third_party_invite(
self,
id_server,
medium,
address,
room_id,
inviter_user_id,
room_alias,
room_avatar_url,
room_join_rules,
room_name,
inviter_display_name,
inviter_avatar_url
):
"""
Asks an identity server for a third party invite.
Args:
id_server (str): hostname + optional port for the identity server.
medium (str): The literal string "email".
address (str): The third party address being invited.
room_id (str): The ID of the room to which the user is invited.
inviter_user_id (str): The user ID of the inviter.
room_alias (str): An alias for the room, for cosmetic notifications.
room_avatar_url (str): The URL of the room's avatar, for cosmetic
notifications.
room_join_rules (str): The join rules of the email (e.g. "public").
room_name (str): The m.room.name of the room.
inviter_display_name (str): The current display name of the
inviter.
inviter_avatar_url (str): The URL of the inviter's avatar.
Returns:
A deferred tuple containing:
token (str): The token which must be signed to prove authenticity.
public_keys ([{"public_key": str, "key_validity_url": str}]):
public_key is a base64-encoded ed25519 public key.
fallback_public_key: One element from public_keys.
display_name (str): A user-friendly name to represent the invited
user.
"""
is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
id_server_scheme, id_server,
)
invite_config = {
"medium": medium,
"address": address,
"room_id": room_id,
"room_alias": room_alias,
"room_avatar_url": room_avatar_url,
"room_join_rules": room_join_rules,
"room_name": room_name,
"sender": inviter_user_id,
"sender_display_name": inviter_display_name,
"sender_avatar_url": inviter_avatar_url,
}
if self.hs.config.invite_3pid_guest:
registration_handler = self.hs.get_handlers().registration_handler
guest_access_token = yield registration_handler.guest_access_token_for(
medium=medium,
address=address,
inviter_user_id=inviter_user_id,
)
guest_user_info = yield self.hs.get_auth().get_user_by_access_token(
guest_access_token
)
invite_config.update({
"guest_access_token": guest_access_token,
"guest_user_id": guest_user_info["user"].to_string(),
})
data = yield self.hs.get_simple_http_client().post_urlencoded_get_json(
is_url,
invite_config
)
# TODO: Check for success
token = data["token"]
public_keys = data.get("public_keys", [])
if "public_key" in data:
fallback_public_key = {
"public_key": data["public_key"],
"key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
id_server_scheme, id_server,
),
}
else:
fallback_public_key = public_keys[0]
if not public_keys:
public_keys.append(fallback_public_key)
display_name = data["display_name"]
defer.returnValue((token, public_keys, fallback_public_key, display_name))
@defer.inlineCallbacks
def forget(self, user, room_id):
user_id = user.to_string()
member = yield self.state_handler.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
)
membership = member.membership if member else None
if membership is not None and membership != Membership.LEAVE:
raise SynapseError(400, "User %s in room %s" % (
user_id, room_id
))
if membership:
yield self.store.forget(user_id, room_id)

View file

@ -21,6 +21,7 @@ from synapse.api.constants import Membership, EventTypes
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.visibility import filter_events_for_client
from unpaddedbase64 import decode_base64, encode_base64 from unpaddedbase64 import decode_base64, encode_base64
@ -172,8 +173,8 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results]) filtered_events = search_filter.filter([r["event"] for r in results])
events = yield self._filter_events_for_client( events = yield filter_events_for_client(
user.to_string(), filtered_events self.store, user.to_string(), filtered_events
) )
events.sort(key=lambda e: -rank_map[e.event_id]) events.sort(key=lambda e: -rank_map[e.event_id])
@ -223,8 +224,8 @@ class SearchHandler(BaseHandler):
r["event"] for r in results r["event"] for r in results
]) ])
events = yield self._filter_events_for_client( events = yield filter_events_for_client(
user.to_string(), filtered_events self.store, user.to_string(), filtered_events
) )
room_events.extend(events) room_events.extend(events)
@ -281,12 +282,12 @@ class SearchHandler(BaseHandler):
event.room_id, event.event_id, before_limit, after_limit event.room_id, event.event_id, before_limit, after_limit
) )
res["events_before"] = yield self._filter_events_for_client( res["events_before"] = yield filter_events_for_client(
user.to_string(), res["events_before"] self.store, user.to_string(), res["events_before"]
) )
res["events_after"] = yield self._filter_events_for_client( res["events_after"] = yield filter_events_for_client(
user.to_string(), res["events_after"] self.store, user.to_string(), res["events_after"]
) )
res["start"] = now_token.copy_and_replace( res["start"] = now_token.copy_and_replace(

File diff suppressed because it is too large Load diff

View file

@ -15,8 +15,6 @@
from twisted.internet import defer from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -32,14 +30,16 @@ logger = logging.getLogger(__name__)
# A tiny object useful for storing a user's membership in a room, as a mapping # A tiny object useful for storing a user's membership in a room, as a mapping
# key # key
RoomMember = namedtuple("RoomMember", ("room_id", "user")) RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
class TypingNotificationHandler(BaseHandler): class TypingHandler(object):
def __init__(self, hs): def __init__(self, hs):
super(TypingNotificationHandler, self).__init__(hs) self.store = hs.get_datastore()
self.server_name = hs.config.server_name
self.homeserver = hs self.auth = hs.get_auth()
self.is_mine_id = hs.is_mine_id
self.notifier = hs.get_notifier()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -67,20 +67,23 @@ class TypingNotificationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def started_typing(self, target_user, auth_user, room_id, timeout): def started_typing(self, target_user, auth_user, room_id, timeout):
if not self.hs.is_mine(target_user): target_user_id = target_user.to_string()
auth_user_id = auth_user.to_string()
if not self.is_mine_id(target_user_id):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user: if target_user_id != auth_user_id:
raise AuthError(400, "Cannot set another user's typing state") raise AuthError(400, "Cannot set another user's typing state")
yield self.auth.check_joined_room(room_id, target_user.to_string()) yield self.auth.check_joined_room(room_id, target_user_id)
logger.debug( logger.debug(
"%s has started typing in %s", target_user.to_string(), room_id "%s has started typing in %s", target_user_id, room_id
) )
until = self.clock.time_msec() + timeout until = self.clock.time_msec() + timeout
member = RoomMember(room_id=room_id, user=target_user) member = RoomMember(room_id=room_id, user_id=target_user_id)
was_present = member in self._member_typing_until was_present = member in self._member_typing_until
@ -104,25 +107,28 @@ class TypingNotificationHandler(BaseHandler):
yield self._push_update( yield self._push_update(
room_id=room_id, room_id=room_id,
user=target_user, user_id=target_user_id,
typing=True, typing=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def stopped_typing(self, target_user, auth_user, room_id): def stopped_typing(self, target_user, auth_user, room_id):
if not self.hs.is_mine(target_user): target_user_id = target_user.to_string()
auth_user_id = auth_user.to_string()
if not self.is_mine_id(target_user_id):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user: if target_user_id != auth_user_id:
raise AuthError(400, "Cannot set another user's typing state") raise AuthError(400, "Cannot set another user's typing state")
yield self.auth.check_joined_room(room_id, target_user.to_string()) yield self.auth.check_joined_room(room_id, target_user_id)
logger.debug( logger.debug(
"%s has stopped typing in %s", target_user.to_string(), room_id "%s has stopped typing in %s", target_user_id, room_id
) )
member = RoomMember(room_id=room_id, user=target_user) member = RoomMember(room_id=room_id, user_id=target_user_id)
if member in self._member_typing_timer: if member in self._member_typing_timer:
self.clock.cancel_call_later(self._member_typing_timer[member]) self.clock.cancel_call_later(self._member_typing_timer[member])
@ -132,8 +138,9 @@ class TypingNotificationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def user_left_room(self, user, room_id): def user_left_room(self, user, room_id):
if self.hs.is_mine(user): user_id = user.to_string()
member = RoomMember(room_id=room_id, user=user) if self.is_mine_id(user_id):
member = RoomMember(room_id=room_id, user_id=user_id)
yield self._stopped_typing(member) yield self._stopped_typing(member)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -144,7 +151,7 @@ class TypingNotificationHandler(BaseHandler):
yield self._push_update( yield self._push_update(
room_id=member.room_id, room_id=member.room_id,
user=member.user, user_id=member.user_id,
typing=False, typing=False,
) )
@ -156,61 +163,53 @@ class TypingNotificationHandler(BaseHandler):
del self._member_typing_timer[member] del self._member_typing_timer[member]
@defer.inlineCallbacks @defer.inlineCallbacks
def _push_update(self, room_id, user, typing): def _push_update(self, room_id, user_id, typing):
localusers = set() domains = yield self.store.get_joined_hosts_for_room(room_id)
remotedomains = set()
rm_handler = self.homeserver.get_handlers().room_member_handler
yield rm_handler.fetch_room_distributions_into(
room_id, localusers=localusers, remotedomains=remotedomains
)
if localusers:
self._push_update_local(
room_id=room_id,
user=user,
typing=typing
)
deferreds = [] deferreds = []
for domain in remotedomains: for domain in domains:
deferreds.append(self.federation.send_edu( if domain == self.server_name:
destination=domain, self._push_update_local(
edu_type="m.typing", room_id=room_id,
content={ user_id=user_id,
"room_id": room_id, typing=typing
"user_id": user.to_string(), )
"typing": typing, else:
}, deferreds.append(self.federation.send_edu(
)) destination=domain,
edu_type="m.typing",
content={
"room_id": room_id,
"user_id": user_id,
"typing": typing,
},
))
yield defer.DeferredList(deferreds, consumeErrors=True) yield defer.DeferredList(deferreds, consumeErrors=True)
@defer.inlineCallbacks @defer.inlineCallbacks
def _recv_edu(self, origin, content): def _recv_edu(self, origin, content):
room_id = content["room_id"] room_id = content["room_id"]
user = UserID.from_string(content["user_id"]) user_id = content["user_id"]
localusers = set() # Check that the string is a valid user id
UserID.from_string(user_id)
rm_handler = self.homeserver.get_handlers().room_member_handler domains = yield self.store.get_joined_hosts_for_room(room_id)
yield rm_handler.fetch_room_distributions_into(
room_id, localusers=localusers
)
if localusers: if self.server_name in domains:
self._push_update_local( self._push_update_local(
room_id=room_id, room_id=room_id,
user=user, user_id=user_id,
typing=content["typing"] typing=content["typing"]
) )
def _push_update_local(self, room_id, user, typing): def _push_update_local(self, room_id, user_id, typing):
room_set = self._room_typing.setdefault(room_id, set()) room_set = self._room_typing.setdefault(room_id, set())
if typing: if typing:
room_set.add(user) room_set.add(user_id)
else: else:
room_set.discard(user) room_set.discard(user_id)
self._latest_room_serial += 1 self._latest_room_serial += 1
self._room_serials[room_id] = self._latest_room_serial self._room_serials[room_id] = self._latest_room_serial
@ -226,9 +225,7 @@ class TypingNotificationHandler(BaseHandler):
for room_id, serial in self._room_serials.items(): for room_id, serial in self._room_serials.items():
if last_id < serial and serial <= current_id: if last_id < serial and serial <= current_id:
typing = self._room_typing[room_id] typing = self._room_typing[room_id]
typing_bytes = json.dumps([ typing_bytes = json.dumps(list(typing), ensure_ascii=False)
u.to_string() for u in typing
], ensure_ascii=False)
rows.append((serial, room_id, typing_bytes)) rows.append((serial, room_id, typing_bytes))
rows.sort() rows.sort()
return rows return rows
@ -238,34 +235,26 @@ class TypingNotificationEventSource(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._handler = None # We can't call get_typing_handler here because there's a cycle:
self._room_member_handler = None #
# Typing -> Notifier -> TypingNotificationEventSource -> Typing
def handler(self): #
# Avoid cyclic dependency in handler setup self.get_typing_handler = hs.get_typing_handler
if not self._handler:
self._handler = self.hs.get_handlers().typing_notification_handler
return self._handler
def room_member_handler(self):
if not self._room_member_handler:
self._room_member_handler = self.hs.get_handlers().room_member_handler
return self._room_member_handler
def _make_event_for(self, room_id): def _make_event_for(self, room_id):
typing = self.handler()._room_typing[room_id] typing = self.get_typing_handler()._room_typing[room_id]
return { return {
"type": "m.typing", "type": "m.typing",
"room_id": room_id, "room_id": room_id,
"content": { "content": {
"user_ids": [u.to_string() for u in typing], "user_ids": list(typing),
}, },
} }
def get_new_events(self, from_key, room_ids, **kwargs): def get_new_events(self, from_key, room_ids, **kwargs):
with Measure(self.clock, "typing.get_new_events"): with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key) from_key = int(from_key)
handler = self.handler() handler = self.get_typing_handler()
events = [] events = []
for room_id in room_ids: for room_id in room_ids:
@ -279,7 +268,7 @@ class TypingNotificationEventSource(object):
return events, handler._latest_room_serial return events, handler._latest_room_serial
def get_current_key(self): def get_current_key(self):
return self.handler()._latest_room_serial return self.get_typing_handler()._latest_room_serial
def get_pagination_rows(self, user, pagination_config, key): def get_pagination_rows(self, user, pagination_config, key):
return ([], pagination_config.from_key) return ([], pagination_config.from_key)

View file

@ -15,17 +15,24 @@
from OpenSSL import SSL from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE from OpenSSL.SSL import VERIFY_NONE
from synapse.api.errors import CodeMessageException from synapse.api.errors import (
CodeMessageException, SynapseError, Codes,
)
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import preserve_context_over_fn
import synapse.metrics import synapse.metrics
from synapse.http.endpoint import SpiderEndpoint
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from twisted.internet import defer, reactor, ssl from twisted.internet import defer, reactor, ssl, protocol
from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
from twisted.web.client import ( from twisted.web.client import (
Agent, readBody, FileBodyProducer, PartialDownloadError, BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
readBody, FileBodyProducer, PartialDownloadError,
) )
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone
from StringIO import StringIO from StringIO import StringIO
@ -238,6 +245,107 @@ class SimpleHttpClient(object):
else: else:
raise CodeMessageException(response.code, body) raise CodeMessageException(response.code, body)
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
@defer.inlineCallbacks
def get_file(self, url, output_stream, max_size=None):
"""GETs a file from a given URL
Args:
url (str): The URL to GET
output_stream (file): File to write the response body to.
Returns:
A (int,dict,string,int) tuple of the file length, dict of the response
headers, absolute URI of the response and HTTP response code.
"""
response = yield self.request(
"GET",
url.encode("ascii"),
headers=Headers({
b"User-Agent": [self.user_agent],
})
)
headers = dict(response.headers.getAllRawHeaders())
if 'Content-Length' in headers and headers['Content-Length'] > max_size:
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError(
502,
"Requested file is too large > %r bytes" % (self.max_size,),
Codes.TOO_LARGE,
)
if response.code > 299:
logger.warn("Got %d when downloading %s" % (response.code, url))
raise SynapseError(
502,
"Got error %d" % (response.code,),
Codes.UNKNOWN,
)
# TODO: if our Content-Type is HTML or something, just read the first
# N bytes into RAM rather than saving it all to disk only to read it
# straight back in again
try:
length = yield preserve_context_over_fn(
_readBodyToFile,
response, output_stream, max_size
)
except Exception as e:
logger.exception("Failed to download body")
raise SynapseError(
502,
("Failed to download remote body: %s" % e),
Codes.UNKNOWN,
)
defer.returnValue((length, headers, response.request.absoluteURI, response.code))
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
class _ReadBodyToFileProtocol(protocol.Protocol):
def __init__(self, stream, deferred, max_size):
self.stream = stream
self.deferred = deferred
self.length = 0
self.max_size = max_size
def dataReceived(self, data):
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
self.deferred.errback(SynapseError(
502,
"Requested file is too large > %r bytes" % (self.max_size,),
Codes.TOO_LARGE,
))
self.deferred = defer.Deferred()
self.transport.loseConnection()
def connectionLost(self, reason):
if reason.check(ResponseDone):
self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss):
# stolen from https://github.com/twisted/treq/pull/49/files
# http://twistedmatrix.com/trac/ticket/4840
self.deferred.callback(self.length)
else:
self.deferred.errback(reason)
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
def _readBodyToFile(response, stream, max_size):
d = defer.Deferred()
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
return d
class CaptchaServerHttpClient(SimpleHttpClient): class CaptchaServerHttpClient(SimpleHttpClient):
""" """
@ -269,6 +377,60 @@ class CaptchaServerHttpClient(SimpleHttpClient):
defer.returnValue(e.response) defer.returnValue(e.response)
class SpiderEndpointFactory(object):
def __init__(self, hs):
self.blacklist = hs.config.url_preview_ip_range_blacklist
self.whitelist = hs.config.url_preview_ip_range_whitelist
self.policyForHTTPS = hs.get_http_client_context_factory()
def endpointForURI(self, uri):
logger.info("Getting endpoint for %s", uri.toBytes())
if uri.scheme == "http":
return SpiderEndpoint(
reactor, uri.host, uri.port, self.blacklist, self.whitelist,
endpoint=TCP4ClientEndpoint,
endpoint_kw_args={
'timeout': 15
},
)
elif uri.scheme == "https":
tlsPolicy = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port)
return SpiderEndpoint(
reactor, uri.host, uri.port, self.blacklist, self.whitelist,
endpoint=SSL4ClientEndpoint,
endpoint_kw_args={
'sslContextFactory': tlsPolicy,
'timeout': 15
},
)
else:
logger.warn("Can't get endpoint for unrecognised scheme %s", uri.scheme)
class SpiderHttpClient(SimpleHttpClient):
"""
Separate HTTP client for spidering arbitrary URLs.
Special in that it follows retries and has a UA that looks
like a browser.
used by the preview_url endpoint in the content repo.
"""
def __init__(self, hs):
SimpleHttpClient.__init__(self, hs)
# clobber the base class's agent and UA:
self.agent = ContentDecoderAgent(
BrowserLikeRedirectAgent(
Agent.usingEndpointFactory(
reactor,
SpiderEndpointFactory(hs)
)
), [('gzip', GzipDecoder)]
)
# We could look like Chrome:
# self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko)
# Chrome Safari" % hs.version_string)
def encode_urlencode_args(args): def encode_urlencode_args(args):
return {k: encode_urlencode_arg(v) for k, v in args.items()} return {k: encode_urlencode_arg(v) for k, v in args.items()}
@ -301,5 +463,8 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
self._context = SSL.Context(SSL.SSLv23_METHOD) self._context = SSL.Context(SSL.SSLv23_METHOD)
self._context.set_verify(VERIFY_NONE, lambda *_: None) self._context.set_verify(VERIFY_NONE, lambda *_: None)
def getContext(self, hostname, port): def getContext(self, hostname=None, port=None):
return self._context return self._context
def creatorForNetloc(self, hostname, port):
return self

View file

@ -22,6 +22,7 @@ from twisted.names.error import DNSNameError, DomainError
import collections import collections
import logging import logging
import random import random
import time
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,7 +32,7 @@ SERVER_CACHE = {}
_Server = collections.namedtuple( _Server = collections.namedtuple(
"_Server", "priority weight host port" "_Server", "priority weight host port expires"
) )
@ -74,6 +75,41 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
return transport_endpoint(reactor, domain, port, **endpoint_kw_args) return transport_endpoint(reactor, domain, port, **endpoint_kw_args)
class SpiderEndpoint(object):
"""An endpoint which refuses to connect to blacklisted IP addresses
Implements twisted.internet.interfaces.IStreamClientEndpoint.
"""
def __init__(self, reactor, host, port, blacklist, whitelist,
endpoint=TCP4ClientEndpoint, endpoint_kw_args={}):
self.reactor = reactor
self.host = host
self.port = port
self.blacklist = blacklist
self.whitelist = whitelist
self.endpoint = endpoint
self.endpoint_kw_args = endpoint_kw_args
@defer.inlineCallbacks
def connect(self, protocolFactory):
address = yield self.reactor.resolve(self.host)
from netaddr import IPAddress
ip_address = IPAddress(address)
if ip_address in self.blacklist:
if self.whitelist is None or ip_address not in self.whitelist:
raise ConnectError(
"Refusing to spider blacklisted IP address %s" % address
)
logger.info("Connecting to %s:%s", address, self.port)
endpoint = self.endpoint(
self.reactor, address, self.port, **self.endpoint_kw_args
)
connection = yield endpoint.connect(protocolFactory)
defer.returnValue(connection)
class SRVClientEndpoint(object): class SRVClientEndpoint(object):
"""An endpoint which looks up SRV records for a service. """An endpoint which looks up SRV records for a service.
Cycles through the list of servers starting with each call to connect Cycles through the list of servers starting with each call to connect
@ -92,7 +128,8 @@ class SRVClientEndpoint(object):
host=domain, host=domain,
port=default_port, port=default_port,
priority=0, priority=0,
weight=0 weight=0,
expires=0,
) )
else: else:
self.default_server = None self.default_server = None
@ -118,7 +155,7 @@ class SRVClientEndpoint(object):
return self.default_server return self.default_server
else: else:
raise ConnectError( raise ConnectError(
"Not server available for %s", self.service_name "Not server available for %s" % self.service_name
) )
min_priority = self.servers[0].priority min_priority = self.servers[0].priority
@ -153,7 +190,13 @@ class SRVClientEndpoint(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
cache_entry = cache.get(service_name, None)
if cache_entry:
if all(s.expires > int(clock.time()) for s in cache_entry):
servers = list(cache_entry)
defer.returnValue(servers)
servers = [] servers = []
try: try:
@ -166,34 +209,33 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
and answers[0].type == dns.SRV and answers[0].type == dns.SRV
and answers[0].payload and answers[0].payload
and answers[0].payload.target == dns.Name('.')): and answers[0].payload.target == dns.Name('.')):
raise ConnectError("Service %s unavailable", service_name) raise ConnectError("Service %s unavailable" % service_name)
for answer in answers: for answer in answers:
if answer.type != dns.SRV or not answer.payload: if answer.type != dns.SRV or not answer.payload:
continue continue
payload = answer.payload payload = answer.payload
host = str(payload.target) host = str(payload.target)
srv_ttl = answer.ttl
try: try:
answers, _, _ = yield dns_client.lookupAddress(host) answers, _, _ = yield dns_client.lookupAddress(host)
except DNSNameError: except DNSNameError:
continue continue
ips = [ for answer in answers:
answer.payload.dottedQuad() if answer.type == dns.A and answer.payload:
for answer in answers ip = answer.payload.dottedQuad()
if answer.type == dns.A and answer.payload host_ttl = min(srv_ttl, answer.ttl)
]
for ip in ips: servers.append(_Server(
servers.append(_Server( host=ip,
host=ip, port=int(payload.port),
port=int(payload.port), priority=int(payload.priority),
priority=int(payload.priority), weight=int(payload.weight),
weight=int(payload.weight) expires=int(clock.time()) + host_ttl,
)) ))
servers.sort() servers.sort()
cache[service_name] = list(servers) cache[service_name] = list(servers)

View file

@ -74,7 +74,12 @@ response_db_txn_duration = metrics.register_distribution(
_next_request_id = 0 _next_request_id = 0
def request_handler(request_handler): def request_handler(report_metrics=True):
"""Decorator for ``wrap_request_handler``"""
return lambda request_handler: wrap_request_handler(request_handler, report_metrics)
def wrap_request_handler(request_handler, report_metrics):
"""Wraps a method that acts as a request handler with the necessary logging """Wraps a method that acts as a request handler with the necessary logging
and exception handling. and exception handling.
@ -96,7 +101,12 @@ def request_handler(request_handler):
global _next_request_id global _next_request_id
request_id = "%s-%s" % (request.method, _next_request_id) request_id = "%s-%s" % (request.method, _next_request_id)
_next_request_id += 1 _next_request_id += 1
with LoggingContext(request_id) as request_context: with LoggingContext(request_id) as request_context:
if report_metrics:
request_metrics = RequestMetrics()
request_metrics.start(self.clock)
request_context.request = request_id request_context.request = request_id
with request.processing(): with request.processing():
try: try:
@ -133,6 +143,14 @@ def request_handler(request_handler):
}, },
send_cors=True send_cors=True
) )
finally:
try:
if report_metrics:
request_metrics.stop(
self.clock, request, self.__class__.__name__
)
except:
pass
return wrapped_request_handler return wrapped_request_handler
@ -197,19 +215,23 @@ class JsonResource(HttpServer, resource.Resource):
self._async_render(request) self._async_render(request)
return server.NOT_DONE_YET return server.NOT_DONE_YET
@request_handler # Disable metric reporting because _async_render does its own metrics.
# It does its own metric reporting because _async_render dispatches to
# a callback and it's the class name of that callback we want to report
# against rather than the JsonResource itself.
@request_handler(report_metrics=False)
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render(self, request): def _async_render(self, request):
""" This gets called from render() every time someone sends us a request. """ This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and This checks if anyone has registered a callback for that method and
path. path.
""" """
start = self.clock.time_msec()
if request.method == "OPTIONS": if request.method == "OPTIONS":
self._send_response(request, 200, {}) self._send_response(request, 200, {})
return return
start_context = LoggingContext.current_context() request_metrics = RequestMetrics()
request_metrics.start(self.clock)
# Loop through all the registered callbacks to check if the method # Loop through all the registered callbacks to check if the method
# and path regex match # and path regex match
@ -241,40 +263,7 @@ class JsonResource(HttpServer, resource.Resource):
self._send_response(request, code, response) self._send_response(request, code, response)
try: try:
context = LoggingContext.current_context() request_metrics.stop(self.clock, request, servlet_classname)
tag = ""
if context:
tag = context.tag
if context != start_context:
logger.warn(
"Context have unexpectedly changed %r, %r",
context, self.start_context
)
return
incoming_requests_counter.inc(request.method, servlet_classname, tag)
response_timer.inc_by(
self.clock.time_msec() - start, request.method,
servlet_classname, tag
)
ru_utime, ru_stime = context.get_resource_usage()
response_ru_utime.inc_by(
ru_utime, request.method, servlet_classname, tag
)
response_ru_stime.inc_by(
ru_stime, request.method, servlet_classname, tag
)
response_db_txn_count.inc_by(
context.db_txn_count, request.method, servlet_classname, tag
)
response_db_txn_duration.inc_by(
context.db_txn_duration, request.method, servlet_classname, tag
)
except: except:
pass pass
@ -307,6 +296,48 @@ class JsonResource(HttpServer, resource.Resource):
) )
class RequestMetrics(object):
def start(self, clock):
self.start = clock.time_msec()
self.start_context = LoggingContext.current_context()
def stop(self, clock, request, servlet_classname):
context = LoggingContext.current_context()
tag = ""
if context:
tag = context.tag
if context != self.start_context:
logger.warn(
"Context have unexpectedly changed %r, %r",
context, self.start_context
)
return
incoming_requests_counter.inc(request.method, servlet_classname, tag)
response_timer.inc_by(
clock.time_msec() - self.start, request.method,
servlet_classname, tag
)
ru_utime, ru_stime = context.get_resource_usage()
response_ru_utime.inc_by(
ru_utime, request.method, servlet_classname, tag
)
response_ru_stime.inc_by(
ru_stime, request.method, servlet_classname, tag
)
response_db_txn_count.inc_by(
context.db_txn_count, request.method, servlet_classname, tag
)
response_db_txn_duration.inc_by(
context.db_txn_duration, request.method, servlet_classname, tag
)
class RootRedirect(resource.Resource): class RootRedirect(resource.Resource):
"""Redirects the root '/' path to another path.""" """Redirects the root '/' path to another path."""

View file

@ -26,14 +26,19 @@ logger = logging.getLogger(__name__)
def parse_integer(request, name, default=None, required=False): def parse_integer(request, name, default=None, required=False):
"""Parse an integer parameter from the request string """Parse an integer parameter from the request string
:param request: the twisted HTTP request. Args:
:param name (str): the name of the query parameter. request: the twisted HTTP request.
:param default: value to use if the parameter is absent, defaults to None. name (str): the name of the query parameter.
:param required (bool): whether to raise a 400 SynapseError if the default (int|None): value to use if the parameter is absent, defaults
parameter is absent, defaults to False. to None.
:return: An int value or the default. required (bool): whether to raise a 400 SynapseError if the
:raises parameter is absent, defaults to False.
SynapseError if the parameter is absent and required, or if the
Returns:
int|None: An int value or the default.
Raises:
SynapseError: if the parameter is absent and required, or if the
parameter is present and not an integer. parameter is present and not an integer.
""" """
if name in request.args: if name in request.args:
@ -53,14 +58,19 @@ def parse_integer(request, name, default=None, required=False):
def parse_boolean(request, name, default=None, required=False): def parse_boolean(request, name, default=None, required=False):
"""Parse a boolean parameter from the request query string """Parse a boolean parameter from the request query string
:param request: the twisted HTTP request. Args:
:param name (str): the name of the query parameter. request: the twisted HTTP request.
:param default: value to use if the parameter is absent, defaults to None. name (str): the name of the query parameter.
:param required (bool): whether to raise a 400 SynapseError if the default (bool|None): value to use if the parameter is absent, defaults
parameter is absent, defaults to False. to None.
:return: A bool value or the default. required (bool): whether to raise a 400 SynapseError if the
:raises parameter is absent, defaults to False.
SynapseError if the parameter is absent and required, or if the
Returns:
bool|None: A bool value or the default.
Raises:
SynapseError: if the parameter is absent and required, or if the
parameter is present and not one of "true" or "false". parameter is present and not one of "true" or "false".
""" """
@ -88,15 +98,20 @@ def parse_string(request, name, default=None, required=False,
allowed_values=None, param_type="string"): allowed_values=None, param_type="string"):
"""Parse a string parameter from the request query string. """Parse a string parameter from the request query string.
:param request: the twisted HTTP request. Args:
:param name (str): the name of the query parameter. request: the twisted HTTP request.
:param default: value to use if the parameter is absent, defaults to None. name (str): the name of the query parameter.
:param required (bool): whether to raise a 400 SynapseError if the default (str|None): value to use if the parameter is absent, defaults
parameter is absent, defaults to False. to None.
:param allowed_values (list): List of allowed values for the string, required (bool): whether to raise a 400 SynapseError if the
or None if any value is allowed, defaults to None parameter is absent, defaults to False.
:return: A string value or the default. allowed_values (list[str]): List of allowed values for the string,
:raises or None if any value is allowed, defaults to None
Returns:
str|None: A string value or the default.
Raises:
SynapseError if the parameter is absent and required, or if the SynapseError if the parameter is absent and required, or if the
parameter is present, must be one of a list of allowed values and parameter is present, must be one of a list of allowed values and
is not one of those allowed values. is not one of those allowed values.
@ -122,9 +137,13 @@ def parse_string(request, name, default=None, required=False,
def parse_json_value_from_request(request): def parse_json_value_from_request(request):
"""Parse a JSON value from the body of a twisted HTTP request. """Parse a JSON value from the body of a twisted HTTP request.
:param request: the twisted HTTP request. Args:
:returns: The JSON value. request: the twisted HTTP request.
:raises
Returns:
The JSON value.
Raises:
SynapseError if the request body couldn't be decoded as JSON. SynapseError if the request body couldn't be decoded as JSON.
""" """
try: try:
@ -143,8 +162,10 @@ def parse_json_value_from_request(request):
def parse_json_object_from_request(request): def parse_json_object_from_request(request):
"""Parse a JSON object from the body of a twisted HTTP request. """Parse a JSON object from the body of a twisted HTTP request.
:param request: the twisted HTTP request. Args:
:raises request: the twisted HTTP request.
Raises:
SynapseError if the request body couldn't be decoded as JSON or SynapseError if the request body couldn't be decoded as JSON or
if it wasn't a JSON object. if it wasn't a JSON object.
""" """

146
synapse/http/site.py Normal file
View file

@ -0,0 +1,146 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.logcontext import LoggingContext
from twisted.web.server import Site, Request
import contextlib
import logging
import re
import time
ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
class SynapseRequest(Request):
def __init__(self, site, *args, **kw):
Request.__init__(self, *args, **kw)
self.site = site
self.authenticated_entity = None
self.start_time = 0
def __repr__(self):
# We overwrite this so that we don't log ``access_token``
return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
self.__class__.__name__,
id(self),
self.method,
self.get_redacted_uri(),
self.clientproto,
self.site.site_tag,
)
def get_redacted_uri(self):
return ACCESS_TOKEN_RE.sub(
r'\1<redacted>\3',
self.uri
)
def get_user_agent(self):
return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
def started_processing(self):
self.site.access_logger.info(
"%s - %s - Received request: %s %s",
self.getClientIP(),
self.site.site_tag,
self.method,
self.get_redacted_uri()
)
self.start_time = int(time.time() * 1000)
def finished_processing(self):
try:
context = LoggingContext.current_context()
ru_utime, ru_stime = context.get_resource_usage()
db_txn_count = context.db_txn_count
db_txn_duration = context.db_txn_duration
except:
ru_utime, ru_stime = (0, 0)
db_txn_count, db_txn_duration = (0, 0)
self.site.access_logger.info(
"%s - %s - {%s}"
" Processed request: %dms (%dms, %dms) (%dms/%d)"
" %sB %s \"%s %s %s\" \"%s\"",
self.getClientIP(),
self.site.site_tag,
self.authenticated_entity,
int(time.time() * 1000) - self.start_time,
int(ru_utime * 1000),
int(ru_stime * 1000),
int(db_txn_duration * 1000),
int(db_txn_count),
self.sentLength,
self.code,
self.method,
self.get_redacted_uri(),
self.clientproto,
self.get_user_agent(),
)
@contextlib.contextmanager
def processing(self):
self.started_processing()
yield
self.finished_processing()
class XForwardedForRequest(SynapseRequest):
def __init__(self, *args, **kw):
SynapseRequest.__init__(self, *args, **kw)
"""
Add a layer on top of another request that only uses the value of an
X-Forwarded-For header as the result of C{getClientIP}.
"""
def getClientIP(self):
"""
@return: The client address (the first address) in the value of the
I{X-Forwarded-For header}. If the header is not present, return
C{b"-"}.
"""
return self.requestHeaders.getRawHeaders(
b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
class SynapseRequestFactory(object):
def __init__(self, site, x_forwarded_for):
self.site = site
self.x_forwarded_for = x_forwarded_for
def __call__(self, *args, **kwargs):
if self.x_forwarded_for:
return XForwardedForRequest(self.site, *args, **kwargs)
else:
return SynapseRequest(self.site, *args, **kwargs)
class SynapseSite(Site):
"""
Subclass of a twisted http Site that does access logging with python's
standard logging
"""
def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
Site.__init__(self, resource, *args, **kwargs)
self.site_tag = site_tag
proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name)
def log(self, request):
pass

View file

@ -22,6 +22,7 @@ import functools
import os import os
import stat import stat
import time import time
import gc
from twisted.internet import reactor from twisted.internet import reactor
@ -33,11 +34,7 @@ from .metric import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# We'll keep all the available metrics in a single toplevel dict, one shared all_metrics = []
# for the entire process. We don't currently support per-HomeServer instances
# of metrics, because in practice any one python VM will host only one
# HomeServer anyway. This makes a lot of implementation neater
all_metrics = {}
class Metrics(object): class Metrics(object):
@ -53,7 +50,7 @@ class Metrics(object):
metric = metric_class(full_name, *args, **kwargs) metric = metric_class(full_name, *args, **kwargs)
all_metrics[full_name] = metric all_metrics.append(metric)
return metric return metric
def register_counter(self, *args, **kwargs): def register_counter(self, *args, **kwargs):
@ -84,12 +81,12 @@ def render_all():
# TODO(paul): Internal hack # TODO(paul): Internal hack
update_resource_metrics() update_resource_metrics()
for name in sorted(all_metrics.keys()): for metric in all_metrics:
try: try:
strs += all_metrics[name].render() strs += metric.render()
except Exception: except Exception:
strs += ["# FAILED to render %s" % name] strs += ["# FAILED to render"]
logger.exception("Failed to render %s metric", name) logger.exception("Failed to render metric")
strs.append("") # to generate a final CRLF strs.append("") # to generate a final CRLF
@ -156,6 +153,13 @@ reactor_metrics = get_metrics_for("reactor")
tick_time = reactor_metrics.register_distribution("tick_time") tick_time = reactor_metrics.register_distribution("tick_time")
pending_calls_metric = reactor_metrics.register_distribution("pending_calls") pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
gc_time = reactor_metrics.register_distribution("gc_time", labels=["gen"])
gc_unreachable = reactor_metrics.register_counter("gc_unreachable", labels=["gen"])
reactor_metrics.register_callback(
"gc_counts", lambda: {(i,): v for i, v in enumerate(gc.get_count())}, labels=["gen"]
)
def runUntilCurrentTimer(func): def runUntilCurrentTimer(func):
@ -182,6 +186,22 @@ def runUntilCurrentTimer(func):
end = time.time() * 1000 end = time.time() * 1000
tick_time.inc_by(end - start) tick_time.inc_by(end - start)
pending_calls_metric.inc_by(num_pending) pending_calls_metric.inc_by(num_pending)
# Check if we need to do a manual GC (since its been disabled), and do
# one if necessary.
threshold = gc.get_threshold()
counts = gc.get_count()
for i in (2, 1, 0):
if threshold[i] < counts[i]:
logger.info("Collecting gc %d", i)
start = time.time() * 1000
unreachable = gc.collect(i)
end = time.time() * 1000
gc_time.inc_by(end - start, i)
gc_unreachable.inc_by(unreachable, i)
return ret return ret
return f return f
@ -196,5 +216,9 @@ try:
# runUntilCurrent is called when we have pending calls. It is called once # runUntilCurrent is called when we have pending calls. It is called once
# per iteratation after fd polling. # per iteratation after fd polling.
reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent) reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)
# We manually run the GC each reactor tick so that we can get some metrics
# about time spent doing GC,
gc.disable()
except AttributeError: except AttributeError:
pass pass

View file

@ -47,9 +47,6 @@ class BaseMetric(object):
for k, v in zip(self.labels, values)]) for k, v in zip(self.labels, values)])
) )
def render(self):
return map_concat(self.render_item, sorted(self.counts.keys()))
class CounterMetric(BaseMetric): class CounterMetric(BaseMetric):
"""The simplest kind of metric; one that stores a monotonically-increasing """The simplest kind of metric; one that stores a monotonically-increasing
@ -83,6 +80,9 @@ class CounterMetric(BaseMetric):
def render_item(self, k): def render_item(self, k):
return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])] return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])]
def render(self):
return map_concat(self.render_item, sorted(self.counts.keys()))
class CallbackMetric(BaseMetric): class CallbackMetric(BaseMetric):
"""A metric that returns the numeric value returned by a callback whenever """A metric that returns the numeric value returned by a callback whenever
@ -126,30 +126,30 @@ class DistributionMetric(object):
class CacheMetric(object): class CacheMetric(object):
"""A combination of two CounterMetrics, one to count cache hits and one to __slots__ = ("name", "cache_name", "hits", "misses", "size_callback")
count a total, and a callback metric to yield the current size.
This metric generates standard metric name pairs, so that monitoring rules def __init__(self, name, size_callback, cache_name):
can easily be applied to measure hit ratio."""
def __init__(self, name, size_callback, labels=[]):
self.name = name self.name = name
self.cache_name = cache_name
self.hits = CounterMetric(name + ":hits", labels=labels) self.hits = 0
self.total = CounterMetric(name + ":total", labels=labels) self.misses = 0
self.size = CallbackMetric( self.size_callback = size_callback
name + ":size",
callback=size_callback,
labels=labels,
)
def inc_hits(self, *values): def inc_hits(self):
self.hits.inc(*values) self.hits += 1
self.total.inc(*values)
def inc_misses(self, *values): def inc_misses(self):
self.total.inc(*values) self.misses += 1
def render(self): def render(self):
return self.hits.render() + self.total.render() + self.size.render() size = self.size_callback()
hits = self.hits
total = self.misses + self.hits
return [
"""%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
]

View file

@ -14,13 +14,14 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.visibility import filter_events_for_client
import synapse.metrics import synapse.metrics
from collections import namedtuple from collections import namedtuple
@ -139,8 +140,6 @@ class Notifier(object):
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
def __init__(self, hs): def __init__(self, hs):
self.hs = hs
self.user_to_user_stream = {} self.user_to_user_stream = {}
self.room_to_user_streams = {} self.room_to_user_streams = {}
self.appservice_to_user_streams = {} self.appservice_to_user_streams = {}
@ -150,10 +149,8 @@ class Notifier(object):
self.pending_new_room_events = [] self.pending_new_room_events = []
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler()
hs.get_distributor().observe( self.state_handler = hs.get_state_handler()
"user_joined_room", self._user_joined_room
)
self.clock.looping_call( self.clock.looping_call(
self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
@ -231,9 +228,7 @@ class Notifier(object):
def _on_new_room_event(self, event, room_stream_id, extra_users=[]): def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event""" """Notify any user streams that are interested in this room event"""
# poke any interested application service. # poke any interested application service.
self.hs.get_handlers().appservice_handler.notify_interested_services( self.appservice_handler.notify_interested_services(event)
event
)
app_streams = set() app_streams = set()
@ -249,6 +244,9 @@ class Notifier(object):
) )
app_streams |= app_user_streams app_streams |= app_user_streams
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
self._user_joined_room(event.state_key, event.room_id)
self.on_new_event( self.on_new_event(
"room_key", room_stream_id, "room_key", room_stream_id,
users=extra_users, users=extra_users,
@ -398,8 +396,8 @@ class Notifier(object):
) )
if name == "room": if name == "room":
room_member_handler = self.hs.get_handlers().room_member_handler new_events = yield filter_events_for_client(
new_events = yield room_member_handler._filter_events_for_client( self.store,
user.to_string(), user.to_string(),
new_events, new_events,
is_peeking=is_peeking, is_peeking=is_peeking,
@ -448,7 +446,7 @@ class Notifier(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _is_world_readable(self, room_id): def _is_world_readable(self, room_id):
state = yield self.hs.get_state_handler().get_current_state( state = yield self.state_handler.get_current_state(
room_id, room_id,
EventTypes.RoomHistoryVisibility EventTypes.RoomHistoryVisibility
) )
@ -484,9 +482,8 @@ class Notifier(object):
user_stream.appservice, set() user_stream.appservice, set()
).add(user_stream) ).add(user_stream)
def _user_joined_room(self, user, room_id): def _user_joined_room(self, user_id, room_id):
user = str(user) new_user_stream = self.user_to_user_stream.get(user_id)
new_user_stream = self.user_to_user_stream.get(user)
if new_user_stream is not None: if new_user_stream is not None:
room_streams = self.room_to_user_streams.setdefault(room_id, set()) room_streams = self.room_to_user_streams.setdefault(room_id, set())
room_streams.add(new_user_stream) room_streams.add(new_user_stream)
@ -503,13 +500,14 @@ class Notifier(object):
def wait_for_replication(self, callback, timeout): def wait_for_replication(self, callback, timeout):
"""Wait for an event to happen. """Wait for an event to happen.
:param callback: Args:
Gets called whenever an event happens. If this returns a truthy callback: Gets called whenever an event happens. If this returns a
value then ``wait_for_replication`` returns, otherwise it waits truthy value then ``wait_for_replication`` returns, otherwise
for another event. it waits for another event.
:param int timeout: timeout: How many milliseconds to wait for callback return a truthy
How many milliseconds to wait for callback return a truthy value. value.
:returns:
Returns:
A deferred that resolves with the value returned by the callback. A deferred that resolves with the value returned by the callback.
""" """
listener = _NotificationListener(None) listener = _NotificationListener(None)

View file

@ -13,333 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken
from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure
import synapse.util.async
from .push_rule_evaluator import evaluator_for_user_id
import logging
import random
logger = logging.getLogger(__name__)
_NEXT_ID = 1
def _get_next_id():
global _NEXT_ID
_id = _NEXT_ID
_NEXT_ID += 1
return _id
# Pushers could now be moved to pull out of the event_push_actions table instead
# of listening on the event stream: this would avoid them having to run the
# rules again.
class Pusher(object):
INITIAL_BACKOFF = 1000
MAX_BACKOFF = 60 * 60 * 1000
GIVE_UP_AFTER = 24 * 60 * 60 * 1000
def __init__(self, _hs, user_id, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since):
self.hs = _hs
self.evStreamHandler = self.hs.get_handlers().event_stream_handler
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.user_id = user_id
self.app_id = app_id
self.app_display_name = app_display_name
self.device_display_name = device_display_name
self.pushkey = pushkey
self.pushkey_ts = pushkey_ts
self.data = data
self.last_token = last_token
self.last_success = last_success # not actually used
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.failing_since = failing_since
self.alive = True
self.badge = None
self.name = "Pusher-%d" % (_get_next_id(),)
# The last value of last_active_time that we saw
self.last_last_active_time = 0
self.has_unread = True
@defer.inlineCallbacks
def get_context_for_event(self, ev):
name_aliases = yield self.store.get_room_name_and_aliases(
ev['room_id']
)
ctx = {'aliases': name_aliases[1]}
if name_aliases[0] is not None:
ctx['name'] = name_aliases[0]
their_member_events_for_room = yield self.store.get_current_state(
room_id=ev['room_id'],
event_type='m.room.member',
state_key=ev['user_id']
)
for mev in their_member_events_for_room:
if mev.content['membership'] == 'join' and 'displayname' in mev.content:
dn = mev.content['displayname']
if dn is not None:
ctx['sender_display_name'] = dn
defer.returnValue(ctx)
@defer.inlineCallbacks
def start(self):
with LoggingContext(self.name):
if not self.last_token:
# First-time setup: get a token to start from (we can't
# just start from no token, ie. 'now'
# because we need the result to be reproduceable in case
# we fail to dispatch the push)
config = PaginationConfig(from_token=None, limit='1')
chunk = yield self.evStreamHandler.get_stream(
self.user_id, config, timeout=0, affect_presence=False
)
self.last_token = chunk['end']
yield self.store.update_pusher_last_token(
self.app_id, self.pushkey, self.user_id, self.last_token
)
logger.info("New pusher %s for user %s starting from token %s",
self.pushkey, self.user_id, self.last_token)
else:
logger.info(
"Old pusher %s for user %s starting",
self.pushkey, self.user_id,
)
wait = 0
while self.alive:
try:
if wait > 0:
yield synapse.util.async.sleep(wait)
with Measure(self.clock, "push"):
yield self.get_and_dispatch()
wait = 0
except:
if wait == 0:
wait = 1
else:
wait = min(wait * 2, 1800)
logger.exception(
"Exception in pusher loop for pushkey %s. Pausing for %ds",
self.pushkey, wait
)
@defer.inlineCallbacks
def get_and_dispatch(self):
from_tok = StreamToken.from_string(self.last_token)
config = PaginationConfig(from_token=from_tok, limit='1')
timeout = (300 + random.randint(-60, 60)) * 1000
chunk = yield self.evStreamHandler.get_stream(
self.user_id, config, timeout=timeout, affect_presence=False,
only_keys=("room", "receipt",),
)
# limiting to 1 may get 1 event plus 1 presence event, so
# pick out the actual event
single_event = None
read_receipt = None
for c in chunk['chunk']:
if 'event_id' in c: # Hmmm...
single_event = c
elif c['type'] == 'm.receipt':
read_receipt = c
have_updated_badge = False
if read_receipt:
for receipt_part in read_receipt['content'].values():
if 'm.read' in receipt_part:
if self.user_id in receipt_part['m.read'].keys():
have_updated_badge = True
if not single_event:
if have_updated_badge:
yield self.update_badge()
self.last_token = chunk['end']
yield self.store.update_pusher_last_token(
self.app_id,
self.pushkey,
self.user_id,
self.last_token
)
return
if not self.alive:
return
processed = False
rule_evaluator = yield \
evaluator_for_user_id(
self.user_id, single_event['room_id'], self.store
)
actions = yield rule_evaluator.actions_for_event(single_event)
tweaks = rule_evaluator.tweaks_for_actions(actions)
if 'notify' in actions:
self.badge = yield self._get_badge_count()
rejected = yield self.dispatch_push(single_event, tweaks, self.badge)
self.has_unread = True
if isinstance(rejected, list) or isinstance(rejected, tuple):
processed = True
for pk in rejected:
if pk != self.pushkey:
# for sanity, we only remove the pushkey if it
# was the one we actually sent...
logger.warn(
("Ignoring rejected pushkey %s because we"
" didn't send it"), pk
)
else:
logger.info(
"Pushkey %s was rejected: removing",
pk
)
yield self.hs.get_pusherpool().remove_pusher(
self.app_id, pk, self.user_id
)
else:
if have_updated_badge:
yield self.update_badge()
processed = True
if not self.alive:
return
if processed:
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
yield self.store.update_pusher_last_token_and_success(
self.app_id,
self.pushkey,
self.user_id,
self.last_token,
self.clock.time_msec()
)
if self.failing_since:
self.failing_since = None
yield self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_id,
self.failing_since)
else:
if not self.failing_since:
self.failing_since = self.clock.time_msec()
yield self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_id,
self.failing_since
)
if (self.failing_since and
self.failing_since <
self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
# we really only give up so that if the URL gets
# fixed, we don't suddenly deliver a load
# of old notifications.
logger.warn("Giving up on a notification to user %s, "
"pushkey %s",
self.user_id, self.pushkey)
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
yield self.store.update_pusher_last_token(
self.app_id,
self.pushkey,
self.user_id,
self.last_token
)
self.failing_since = None
yield self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_id,
self.failing_since
)
else:
logger.warn("Failed to dispatch push for user %s "
"(failing for %dms)."
"Trying again in %dms",
self.user_id,
self.clock.time_msec() - self.failing_since,
self.backoff_delay)
yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
self.backoff_delay *= 2
if self.backoff_delay > Pusher.MAX_BACKOFF:
self.backoff_delay = Pusher.MAX_BACKOFF
def stop(self):
self.alive = False
def dispatch_push(self, p, tweaks, badge):
"""
Overridden by implementing classes to actually deliver the notification
Args:
p: The event to notify for as a single event from the event stream
Returns: If the notification was delivered, an array containing any
pushkeys that were rejected by the push gateway.
False if the notification could not be delivered (ie.
should be retried).
"""
pass
@defer.inlineCallbacks
def update_badge(self):
new_badge = yield self._get_badge_count()
if self.badge != new_badge:
self.badge = new_badge
yield self.send_badge(self.badge)
def send_badge(self, badge):
"""
Overridden by implementing classes to send an updated badge count
"""
pass
@defer.inlineCallbacks
def _get_badge_count(self):
invites, joins = yield defer.gatherResults([
self.store.get_invited_rooms_for_user(self.user_id),
self.store.get_rooms_for_user(self.user_id),
], consumeErrors=True)
my_receipts_by_room = yield self.store.get_receipts_for_user(
self.user_id,
"m.read",
)
badge = len(invites)
for r in joins:
if r.room_id in my_receipts_by_room:
last_unread_event_id = my_receipts_by_room[r.room_id]
notifs = yield (
self.store.get_unread_event_push_actions_by_room_for_user(
r.room_id, self.user_id, last_unread_event_id
)
)
badge += notifs["notify_count"]
defer.returnValue(badge)
class PusherConfigException(Exception): class PusherConfigException(Exception):
def __init__(self, msg): def __init__(self, msg):

View file

@ -15,7 +15,9 @@
from twisted.internet import defer from twisted.internet import defer
from .bulk_push_rule_evaluator import evaluator_for_room_id from .bulk_push_rule_evaluator import evaluator_for_event
from synapse.util.metrics import Measure
import logging import logging
@ -25,6 +27,7 @@ logger = logging.getLogger(__name__)
class ActionGenerator: class ActionGenerator:
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
# really we want to get all user ids and all profile tags too, # really we want to get all user ids and all profile tags too,
# since we want the actions for each profile tag for every user and # since we want the actions for each profile tag for every user and
@ -34,15 +37,16 @@ class ActionGenerator:
# tag (ie. we just need all the users). # tag (ie. we just need all the users).
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context, handler): def handle_push_actions_for_event(self, event, context):
bulk_evaluator = yield evaluator_for_room_id( with Measure(self.clock, "handle_push_actions_for_event"):
event.room_id, self.hs, self.store bulk_evaluator = yield evaluator_for_event(
) event, self.hs, self.store, context.current_state
)
actions_by_user = yield bulk_evaluator.action_for_event_by_user( actions_by_user = yield bulk_evaluator.action_for_event_by_user(
event, handler, context.current_state event, context.current_state
) )
context.push_actions = [ context.push_actions = [
(uid, actions) for uid, actions in actions_by_user.items() (uid, actions) for uid, actions in actions_by_user.items()
] ]

View file

@ -19,9 +19,11 @@ import copy
def list_with_base_rules(rawrules): def list_with_base_rules(rawrules):
"""Combine the list of rules set by the user with the default push rules """Combine the list of rules set by the user with the default push rules
:param list rawrules: The rules the user has modified or set. Args:
:returns: A new list with the rules set by the user combined with the rawrules(list): The rules the user has modified or set.
defaults.
Returns:
A new list with the rules set by the user combined with the defaults.
""" """
ruleslist = [] ruleslist = []
@ -77,7 +79,7 @@ def make_base_append_rules(kind, modified_base_rules):
rules = [] rules = []
if kind == 'override': if kind == 'override':
rules = BASE_APPEND_OVRRIDE_RULES rules = BASE_APPEND_OVERRIDE_RULES
elif kind == 'underride': elif kind == 'underride':
rules = BASE_APPEND_UNDERRIDE_RULES rules = BASE_APPEND_UNDERRIDE_RULES
elif kind == 'content': elif kind == 'content':
@ -146,7 +148,7 @@ BASE_PREPEND_OVERRIDE_RULES = [
] ]
BASE_APPEND_OVRRIDE_RULES = [ BASE_APPEND_OVERRIDE_RULES = [
{ {
'rule_id': 'global/override/.m.rule.suppress_notices', 'rule_id': 'global/override/.m.rule.suppress_notices',
'conditions': [ 'conditions': [
@ -160,7 +162,61 @@ BASE_APPEND_OVRRIDE_RULES = [
'actions': [ 'actions': [
'dont_notify', 'dont_notify',
] ]
} },
# NB. .m.rule.invite_for_me must be higher prio than .m.rule.member_event
# otherwise invites will be matched by .m.rule.member_event
{
'rule_id': 'global/override/.m.rule.invite_for_me',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.member',
'_id': '_member',
},
{
'kind': 'event_match',
'key': 'content.membership',
'pattern': 'invite',
'_id': '_invite_member',
},
{
'kind': 'event_match',
'key': 'state_key',
'pattern_type': 'user_id'
},
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight',
'value': False
}
]
},
# Will we sometimes want to know about people joining and leaving?
# Perhaps: if so, this could be expanded upon. Seems the most usual case
# is that we don't though. We add this override rule so that even if
# the room rule is set to notify, we don't get notifications about
# join/leave/avatar/displayname events.
# See also: https://matrix.org/jira/browse/SYN-607
{
'rule_id': 'global/override/.m.rule.member_event',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.member',
'_id': '_member',
}
],
'actions': [
'dont_notify'
]
},
] ]
@ -229,57 +285,6 @@ BASE_APPEND_UNDERRIDE_RULES = [
} }
] ]
}, },
{
'rule_id': 'global/underride/.m.rule.invite_for_me',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.member',
'_id': '_member',
},
{
'kind': 'event_match',
'key': 'content.membership',
'pattern': 'invite',
'_id': '_invite_member',
},
{
'kind': 'event_match',
'key': 'state_key',
'pattern_type': 'user_id'
},
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight',
'value': False
}
]
},
# This is too simple: https://matrix.org/jira/browse/SYN-607
# Removing for now
# {
# 'rule_id': 'global/underride/.m.rule.member_event',
# 'conditions': [
# {
# 'kind': 'event_match',
# 'key': 'type',
# 'pattern': 'm.room.member',
# '_id': '_member',
# }
# ],
# 'actions': [
# 'notify', {
# 'set_tweak': 'highlight',
# 'value': False
# }
# ]
# },
{ {
'rule_id': 'global/underride/.m.rule.message', 'rule_id': 'global/underride/.m.rule.message',
'conditions': [ 'conditions': [
@ -312,7 +317,7 @@ for r in BASE_PREPEND_OVERRIDE_RULES:
r['default'] = True r['default'] = True
BASE_RULE_IDS.add(r['rule_id']) BASE_RULE_IDS.add(r['rule_id'])
for r in BASE_APPEND_OVRRIDE_RULES: for r in BASE_APPEND_OVERRIDE_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['override'] r['priority_class'] = PRIORITY_CLASS_MAP['override']
r['default'] = True r['default'] = True
BASE_RULE_IDS.add(r['rule_id']) BASE_RULE_IDS.add(r['rule_id'])

View file

@ -14,67 +14,66 @@
# limitations under the License. # limitations under the License.
import logging import logging
import ujson as json
from twisted.internet import defer from twisted.internet import defer
from .baserules import list_with_base_rules
from .push_rule_evaluator import PushRuleEvaluatorForEvent from .push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes, Membership
from synapse.visibility import filter_events_for_clients
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def decode_rule_json(rule):
rule['conditions'] = json.loads(rule['conditions'])
rule['actions'] = json.loads(rule['actions'])
return rule
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_rules(room_id, user_ids, store): def _get_rules(room_id, user_ids, store):
rules_by_user = yield store.bulk_get_push_rules(user_ids) rules_by_user = yield store.bulk_get_push_rules(user_ids)
rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids)
rules_by_user = { rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
uid: list_with_base_rules([
decode_rule_json(rule_list)
for rule_list in rules_by_user.get(uid, [])
])
for uid in user_ids
}
# We apply the rules-enabled map here: bulk_get_push_rules doesn't
# fetch disabled rules, but this won't account for any server default
# rules the user has disabled, so we need to do this too.
for uid in user_ids:
if uid not in rules_enabled_by_user:
continue
user_enabled_map = rules_enabled_by_user[uid]
for i, rule in enumerate(rules_by_user[uid]):
rule_id = rule['rule_id']
if rule_id in user_enabled_map:
if rule.get('enabled', True) != bool(user_enabled_map[rule_id]):
# Rules are cached across users.
rule = dict(rule)
rule['enabled'] = bool(user_enabled_map[rule_id])
rules_by_user[uid][i] = rule
defer.returnValue(rules_by_user) defer.returnValue(rules_by_user)
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_room_id(room_id, hs, store): def evaluator_for_event(event, hs, store, current_state):
results = yield store.get_receipts_for_room(room_id, "m.read") room_id = event.room_id
user_ids = [ # We also will want to generate notifs for other people in the room so
row["user_id"] for row in results # their unread countss are correct in the event stream, but to avoid
if hs.is_mine_id(row["user_id"]) # generating them for bot / AS users etc, we only do so for people who've
] # sent a read receipt into the room.
local_users_in_room = set(
e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
and hs.is_mine_id(e.state_key)
)
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield store.get_if_users_have_pushers(
local_users_in_room
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
)
users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in local_users_in_room:
user_ids.add(uid)
# if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited
if event.type == 'm.room.member' and event.content['membership'] == 'invite':
invited_user = event.state_key
if invited_user and hs.is_mine_id(invited_user):
has_pusher = yield store.user_has_pusher(invited_user)
if has_pusher:
user_ids.add(invited_user)
rules_by_user = yield _get_rules(room_id, user_ids, store) rules_by_user = yield _get_rules(room_id, user_ids, store)
defer.returnValue(BulkPushRuleEvaluator( defer.returnValue(BulkPushRuleEvaluator(
@ -98,16 +97,24 @@ class BulkPushRuleEvaluator:
self.store = store self.store = store
@defer.inlineCallbacks @defer.inlineCallbacks
def action_for_event_by_user(self, event, handler, current_state): def action_for_event_by_user(self, event, current_state):
actions_by_user = {} actions_by_user = {}
users_dict = yield self.store.are_guests(self.rules_by_user.keys()) # None of these users can be peeking since this list of users comes
# from the set of users in the room, so we know for sure they're all
# actually in the room.
user_tuples = [
(u, False) for u in self.rules_by_user.keys()
]
filtered_by_user = yield handler.filter_events_for_clients( filtered_by_user = yield filter_events_for_clients(
users_dict.items(), [event], {event.event_id: current_state} self.store, user_tuples, [event], {event.event_id: current_state}
) )
room_members = yield self.store.get_users_in_room(self.room_id) room_members = set(
e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
)
evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) evaluator = PushRuleEvaluatorForEvent(event, len(room_members))

View file

@ -13,29 +13,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.push.baserules import list_with_base_rules
from synapse.push.rulekinds import ( from synapse.push.rulekinds import (
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
) )
import copy import copy
import simplejson as json
def format_push_rules_for_user(user, rawrules, enabled_map): def format_push_rules_for_user(user, ruleslist):
"""Converts a list of rawrules and a enabled map into nested dictionaries """Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules""" to match the Matrix client-server format for push rules"""
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
rule["conditions"] = json.loads(rawrule["conditions"])
rule["actions"] = json.loads(rawrule["actions"])
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy # We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(list_with_base_rules(ruleslist)) ruleslist = copy.deepcopy(ruleslist)
rules = {'global': {}, 'device': {}} rules = {'global': {}, 'device': {}}
@ -60,9 +50,7 @@ def format_push_rules_for_user(user, rawrules, enabled_map):
template_rule = _rule_to_template(r) template_rule = _rule_to_template(r)
if template_rule: if template_rule:
if r['rule_id'] in enabled_map: if 'enabled' in r:
template_rule['enabled'] = enabled_map[r['rule_id']]
elif 'enabled' in r:
template_rule['enabled'] = r['enabled'] template_rule['enabled'] = r['enabled']
else: else:
template_rule['enabled'] = True template_rule['enabled'] = True

283
synapse/push/emailpusher.py Normal file
View file

@ -0,0 +1,283 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer, reactor
import logging
from synapse.util.metrics import Measure
from synapse.util.logcontext import LoggingContext
from mailer import Mailer
logger = logging.getLogger(__name__)
# The amount of time we always wait before ever emailing about a notification
# (to give the user a chance to respond to other push or notice the window)
DELAY_BEFORE_MAIL_MS = 10 * 60 * 1000
# THROTTLE is the minimum time between mail notifications sent for a given room.
# Each room maintains its own throttle counter, but each new mail notification
# sends the pending notifications for all rooms.
THROTTLE_START_MS = 10 * 60 * 1000
THROTTLE_MAX_MS = 24 * 60 * 60 * 1000 # 24h
# THROTTLE_MULTIPLIER = 6 # 10 mins, 1 hour, 6 hours, 24 hours
THROTTLE_MULTIPLIER = 144 # 10 mins, 24 hours - i.e. jump straight to 1 day
# If no event triggers a notification for this long after the previous,
# the throttle is released.
# 12 hours - a gap of 12 hours in conversation is surely enough to merit a new
# notification when things get going again...
THROTTLE_RESET_AFTER_MS = (12 * 60 * 60 * 1000)
# does each email include all unread notifs, or just the ones which have happened
# since the last mail?
# XXX: this is currently broken as it includes ones from parted rooms(!)
INCLUDE_ALL_UNREAD_NOTIFS = False
class EmailPusher(object):
"""
A pusher that sends email notifications about events (approximately)
when they happen.
This shares quite a bit of code with httpusher: it would be good to
factor out the common parts
"""
def __init__(self, hs, pusherdict):
self.hs = hs
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.pusher_id = pusherdict['id']
self.user_id = pusherdict['user_name']
self.app_id = pusherdict['app_id']
self.email = pusherdict['pushkey']
self.last_stream_ordering = pusherdict['last_stream_ordering']
self.timed_call = None
self.throttle_params = None
# See httppusher
self.max_stream_ordering = None
self.processing = False
if self.hs.config.email_enable_notifs:
if 'data' in pusherdict and 'brand' in pusherdict['data']:
app_name = pusherdict['data']['brand']
else:
app_name = self.hs.config.email_app_name
self.mailer = Mailer(self.hs, app_name)
else:
self.mailer = None
@defer.inlineCallbacks
def on_started(self):
if self.mailer is not None:
self.throttle_params = yield self.store.get_throttle_params_by_room(
self.pusher_id
)
yield self._process()
def on_stop(self):
if self.timed_call:
self.timed_call.cancel()
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
yield self._process()
def on_new_receipts(self, min_stream_id, max_stream_id):
# We could wake up and cancel the timer but there tend to be quite a
# lot of read receipts so it's probably less work to just let the
# timer fire
return defer.succeed(None)
@defer.inlineCallbacks
def on_timer(self):
self.timed_call = None
yield self._process()
@defer.inlineCallbacks
def _process(self):
if self.processing:
return
with LoggingContext("emailpush._process"):
with Measure(self.clock, "emailpush._process"):
try:
self.processing = True
# if the max ordering changes while we're running _unsafe_process,
# call it again, and so on until we've caught up.
while True:
starting_max_ordering = self.max_stream_ordering
try:
yield self._unsafe_process()
except:
logger.exception("Exception processing notifs")
if self.max_stream_ordering == starting_max_ordering:
break
finally:
self.processing = False
@defer.inlineCallbacks
def _unsafe_process(self):
"""
Main logic of the push loop without the wrapper function that sets
up logging, measures and guards against multiple instances of it
being run.
"""
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
self.user_id, start, self.max_stream_ordering
)
soonest_due_at = None
for push_action in unprocessed:
received_at = push_action['received_ts']
if received_at is None:
received_at = 0
notif_ready_at = received_at + DELAY_BEFORE_MAIL_MS
room_ready_at = self.room_ready_to_notify_at(
push_action['room_id']
)
should_notify_at = max(notif_ready_at, room_ready_at)
if should_notify_at < self.clock.time_msec():
# one of our notifications is ready for sending, so we send
# *one* email updating the user on their notifications,
# we then consider all previously outstanding notifications
# to be delivered.
reason = {
'room_id': push_action['room_id'],
'now': self.clock.time_msec(),
'received_at': received_at,
'delay_before_mail_ms': DELAY_BEFORE_MAIL_MS,
'last_sent_ts': self.get_room_last_sent_ts(push_action['room_id']),
'throttle_ms': self.get_room_throttle_ms(push_action['room_id']),
}
yield self.send_notification(unprocessed, reason)
yield self.save_last_stream_ordering_and_success(max([
ea['stream_ordering'] for ea in unprocessed
]))
# we update the throttle on all the possible unprocessed push actions
for ea in unprocessed:
yield self.sent_notif_update_throttle(
ea['room_id'], ea
)
break
else:
if soonest_due_at is None or should_notify_at < soonest_due_at:
soonest_due_at = should_notify_at
if self.timed_call is not None:
self.timed_call.cancel()
self.timed_call = None
if soonest_due_at is not None:
self.timed_call = reactor.callLater(
self.seconds_until(soonest_due_at), self.on_timer
)
@defer.inlineCallbacks
def save_last_stream_ordering_and_success(self, last_stream_ordering):
self.last_stream_ordering = last_stream_ordering
yield self.store.update_pusher_last_stream_ordering_and_success(
self.app_id, self.email, self.user_id,
last_stream_ordering, self.clock.time_msec()
)
def seconds_until(self, ts_msec):
return (ts_msec - self.clock.time_msec()) / 1000
def get_room_throttle_ms(self, room_id):
if room_id in self.throttle_params:
return self.throttle_params[room_id]["throttle_ms"]
else:
return 0
def get_room_last_sent_ts(self, room_id):
if room_id in self.throttle_params:
return self.throttle_params[room_id]["last_sent_ts"]
else:
return 0
def room_ready_to_notify_at(self, room_id):
"""
Determines whether throttling should prevent us from sending an email
for the given room
Returns: The timestamp when we are next allowed to send an email notif
for this room
"""
last_sent_ts = self.get_room_last_sent_ts(room_id)
throttle_ms = self.get_room_throttle_ms(room_id)
may_send_at = last_sent_ts + throttle_ms
return may_send_at
@defer.inlineCallbacks
def sent_notif_update_throttle(self, room_id, notified_push_action):
# We have sent a notification, so update the throttle accordingly.
# If the event that triggered the notif happened more than
# THROTTLE_RESET_AFTER_MS after the previous one that triggered a
# notif, we release the throttle. Otherwise, the throttle is increased.
time_of_previous_notifs = yield self.store.get_time_of_last_push_action_before(
notified_push_action['stream_ordering']
)
time_of_this_notifs = notified_push_action['received_ts']
if time_of_previous_notifs is not None and time_of_this_notifs is not None:
gap = time_of_this_notifs - time_of_previous_notifs
else:
# if we don't know the arrival time of one of the notifs (it was not
# stored prior to email notification code) then assume a gap of
# zero which will just not reset the throttle
gap = 0
current_throttle_ms = self.get_room_throttle_ms(room_id)
if gap > THROTTLE_RESET_AFTER_MS:
new_throttle_ms = THROTTLE_START_MS
else:
if current_throttle_ms == 0:
new_throttle_ms = THROTTLE_START_MS
else:
new_throttle_ms = min(
current_throttle_ms * THROTTLE_MULTIPLIER,
THROTTLE_MAX_MS
)
self.throttle_params[room_id] = {
"last_sent_ts": self.clock.time_msec(),
"throttle_ms": new_throttle_ms
}
yield self.store.set_throttle_params(
self.pusher_id, room_id, self.throttle_params[room_id]
)
@defer.inlineCallbacks
def send_notification(self, push_actions, reason):
logger.info("Sending notif email for user %r", self.user_id)
yield self.mailer.send_notification_mail(
self.app_id, self.user_id, self.email, push_actions, reason
)

View file

@ -13,60 +13,239 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.push import Pusher, PusherConfigException from synapse.push import PusherConfigException
from twisted.internet import defer from twisted.internet import defer, reactor
import logging import logging
import push_rule_evaluator
import push_tools
from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class HttpPusher(Pusher): class HttpPusher(object):
def __init__(self, _hs, user_id, app_id, INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
app_display_name, device_display_name, pushkey, pushkey_ts, MAX_BACKOFF_SEC = 60 * 60
data, last_token, last_success, failing_since):
super(HttpPusher, self).__init__( # This one's in ms because we compare it against the clock
_hs, GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000
user_id,
app_id, def __init__(self, hs, pusherdict):
app_display_name, self.hs = hs
device_display_name, self.store = self.hs.get_datastore()
pushkey, self.clock = self.hs.get_clock()
pushkey_ts, self.user_id = pusherdict['user_name']
data, self.app_id = pusherdict['app_id']
last_token, self.app_display_name = pusherdict['app_display_name']
last_success, self.device_display_name = pusherdict['device_display_name']
failing_since self.pushkey = pusherdict['pushkey']
self.pushkey_ts = pusherdict['ts']
self.data = pusherdict['data']
self.last_stream_ordering = pusherdict['last_stream_ordering']
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.failing_since = pusherdict['failing_since']
self.timed_call = None
self.processing = False
# This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we
# should honour this rather than just looking for anything higher
# because of potential out-of-order event serialisation. This starts
# off as None though as we don't know any better.
self.max_stream_ordering = None
if 'data' not in pusherdict:
raise PusherConfigException(
"No 'data' key for HTTP pusher"
)
self.data = pusherdict['data']
self.name = "%s/%s/%s" % (
pusherdict['user_name'],
pusherdict['app_id'],
pusherdict['pushkey'],
) )
if 'url' not in data:
if 'url' not in self.data:
raise PusherConfigException( raise PusherConfigException(
"'url' required in data for HTTP pusher" "'url' required in data for HTTP pusher"
) )
self.url = data['url'] self.url = self.data['url']
self.http_client = _hs.get_simple_http_client() self.http_client = hs.get_simple_http_client()
self.data_minus_url = {} self.data_minus_url = {}
self.data_minus_url.update(self.data) self.data_minus_url.update(self.data)
del self.data_minus_url['url'] del self.data_minus_url['url']
@defer.inlineCallbacks @defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge): def on_started(self):
# we probably do not want to push for every presence update yield self._process()
# (we may want to be able to set up notifications when specific
# people sign in, but we'd want to only deliver the pertinent ones)
# Actually, presence events will not get this far now because we
# need to filter them out in the main Pusher code.
if 'event_id' not in event:
defer.returnValue(None)
ctx = yield self.get_context_for_event(event) @defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
yield self._process()
@defer.inlineCallbacks
def on_new_receipts(self, min_stream_id, max_stream_id):
# Note that the min here shouldn't be relied upon to be accurate.
# We could check the receipts are actually m.read receipts here,
# but currently that's the only type of receipt anyway...
with LoggingContext("push.on_new_receipts"):
with Measure(self.clock, "push.on_new_receipts"):
badge = yield push_tools.get_badge_count(
self.hs.get_datastore(), self.user_id
)
yield self._send_badge(badge)
@defer.inlineCallbacks
def on_timer(self):
yield self._process()
def on_stop(self):
if self.timed_call:
self.timed_call.cancel()
@defer.inlineCallbacks
def _process(self):
if self.processing:
return
with LoggingContext("push._process"):
with Measure(self.clock, "push._process"):
try:
self.processing = True
# if the max ordering changes while we're running _unsafe_process,
# call it again, and so on until we've caught up.
while True:
starting_max_ordering = self.max_stream_ordering
try:
yield self._unsafe_process()
except:
logger.exception("Exception processing notifs")
if self.max_stream_ordering == starting_max_ordering:
break
finally:
self.processing = False
@defer.inlineCallbacks
def _unsafe_process(self):
"""
Looks for unset notifications and dispatch them, in order
Never call this directly: use _process which will only allow this to
run once per pusher.
"""
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
self.user_id, self.last_stream_ordering, self.max_stream_ordering
)
for push_action in unprocessed:
processed = yield self._process_one(push_action)
if processed:
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action['stream_ordering']
yield self.store.update_pusher_last_stream_ordering_and_success(
self.app_id, self.pushkey, self.user_id,
self.last_stream_ordering,
self.clock.time_msec()
)
if self.failing_since:
self.failing_since = None
yield self.store.update_pusher_failing_since(
self.app_id, self.pushkey, self.user_id,
self.failing_since
)
else:
if not self.failing_since:
self.failing_since = self.clock.time_msec()
yield self.store.update_pusher_failing_since(
self.app_id, self.pushkey, self.user_id,
self.failing_since
)
if (
self.failing_since and
self.failing_since <
self.clock.time_msec() - HttpPusher.GIVE_UP_AFTER_MS
):
# we really only give up so that if the URL gets
# fixed, we don't suddenly deliver a load
# of old notifications.
logger.warn("Giving up on a notification to user %s, "
"pushkey %s",
self.user_id, self.pushkey)
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action['stream_ordering']
yield self.store.update_pusher_last_stream_ordering(
self.app_id,
self.pushkey,
self.user_id,
self.last_stream_ordering
)
self.failing_since = None
yield self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_id,
self.failing_since
)
else:
logger.info("Push failed: delaying for %ds", self.backoff_delay)
self.timed_call = reactor.callLater(self.backoff_delay, self.on_timer)
self.backoff_delay = min(self.backoff_delay * 2, self.MAX_BACKOFF_SEC)
break
@defer.inlineCallbacks
def _process_one(self, push_action):
if 'notify' not in push_action['actions']:
defer.returnValue(True)
tweaks = push_rule_evaluator.tweaks_for_actions(push_action['actions'])
badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
event = yield self.store.get_event(push_action['event_id'], allow_none=True)
if event is None:
defer.returnValue(True) # It's been redacted
rejected = yield self.dispatch_push(event, tweaks, badge)
if rejected is False:
defer.returnValue(False)
if isinstance(rejected, list) or isinstance(rejected, tuple):
for pk in rejected:
if pk != self.pushkey:
# for sanity, we only remove the pushkey if it
# was the one we actually sent...
logger.warn(
("Ignoring rejected pushkey %s because we"
" didn't send it"), pk
)
else:
logger.info(
"Pushkey %s was rejected: removing",
pk
)
yield self.hs.remove_pusher(
self.app_id, pk, self.user_id
)
defer.returnValue(True)
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge):
ctx = yield push_tools.get_context_for_event(self.hs.get_datastore(), event)
d = { d = {
'notification': { 'notification': {
'id': event['event_id'], 'id': event.event_id, # deprecated: remove soon
'room_id': event['room_id'], 'event_id': event.event_id,
'type': event['type'], 'room_id': event.room_id,
'sender': event['user_id'], 'type': event.type,
'sender': event.user_id,
'counts': { # -- we don't mark messages as read yet so 'counts': { # -- we don't mark messages as read yet so
# we have no way of knowing # we have no way of knowing
# Just set the badge to 1 until we have read receipts # Just set the badge to 1 until we have read receipts
@ -84,11 +263,11 @@ class HttpPusher(Pusher):
] ]
} }
} }
if event['type'] == 'm.room.member': if event.type == 'm.room.member':
d['notification']['membership'] = event['content']['membership'] d['notification']['membership'] = event.content['membership']
d['notification']['user_is_target'] = event['state_key'] == self.user_id d['notification']['user_is_target'] = event.state_key == self.user_id
if 'content' in event: if 'content' in event:
d['notification']['content'] = event['content'] d['notification']['content'] = event.content
if len(ctx['aliases']): if len(ctx['aliases']):
d['notification']['room_alias'] = ctx['aliases'][0] d['notification']['room_alias'] = ctx['aliases'][0]
@ -115,7 +294,7 @@ class HttpPusher(Pusher):
defer.returnValue(rejected) defer.returnValue(rejected)
@defer.inlineCallbacks @defer.inlineCallbacks
def send_badge(self, badge): def _send_badge(self, badge):
logger.info("Sending updated badge count %d to %r", badge, self.user_id) logger.info("Sending updated badge count %d to %r", badge, self.user_id)
d = { d = {
'notification': { 'notification': {

514
synapse/push/mailer.py Normal file
View file

@ -0,0 +1,514 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from twisted.mail.smtp import sendmail
import email.utils
import email.mime.multipart
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from synapse.util.async import concurrently_execute
from synapse.util.presentable_names import (
calculate_room_name, name_from_member_event, descriptor_from_member_events
)
from synapse.types import UserID
from synapse.api.errors import StoreError
from synapse.api.constants import EventTypes
from synapse.visibility import filter_events_for_client
import jinja2
import bleach
import time
import urllib
import logging
logger = logging.getLogger(__name__)
MESSAGE_FROM_PERSON_IN_ROOM = "You have a message on %(app)s from %(person)s " \
"in the %(room)s room..."
MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..."
MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..."
MESSAGES_IN_ROOM = "You have messages on %(app)s in the %(room)s room..."
MESSAGES_IN_ROOM_AND_OTHERS = \
"You have messages on %(app)s in the %(room)s room and others..."
MESSAGES_FROM_PERSON_AND_OTHERS = \
"You have messages on %(app)s from %(person)s and others..."
INVITE_FROM_PERSON_TO_ROOM = "%(person)s has invited you to join the " \
"%(room)s room on %(app)s..."
INVITE_FROM_PERSON = "%(person)s has invited you to chat on %(app)s..."
CONTEXT_BEFORE = 1
CONTEXT_AFTER = 1
# From https://github.com/matrix-org/matrix-react-sdk/blob/master/src/HtmlUtils.js
ALLOWED_TAGS = [
'font', # custom to matrix for IRC-style font coloring
'del', # for markdown
# deliberately no h1/h2 to stop people shouting.
'h3', 'h4', 'h5', 'h6', 'blockquote', 'p', 'a', 'ul', 'ol',
'nl', 'li', 'b', 'i', 'u', 'strong', 'em', 'strike', 'code', 'hr', 'br', 'div',
'table', 'thead', 'caption', 'tbody', 'tr', 'th', 'td', 'pre'
]
ALLOWED_ATTRS = {
# custom ones first:
"font": ["color"], # custom to matrix
"a": ["href", "name", "target"], # remote target: custom to matrix
# We don't currently allow img itself by default, but this
# would make sense if we did
"img": ["src"],
}
# When bleach release a version with this option, we can specify schemes
# ALLOWED_SCHEMES = ["http", "https", "ftp", "mailto"]
class Mailer(object):
def __init__(self, hs, app_name):
self.hs = hs
self.store = self.hs.get_datastore()
self.auth_handler = self.hs.get_auth_handler()
self.state_handler = self.hs.get_state_handler()
loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
self.app_name = app_name
logger.info("Created Mailer for app_name %s" % app_name)
env = jinja2.Environment(loader=loader)
env.filters["format_ts"] = format_ts_filter
env.filters["mxc_to_http"] = self.mxc_to_http_filter
self.notif_template_html = env.get_template(
self.hs.config.email_notif_template_html
)
self.notif_template_text = env.get_template(
self.hs.config.email_notif_template_text
)
@defer.inlineCallbacks
def send_notification_mail(self, app_id, user_id, email_address,
push_actions, reason):
try:
from_string = self.hs.config.email_notif_from % {
"app": self.app_name
}
except TypeError:
from_string = self.hs.config.email_notif_from
raw_from = email.utils.parseaddr(from_string)[1]
raw_to = email.utils.parseaddr(email_address)[1]
if raw_to == '':
raise RuntimeError("Invalid 'to' address")
rooms_in_order = deduped_ordered_list(
[pa['room_id'] for pa in push_actions]
)
notif_events = yield self.store.get_events(
[pa['event_id'] for pa in push_actions]
)
notifs_by_room = {}
for pa in push_actions:
notifs_by_room.setdefault(pa["room_id"], []).append(pa)
# collect the current state for all the rooms in which we have
# notifications
state_by_room = {}
try:
user_display_name = yield self.store.get_profile_displayname(
UserID.from_string(user_id).localpart
)
if user_display_name is None:
user_display_name = user_id
except StoreError:
user_display_name = user_id
@defer.inlineCallbacks
def _fetch_room_state(room_id):
room_state = yield self.state_handler.get_current_state(room_id)
state_by_room[room_id] = room_state
# Run at most 3 of these at once: sync does 10 at a time but email
# notifs are much less realtime than sync so we can afford to wait a bit.
yield concurrently_execute(_fetch_room_state, rooms_in_order, 3)
# actually sort our so-called rooms_in_order list, most recent room first
rooms_in_order.sort(
key=lambda r: -(notifs_by_room[r][-1]['received_ts'] or 0)
)
rooms = []
for r in rooms_in_order:
roomvars = yield self.get_room_vars(
r, user_id, notifs_by_room[r], notif_events, state_by_room[r]
)
rooms.append(roomvars)
reason['room_name'] = calculate_room_name(
state_by_room[reason['room_id']], user_id, fallback_to_members=True
)
summary_text = self.make_summary_text(
notifs_by_room, state_by_room, notif_events, user_id, reason
)
template_vars = {
"user_display_name": user_display_name,
"unsubscribe_link": self.make_unsubscribe_link(
user_id, app_id, email_address
),
"summary_text": summary_text,
"app_name": self.app_name,
"rooms": rooms,
"reason": reason,
}
html_text = self.notif_template_html.render(**template_vars)
html_part = MIMEText(html_text, "html", "utf8")
plain_text = self.notif_template_text.render(**template_vars)
text_part = MIMEText(plain_text, "plain", "utf8")
multipart_msg = MIMEMultipart('alternative')
multipart_msg['Subject'] = "[%s] %s" % (self.app_name, summary_text)
multipart_msg['From'] = from_string
multipart_msg['To'] = email_address
multipart_msg['Date'] = email.utils.formatdate()
multipart_msg['Message-ID'] = email.utils.make_msgid()
multipart_msg.attach(text_part)
multipart_msg.attach(html_part)
logger.info("Sending email push notification to %s" % email_address)
# logger.debug(html_text)
yield sendmail(
self.hs.config.email_smtp_host,
raw_from, raw_to, multipart_msg.as_string(),
port=self.hs.config.email_smtp_port
)
@defer.inlineCallbacks
def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state):
my_member_event = room_state[("m.room.member", user_id)]
is_invite = my_member_event.content["membership"] == "invite"
room_vars = {
"title": calculate_room_name(room_state, user_id),
"hash": string_ordinal_total(room_id), # See sender avatar hash
"notifs": [],
"invite": is_invite,
"link": self.make_room_link(room_id),
}
if not is_invite:
for n in notifs:
notifvars = yield self.get_notif_vars(
n, user_id, notif_events[n['event_id']], room_state
)
# merge overlapping notifs together.
# relies on the notifs being in chronological order.
merge = False
if room_vars['notifs'] and 'messages' in room_vars['notifs'][-1]:
prev_messages = room_vars['notifs'][-1]['messages']
for message in notifvars['messages']:
pm = filter(lambda pm: pm['id'] == message['id'], prev_messages)
if pm:
if not message["is_historical"]:
pm[0]["is_historical"] = False
merge = True
elif merge:
# we're merging, so append any remaining messages
# in this notif to the previous one
prev_messages.append(message)
if not merge:
room_vars['notifs'].append(notifvars)
defer.returnValue(room_vars)
@defer.inlineCallbacks
def get_notif_vars(self, notif, user_id, notif_event, room_state):
results = yield self.store.get_events_around(
notif['room_id'], notif['event_id'],
before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER
)
ret = {
"link": self.make_notif_link(notif),
"ts": notif['received_ts'],
"messages": [],
}
the_events = yield filter_events_for_client(
self.store, user_id, results["events_before"]
)
the_events.append(notif_event)
for event in the_events:
messagevars = self.get_message_vars(notif, event, room_state)
if messagevars is not None:
ret['messages'].append(messagevars)
defer.returnValue(ret)
def get_message_vars(self, notif, event, room_state):
if event.type != EventTypes.Message:
return None
sender_state_event = room_state[("m.room.member", event.sender)]
sender_name = name_from_member_event(sender_state_event)
sender_avatar_url = None
if "avatar_url" in sender_state_event.content:
sender_avatar_url = sender_state_event.content["avatar_url"]
# 'hash' for deterministically picking default images: use
# sender_hash % the number of default images to choose from
sender_hash = string_ordinal_total(event.sender)
ret = {
"msgtype": event.content["msgtype"],
"is_historical": event.event_id != notif['event_id'],
"id": event.event_id,
"ts": event.origin_server_ts,
"sender_name": sender_name,
"sender_avatar_url": sender_avatar_url,
"sender_hash": sender_hash,
}
if event.content["msgtype"] == "m.text":
self.add_text_message_vars(ret, event)
elif event.content["msgtype"] == "m.image":
self.add_image_message_vars(ret, event)
if "body" in event.content:
ret["body_text_plain"] = event.content["body"]
return ret
def add_text_message_vars(self, messagevars, event):
if "format" in event.content:
msgformat = event.content["format"]
else:
msgformat = None
messagevars["format"] = msgformat
if msgformat == "org.matrix.custom.html":
messagevars["body_text_html"] = safe_markup(event.content["formatted_body"])
else:
messagevars["body_text_html"] = safe_text(event.content["body"])
return messagevars
def add_image_message_vars(self, messagevars, event):
messagevars["image_url"] = event.content["url"]
return messagevars
def make_summary_text(self, notifs_by_room, state_by_room,
notif_events, user_id, reason):
if len(notifs_by_room) == 1:
# Only one room has new stuff
room_id = notifs_by_room.keys()[0]
# If the room has some kind of name, use it, but we don't
# want the generated-from-names one here otherwise we'll
# end up with, "new message from Bob in the Bob room"
room_name = calculate_room_name(
state_by_room[room_id], user_id, fallback_to_members=False
)
my_member_event = state_by_room[room_id][("m.room.member", user_id)]
if my_member_event.content["membership"] == "invite":
inviter_member_event = state_by_room[room_id][
("m.room.member", my_member_event.sender)
]
inviter_name = name_from_member_event(inviter_member_event)
if room_name is None:
return INVITE_FROM_PERSON % {
"person": inviter_name,
"app": self.app_name
}
else:
return INVITE_FROM_PERSON_TO_ROOM % {
"person": inviter_name,
"room": room_name,
"app": self.app_name,
}
sender_name = None
if len(notifs_by_room[room_id]) == 1:
# There is just the one notification, so give some detail
event = notif_events[notifs_by_room[room_id][0]["event_id"]]
if ("m.room.member", event.sender) in state_by_room[room_id]:
state_event = state_by_room[room_id][("m.room.member", event.sender)]
sender_name = name_from_member_event(state_event)
if sender_name is not None and room_name is not None:
return MESSAGE_FROM_PERSON_IN_ROOM % {
"person": sender_name,
"room": room_name,
"app": self.app_name,
}
elif sender_name is not None:
return MESSAGE_FROM_PERSON % {
"person": sender_name,
"app": self.app_name,
}
else:
# There's more than one notification for this room, so just
# say there are several
if room_name is not None:
return MESSAGES_IN_ROOM % {
"room": room_name,
"app": self.app_name,
}
else:
# If the room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room"
sender_ids = list(set([
notif_events[n['event_id']].sender
for n in notifs_by_room[room_id]
]))
return MESSAGES_FROM_PERSON % {
"person": descriptor_from_member_events([
state_by_room[room_id][("m.room.member", s)]
for s in sender_ids
]),
"app": self.app_name,
}
else:
# Stuff's happened in multiple different rooms
# ...but we still refer to the 'reason' room which triggered the mail
if reason['room_name'] is not None:
return MESSAGES_IN_ROOM_AND_OTHERS % {
"room": reason['room_name'],
"app": self.app_name,
}
else:
# If the reason room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room"
sender_ids = list(set([
notif_events[n['event_id']].sender
for n in notifs_by_room[reason['room_id']]
]))
return MESSAGES_FROM_PERSON_AND_OTHERS % {
"person": descriptor_from_member_events([
state_by_room[reason['room_id']][("m.room.member", s)]
for s in sender_ids
]),
"app": self.app_name,
}
def make_room_link(self, room_id):
# need /beta for Universal Links to work on iOS
if self.app_name == "Vector":
return "https://vector.im/beta/#/room/%s" % (room_id,)
else:
return "https://matrix.to/#/%s" % (room_id,)
def make_notif_link(self, notif):
# need /beta for Universal Links to work on iOS
if self.app_name == "Vector":
return "https://vector.im/beta/#/room/%s/%s" % (
notif['room_id'], notif['event_id']
)
else:
return "https://matrix.to/#/%s/%s" % (
notif['room_id'], notif['event_id']
)
def make_unsubscribe_link(self, user_id, app_id, email_address):
params = {
"access_token": self.auth_handler.generate_delete_pusher_token(user_id),
"app_id": app_id,
"pushkey": email_address,
}
# XXX: make r0 once API is stable
return "%s_matrix/client/unstable/pushers/remove?%s" % (
self.hs.config.public_baseurl,
urllib.urlencode(params),
)
def mxc_to_http_filter(self, value, width, height, resize_method="crop"):
if value[0:6] != "mxc://":
return ""
serverAndMediaId = value[6:]
fragment = None
if '#' in serverAndMediaId:
(serverAndMediaId, fragment) = serverAndMediaId.split('#', 1)
fragment = "#" + fragment
params = {
"width": width,
"height": height,
"method": resize_method,
}
return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
self.hs.config.public_baseurl,
serverAndMediaId,
urllib.urlencode(params),
fragment or "",
)
def safe_markup(raw_html):
return jinja2.Markup(bleach.linkify(bleach.clean(
raw_html, tags=ALLOWED_TAGS, attributes=ALLOWED_ATTRS,
# bleach master has this, but it isn't released yet
# protocols=ALLOWED_SCHEMES,
strip=True
)))
def safe_text(raw_text):
"""
Process text: treat it as HTML but escape any tags (ie. just escape the
HTML) then linkify it.
"""
return jinja2.Markup(bleach.linkify(bleach.clean(
raw_text, tags=[], attributes={},
strip=False
)))
def deduped_ordered_list(l):
seen = set()
ret = []
for item in l:
if item not in seen:
seen.add(item)
ret.append(item)
return ret
def string_ordinal_total(s):
tot = 0
for c in s:
tot += ord(c)
return tot
def format_ts_filter(value, format):
return time.strftime(format, time.localtime(value / 1000))

View file

@ -13,12 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from .baserules import list_with_base_rules
import logging import logging
import simplejson as json
import re import re
from synapse.types import UserID from synapse.types import UserID
@ -32,22 +27,6 @@ IS_GLOB = re.compile(r'[\?\*\[\]]')
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
@defer.inlineCallbacks
def evaluator_for_user_id(user_id, room_id, store):
rawrules = yield store.get_push_rules_for_user(user_id)
enabled_map = yield store.get_push_rules_enabled_for_user(user_id)
our_member_event = yield store.get_current_state(
room_id=room_id,
event_type='m.room.member',
state_key=user_id,
)
defer.returnValue(PushRuleEvaluator(
user_id, rawrules, enabled_map,
room_id, our_member_event, store
))
def _room_member_count(ev, condition, room_member_count): def _room_member_count(ev, condition, room_member_count):
if 'is' not in condition: if 'is' not in condition:
return False return False
@ -74,110 +53,14 @@ def _room_member_count(ev, condition, room_member_count):
return False return False
class PushRuleEvaluator: def tweaks_for_actions(actions):
DEFAULT_ACTIONS = [] tweaks = {}
for a in actions:
def __init__(self, user_id, raw_rules, enabled_map, room_id, if not isinstance(a, dict):
our_member_event, store): continue
self.user_id = user_id if 'set_tweak' in a and 'value' in a:
self.room_id = room_id tweaks[a['set_tweak']] = a['value']
self.our_member_event = our_member_event return tweaks
self.store = store
rules = []
for raw_rule in raw_rules:
rule = dict(raw_rule)
rule['conditions'] = json.loads(raw_rule['conditions'])
rule['actions'] = json.loads(raw_rule['actions'])
rules.append(rule)
self.rules = list_with_base_rules(rules)
self.enabled_map = enabled_map
@staticmethod
def tweaks_for_actions(actions):
tweaks = {}
for a in actions:
if not isinstance(a, dict):
continue
if 'set_tweak' in a and 'value' in a:
tweaks[a['set_tweak']] = a['value']
return tweaks
@defer.inlineCallbacks
def actions_for_event(self, ev):
"""
This should take into account notification settings that the user
has configured both globally and per-room when we have the ability
to do such things.
"""
if ev['user_id'] == self.user_id:
# let's assume you probably know about messages you sent yourself
defer.returnValue([])
room_id = ev['room_id']
# get *our* member event for display name matching
my_display_name = None
if self.our_member_event:
my_display_name = self.our_member_event[0].content.get("displayname")
room_members = yield self.store.get_users_in_room(room_id)
room_member_count = len(room_members)
evaluator = PushRuleEvaluatorForEvent(ev, room_member_count)
for r in self.rules:
enabled = self.enabled_map.get(r['rule_id'], None)
if enabled is not None and not enabled:
continue
if not r.get("enabled", True):
continue
conditions = r['conditions']
actions = r['actions']
# ignore rules with no actions (we have an explict 'dont_notify')
if len(actions) == 0:
logger.warn(
"Ignoring rule id %s with no actions for user %s",
r['rule_id'], self.user_id
)
continue
matches = True
for c in conditions:
matches = evaluator.matches(
c, self.user_id, my_display_name
)
if not matches:
break
logger.debug(
"Rule %s %s",
r['rule_id'], "matches" if matches else "doesn't match"
)
if matches:
logger.debug(
"%s matches for user %s, event %s",
r['rule_id'], self.user_id, ev['event_id']
)
# filter out dont_notify as we treat an empty actions list
# as dont_notify, and this doesn't take up a row in our database
actions = [x for x in actions if x != 'dont_notify']
defer.returnValue(actions)
logger.debug(
"No rules match for user %s, event %s",
self.user_id, ev['event_id']
)
defer.returnValue(PushRuleEvaluator.DEFAULT_ACTIONS)
class PushRuleEvaluatorForEvent(object): class PushRuleEvaluatorForEvent(object):

View file

@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
@defer.inlineCallbacks
def get_badge_count(store, user_id):
invites, joins = yield defer.gatherResults([
store.get_invited_rooms_for_user(user_id),
store.get_rooms_for_user(user_id),
], consumeErrors=True)
my_receipts_by_room = yield store.get_receipts_for_user(
user_id, "m.read",
)
badge = len(invites)
for r in joins:
if r.room_id in my_receipts_by_room:
last_unread_event_id = my_receipts_by_room[r.room_id]
notifs = yield (
store.get_unread_event_push_actions_by_room_for_user(
r.room_id, user_id, last_unread_event_id
)
)
# return one badge count per conversation, as count per
# message is so noisy as to be almost useless
badge += 1 if notifs["notify_count"] else 0
defer.returnValue(badge)
@defer.inlineCallbacks
def get_context_for_event(store, ev):
name_aliases = yield store.get_room_name_and_aliases(
ev.room_id
)
ctx = {'aliases': name_aliases[1]}
if name_aliases[0] is not None:
ctx['name'] = name_aliases[0]
their_member_events_for_room = yield store.get_current_state(
room_id=ev.room_id,
event_type='m.room.member',
state_key=ev.user_id
)
for mev in their_member_events_for_room:
if mev.content['membership'] == 'join' and 'displayname' in mev.content:
dn = mev.content['displayname']
if dn is not None:
ctx['sender_display_name'] = dn
defer.returnValue(ctx)

47
synapse/push/pusher.py Normal file
View file

@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from httppusher import HttpPusher
import logging
logger = logging.getLogger(__name__)
# We try importing this if we can (it will fail if we don't
# have the optional email dependencies installed). We don't
# yet have the config to know if we need the email pusher,
# but importing this after daemonizing seems to fail
# (even though a simple test of importing from a daemonized
# process works fine)
try:
from synapse.push.emailpusher import EmailPusher
except:
pass
def create_pusher(hs, pusherdict):
logger.info("trying to create_pusher for %r", pusherdict)
PUSHER_TYPES = {
"http": HttpPusher,
}
logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
if hs.config.email_enable_notifs:
PUSHER_TYPES["email"] = EmailPusher
logger.info("defined email pusher type")
if pusherdict['kind'] in PUSHER_TYPES:
logger.info("found pusher")
return PUSHER_TYPES[pusherdict['kind']](hs, pusherdict)

View file

@ -16,9 +16,9 @@
from twisted.internet import defer from twisted.internet import defer
from .httppusher import HttpPusher import pusher
from synapse.push import PusherConfigException
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn
from synapse.util.async import run_on_reactor
import logging import logging
@ -28,10 +28,10 @@ logger = logging.getLogger(__name__)
class PusherPool: class PusherPool:
def __init__(self, _hs): def __init__(self, _hs):
self.hs = _hs self.hs = _hs
self.start_pushers = _hs.config.start_pushers
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.pushers = {} self.pushers = {}
self.last_pusher_started = -1
@defer.inlineCallbacks @defer.inlineCallbacks
def start(self): def start(self):
@ -48,7 +48,8 @@ class PusherPool:
# will then get pulled out of the database, # will then get pulled out of the database,
# recreated, added and started: this means we have only one # recreated, added and started: this means we have only one
# code path adding pushers. # code path adding pushers.
self._create_pusher({ pusher.create_pusher(self.hs, {
"id": None,
"user_name": user_id, "user_name": user_id,
"kind": kind, "kind": kind,
"app_id": app_id, "app_id": app_id,
@ -58,10 +59,18 @@ class PusherPool:
"ts": time_now_msec, "ts": time_now_msec,
"lang": lang, "lang": lang,
"data": data, "data": data,
"last_token": None, "last_stream_ordering": None,
"last_success": None, "last_success": None,
"failing_since": None "failing_since": None
}) })
# create the pusher setting last_stream_ordering to the current maximum
# stream ordering in event_push_actions, so it will process
# pushes from this point onwards.
last_stream_ordering = (
yield self.store.get_latest_push_action_stream_ordering()
)
yield self.store.add_pusher( yield self.store.add_pusher(
user_id=user_id, user_id=user_id,
access_token=access_token, access_token=access_token,
@ -73,6 +82,7 @@ class PusherPool:
pushkey_ts=time_now_msec, pushkey_ts=time_now_msec,
lang=lang, lang=lang,
data=data, data=data,
last_stream_ordering=last_stream_ordering,
profile_tag=profile_tag, profile_tag=profile_tag,
) )
yield self._refresh_pusher(app_id, pushkey, user_id) yield self._refresh_pusher(app_id, pushkey, user_id)
@ -106,26 +116,51 @@ class PusherPool:
) )
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
def _create_pusher(self, pusherdict): @defer.inlineCallbacks
if pusherdict['kind'] == 'http': def on_new_notifications(self, min_stream_id, max_stream_id):
return HttpPusher( yield run_on_reactor()
self.hs, try:
user_id=pusherdict['user_name'], users_affected = yield self.store.get_push_action_users_in_range(
app_id=pusherdict['app_id'], min_stream_id, max_stream_id
app_display_name=pusherdict['app_display_name'],
device_display_name=pusherdict['device_display_name'],
pushkey=pusherdict['pushkey'],
pushkey_ts=pusherdict['ts'],
data=pusherdict['data'],
last_token=pusherdict['last_token'],
last_success=pusherdict['last_success'],
failing_since=pusherdict['failing_since']
) )
else:
raise PusherConfigException( deferreds = []
"Unknown pusher type '%s' for user %s" %
(pusherdict['kind'], pusherdict['user_name']) for u in users_affected:
if u in self.pushers:
for p in self.pushers[u].values():
deferreds.append(
p.on_new_notifications(min_stream_id, max_stream_id)
)
yield defer.gatherResults(deferreds)
except:
logger.exception("Exception in pusher on_new_notifications")
@defer.inlineCallbacks
def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
yield run_on_reactor()
try:
# Need to subtract 1 from the minimum because the lower bound here
# is not inclusive
updated_receipts = yield self.store.get_all_updated_receipts(
min_stream_id - 1, max_stream_id
) )
# This returns a tuple, user_id is at index 3
users_affected = set([r[3] for r in updated_receipts])
deferreds = []
for u in users_affected:
if u in self.pushers:
for p in self.pushers[u].values():
deferreds.append(
p.on_new_receipts(min_stream_id, max_stream_id)
)
yield defer.gatherResults(deferreds)
except:
logger.exception("Exception in pusher on_new_receipts")
@defer.inlineCallbacks @defer.inlineCallbacks
def _refresh_pusher(self, app_id, pushkey, user_id): def _refresh_pusher(self, app_id, pushkey, user_id):
@ -143,33 +178,40 @@ class PusherPool:
self._start_pushers([p]) self._start_pushers([p])
def _start_pushers(self, pushers): def _start_pushers(self, pushers):
if not self.start_pushers:
logger.info("Not starting pushers because they are disabled in the config")
return
logger.info("Starting %d pushers", len(pushers)) logger.info("Starting %d pushers", len(pushers))
for pusherdict in pushers: for pusherdict in pushers:
try: try:
p = self._create_pusher(pusherdict) p = pusher.create_pusher(self.hs, pusherdict)
except PusherConfigException: except:
logger.exception("Couldn't start a pusher: caught PusherConfigException") logger.exception("Couldn't start a pusher: caught Exception")
continue continue
if p: if p:
fullid = "%s:%s:%s" % ( appid_pushkey = "%s:%s" % (
pusherdict['app_id'], pusherdict['app_id'],
pusherdict['pushkey'], pusherdict['pushkey'],
pusherdict['user_name']
) )
if fullid in self.pushers: byuser = self.pushers.setdefault(pusherdict['user_name'], {})
self.pushers[fullid].stop()
self.pushers[fullid] = p if appid_pushkey in byuser:
preserve_fn(p.start)() byuser[appid_pushkey].on_stop()
byuser[appid_pushkey] = p
preserve_fn(p.on_started)()
logger.info("Started pushers") logger.info("Started pushers")
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pusher(self, app_id, pushkey, user_id): def remove_pusher(self, app_id, pushkey, user_id):
fullid = "%s:%s:%s" % (app_id, pushkey, user_id) appid_pushkey = "%s:%s" % (app_id, pushkey)
if fullid in self.pushers:
logger.info("Stopping pusher %s", fullid) byuser = self.pushers.get(user_id, {})
self.pushers[fullid].stop()
del self.pushers[fullid] if appid_pushkey in byuser:
logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
byuser[appid_pushkey].on_stop()
del byuser[appid_pushkey]
yield self.store.delete_pusher_by_app_id_pushkey_user_id( yield self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id, pushkey, user_id app_id, pushkey, user_id
) )

View file

@ -40,7 +40,14 @@ REQUIREMENTS = {
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {
"web_client": { "web_client": {
"matrix_angular_sdk>=0.6.8": ["syweb>=0.6.8"], "matrix_angular_sdk>=0.6.8": ["syweb>=0.6.8"],
} },
"preview_url": {
"netaddr>=0.7.18": ["netaddr"],
},
"email.enable_notifs": {
"Jinja2>=2.8": ["Jinja2>=2.8"],
"bleach>=1.4.2": ["bleach>=1.4.2"],
},
} }

View file

@ -0,0 +1,59 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import respond_with_json_bytes, request_handler
from synapse.http.servlet import parse_json_object_from_request
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
class PresenceResource(Resource):
"""
HTTP endpoint for marking users as syncing.
POST /_synapse/replication/presence HTTP/1.1
Content-Type: application/json
{
"process_id": "<process_id>",
"syncing_users": ["<user_id>"]
}
"""
def __init__(self, hs):
Resource.__init__(self) # Resource is old-style, so no super()
self.version_string = hs.version_string
self.clock = hs.get_clock()
self.presence_handler = hs.get_presence_handler()
def render_POST(self, request):
self._async_render_POST(request)
return NOT_DONE_YET
@request_handler()
@defer.inlineCallbacks
def _async_render_POST(self, request):
content = parse_json_object_from_request(request)
process_id = content["process_id"]
syncing_user_ids = content["syncing_users"]
yield self.presence_handler.update_external_syncs(
process_id, set(syncing_user_ids)
)
respond_with_json_bytes(request, 200, "{}")

View file

@ -0,0 +1,54 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import respond_with_json_bytes, request_handler
from synapse.http.servlet import parse_json_object_from_request
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
class PusherResource(Resource):
"""
HTTP endpoint for deleting rejected pushers
"""
def __init__(self, hs):
Resource.__init__(self) # Resource is old-style, so no super()
self.version_string = hs.version_string
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.clock = hs.get_clock()
def render_POST(self, request):
self._async_render_POST(request)
return NOT_DONE_YET
@request_handler()
@defer.inlineCallbacks
def _async_render_POST(self, request):
content = parse_json_object_from_request(request)
for remove in content["remove"]:
yield self.store.delete_pusher_by_app_id_pushkey_user_id(
remove["app_id"],
remove["push_key"],
remove["user_id"],
)
self.notifier.on_new_replication_data()
respond_with_json_bytes(request, 200, "{}")

View file

@ -15,6 +15,8 @@
from synapse.http.servlet import parse_integer, parse_string from synapse.http.servlet import parse_integer, parse_string
from synapse.http.server import request_handler, finish_request from synapse.http.server import request_handler, finish_request
from synapse.replication.pusher_resource import PusherResource
from synapse.replication.presence_resource import PresenceResource
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
@ -38,6 +40,7 @@ STREAM_NAMES = (
("backfill",), ("backfill",),
("push_rules",), ("push_rules",),
("pushers",), ("pushers",),
("state",),
) )
@ -76,7 +79,7 @@ class ReplicationResource(Resource):
The response is a JSON object with keys for each stream with updates. Under The response is a JSON object with keys for each stream with updates. Under
each key is a JSON object with: each key is a JSON object with:
* "postion": The current position of the stream. * "position": The current position of the stream.
* "field_names": The names of the fields in each row. * "field_names": The names of the fields in each row.
* "rows": The updates as an array of arrays. * "rows": The updates as an array of arrays.
@ -101,17 +104,19 @@ class ReplicationResource(Resource):
long-polling this replication API for new data on those streams. long-polling this replication API for new data on those streams.
""" """
isLeaf = True
def __init__(self, hs): def __init__(self, hs):
Resource.__init__(self) # Resource is old-style, so no super() Resource.__init__(self) # Resource is old-style, so no super()
self.version_string = hs.version_string self.version_string = hs.version_string
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.sources = hs.get_event_sources() self.sources = hs.get_event_sources()
self.presence_handler = hs.get_handlers().presence_handler self.presence_handler = hs.get_presence_handler()
self.typing_handler = hs.get_handlers().typing_notification_handler self.typing_handler = hs.get_typing_handler()
self.notifier = hs.notifier self.notifier = hs.notifier
self.clock = hs.get_clock()
self.putChild("remove_pushers", PusherResource(hs))
self.putChild("syncing_users", PresenceResource(hs))
def render_GET(self, request): def render_GET(self, request):
self._async_render_GET(request) self._async_render_GET(request)
@ -123,6 +128,7 @@ class ReplicationResource(Resource):
backfill_token = yield self.store.get_current_backfill_token() backfill_token = yield self.store.get_current_backfill_token()
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token() pushers_token = self.store.get_pushers_stream_token()
state_token = self.store.get_state_stream_token()
defer.returnValue(_ReplicationToken( defer.returnValue(_ReplicationToken(
room_stream_token, room_stream_token,
@ -133,40 +139,62 @@ class ReplicationResource(Resource):
backfill_token, backfill_token,
push_rules_token, push_rules_token,
pushers_token, pushers_token,
state_token,
)) ))
@request_handler @request_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_GET(self, request): def _async_render_GET(self, request):
limit = parse_integer(request, "limit", 100) limit = parse_integer(request, "limit", 100)
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
request.setHeader(b"Content-Type", b"application/json") request.setHeader(b"Content-Type", b"application/json")
writer = _Writer(request)
@defer.inlineCallbacks request_streams = {
name: parse_integer(request, name)
for names in STREAM_NAMES for name in names
}
request_streams["streams"] = parse_string(request, "streams")
def replicate(): def replicate():
current_token = yield self.current_replication_token() return self.replicate(request_streams, limit)
logger.info("Replicating up to %r", current_token)
yield self.account_data(writer, current_token, limit) result = yield self.notifier.wait_for_replication(replicate, timeout)
yield self.events(writer, current_token, limit)
yield self.presence(writer, current_token) # TODO: implement limit
yield self.typing(writer, current_token) # TODO: implement limit
yield self.receipts(writer, current_token, limit)
yield self.push_rules(writer, current_token, limit)
yield self.pushers(writer, current_token, limit)
self.streams(writer, current_token)
logger.info("Replicated %d rows", writer.total) for stream_name, stream_content in result.items():
defer.returnValue(writer.total) logger.info(
"Replicating %d rows of %s from %s -> %s",
len(stream_content["rows"]),
stream_name,
request_streams.get(stream_name),
stream_content["position"],
)
yield self.notifier.wait_for_replication(replicate, timeout) request.write(json.dumps(result, ensure_ascii=False))
finish_request(request)
writer.finish() @defer.inlineCallbacks
def replicate(self, request_streams, limit):
writer = _Writer()
current_token = yield self.current_replication_token()
logger.info("Replicating up to %r", current_token)
def streams(self, writer, current_token): yield self.account_data(writer, current_token, limit, request_streams)
request_token = parse_string(writer.request, "streams") yield self.events(writer, current_token, limit, request_streams)
# TODO: implement limit
yield self.presence(writer, current_token, request_streams)
yield self.typing(writer, current_token, request_streams)
yield self.receipts(writer, current_token, limit, request_streams)
yield self.push_rules(writer, current_token, limit, request_streams)
yield self.pushers(writer, current_token, limit, request_streams)
yield self.state(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams)
logger.info("Replicated %d rows", writer.total)
defer.returnValue(writer.finish())
def streams(self, writer, current_token, request_streams):
request_token = request_streams.get("streams")
streams = [] streams = []
@ -191,32 +219,43 @@ class ReplicationResource(Resource):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def events(self, writer, current_token, limit): def events(self, writer, current_token, limit, request_streams):
request_events = parse_integer(writer.request, "events") request_events = request_streams.get("events")
request_backfill = parse_integer(writer.request, "backfill") request_backfill = request_streams.get("backfill")
if request_events is not None or request_backfill is not None: if request_events is not None or request_backfill is not None:
if request_events is None: if request_events is None:
request_events = current_token.events request_events = current_token.events
if request_backfill is None: if request_backfill is None:
request_backfill = current_token.backfill request_backfill = current_token.backfill
events_rows, backfill_rows = yield self.store.get_all_new_events( res = yield self.store.get_all_new_events(
request_backfill, request_events, request_backfill, request_events,
current_token.backfill, current_token.events, current_token.backfill, current_token.events,
limit limit
) )
writer.write_header_and_rows("events", res.new_forward_events, (
"position", "internal", "json", "state_group"
))
writer.write_header_and_rows("backfill", res.new_backfill_events, (
"position", "internal", "json", "state_group"
))
writer.write_header_and_rows( writer.write_header_and_rows(
"events", events_rows, ("position", "internal", "json") "forward_ex_outliers", res.forward_ex_outliers,
("position", "event_id", "state_group")
) )
writer.write_header_and_rows( writer.write_header_and_rows(
"backfill", backfill_rows, ("position", "internal", "json") "backward_ex_outliers", res.backward_ex_outliers,
("position", "event_id", "state_group")
)
writer.write_header_and_rows(
"state_resets", res.state_resets, ("position",)
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def presence(self, writer, current_token): def presence(self, writer, current_token, request_streams):
current_position = current_token.presence current_position = current_token.presence
request_presence = parse_integer(writer.request, "presence") request_presence = request_streams.get("presence")
if request_presence is not None: if request_presence is not None:
presence_rows = yield self.presence_handler.get_all_presence_updates( presence_rows = yield self.presence_handler.get_all_presence_updates(
@ -229,10 +268,10 @@ class ReplicationResource(Resource):
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def typing(self, writer, current_token): def typing(self, writer, current_token, request_streams):
current_position = current_token.presence current_position = current_token.presence
request_typing = parse_integer(writer.request, "typing") request_typing = request_streams.get("typing")
if request_typing is not None: if request_typing is not None:
typing_rows = yield self.typing_handler.get_all_typing_updates( typing_rows = yield self.typing_handler.get_all_typing_updates(
@ -243,10 +282,10 @@ class ReplicationResource(Resource):
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def receipts(self, writer, current_token, limit): def receipts(self, writer, current_token, limit, request_streams):
current_position = current_token.receipts current_position = current_token.receipts
request_receipts = parse_integer(writer.request, "receipts") request_receipts = request_streams.get("receipts")
if request_receipts is not None: if request_receipts is not None:
receipts_rows = yield self.store.get_all_updated_receipts( receipts_rows = yield self.store.get_all_updated_receipts(
@ -257,12 +296,12 @@ class ReplicationResource(Resource):
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def account_data(self, writer, current_token, limit): def account_data(self, writer, current_token, limit, request_streams):
current_position = current_token.account_data current_position = current_token.account_data
user_account_data = parse_integer(writer.request, "user_account_data") user_account_data = request_streams.get("user_account_data")
room_account_data = parse_integer(writer.request, "room_account_data") room_account_data = request_streams.get("room_account_data")
tag_account_data = parse_integer(writer.request, "tag_account_data") tag_account_data = request_streams.get("tag_account_data")
if user_account_data is not None or room_account_data is not None: if user_account_data is not None or room_account_data is not None:
if user_account_data is None: if user_account_data is None:
@ -288,10 +327,10 @@ class ReplicationResource(Resource):
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def push_rules(self, writer, current_token, limit): def push_rules(self, writer, current_token, limit, request_streams):
current_position = current_token.push_rules current_position = current_token.push_rules
push_rules = parse_integer(writer.request, "push_rules") push_rules = request_streams.get("push_rules")
if push_rules is not None: if push_rules is not None:
rows = yield self.store.get_all_push_rule_updates( rows = yield self.store.get_all_push_rule_updates(
@ -303,10 +342,11 @@ class ReplicationResource(Resource):
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def pushers(self, writer, current_token, limit): def pushers(self, writer, current_token, limit, request_streams):
current_position = current_token.pushers current_position = current_token.pushers
pushers = parse_integer(writer.request, "pushers") pushers = request_streams.get("pushers")
if pushers is not None: if pushers is not None:
updated, deleted = yield self.store.get_all_updated_pushers( updated, deleted = yield self.store.get_all_updated_pushers(
pushers, current_position, limit pushers, current_position, limit
@ -316,16 +356,34 @@ class ReplicationResource(Resource):
"app_id", "app_display_name", "device_display_name", "pushkey", "app_id", "app_display_name", "device_display_name", "pushkey",
"ts", "lang", "data" "ts", "lang", "data"
)) ))
writer.write_header_and_rows("deleted", deleted, ( writer.write_header_and_rows("deleted_pushers", deleted, (
"position", "user_id", "app_id", "pushkey" "position", "user_id", "app_id", "pushkey"
)) ))
@defer.inlineCallbacks
def state(self, writer, current_token, limit, request_streams):
current_position = current_token.state
state = request_streams.get("state")
if state is not None:
state_groups, state_group_state = (
yield self.store.get_all_new_state_groups(
state, current_position, limit
)
)
writer.write_header_and_rows("state_groups", state_groups, (
"position", "room_id", "event_id"
))
writer.write_header_and_rows("state_group_state", state_group_state, (
"position", "type", "state_key", "event_id"
))
class _Writer(object): class _Writer(object):
"""Writes the streams as a JSON object as the response to the request""" """Writes the streams as a JSON object as the response to the request"""
def __init__(self, request): def __init__(self):
self.streams = {} self.streams = {}
self.request = request
self.total = 0 self.total = 0
def write_header_and_rows(self, name, rows, fields, position=None): def write_header_and_rows(self, name, rows, fields, position=None):
@ -336,7 +394,7 @@ class _Writer(object):
position = rows[-1][0] position = rows[-1][0]
self.streams[name] = { self.streams[name] = {
"position": str(position), "position": position if type(position) is int else str(position),
"field_names": fields, "field_names": fields,
"rows": rows, "rows": rows,
} }
@ -344,13 +402,12 @@ class _Writer(object):
self.total += len(rows) self.total += len(rows)
def finish(self): def finish(self):
self.request.write(json.dumps(self.streams, ensure_ascii=False)) return self.streams
finish_request(self.request)
class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill", "events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers" "push_rules", "pushers", "state"
))): ))):
__slots__ = [] __slots__ = []

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage._base import SQLBaseStore
from twisted.internet import defer
class BaseSlavedStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(BaseSlavedStore, self).__init__(hs)
def stream_positions(self):
return {}
def process_replication(self, result):
return defer.succeed(None)

View file

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.util.id_generators import _load_current_id
class SlavedIdTracker(object):
def __init__(self, db_conn, table, column, extra_tables=[], step=1):
self.step = step
self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables:
self.advance(_load_current_id(db_conn, table, column))
def advance(self, new_id):
self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self):
return self._current

View file

@ -0,0 +1,100 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.account_data import AccountDataStore
from synapse.storage.tags import TagsStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedAccountDataStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedAccountDataStore, self).__init__(db_conn, hs)
self._account_data_id_gen = SlavedIdTracker(
db_conn, "account_data_max_stream_id", "stream_id",
)
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache",
self._account_data_id_gen.get_current_token(),
)
get_account_data_for_user = (
AccountDataStore.__dict__["get_account_data_for_user"]
)
get_global_account_data_by_type_for_users = (
AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
)
get_global_account_data_by_type_for_user = (
AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
)
get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
get_updated_tags = DataStore.get_updated_tags.__func__
get_updated_account_data_for_user = (
DataStore.get_updated_account_data_for_user.__func__
)
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedAccountDataStore, self).stream_positions()
position = self._account_data_id_gen.get_current_token()
result["user_account_data"] = position
result["room_account_data"] = position
result["tag_account_data"] = position
return result
def process_replication(self, result):
stream = result.get("user_account_data")
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id, data_type = row[:3]
self.get_global_account_data_by_type_for_user.invalidate(
(data_type, user_id,)
)
self.get_account_data_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
)
stream = result.get("room_account_data")
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.get_account_data_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
)
stream = result.get("tag_account_data")
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.get_tags_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
)
return super(SlavedAccountDataStore, self).process_replication(result)

View file

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.config.appservice import load_appservices
class SlavedApplicationServiceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedApplicationServiceStore, self).__init__(db_conn, hs)
self.services_cache = load_appservices(
hs.config.server_name,
hs.config.app_service_config_files
)
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__

View file

@ -0,0 +1,254 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.api.constants import EventTypes
from synapse.events import FrozenEvent
from synapse.storage import DataStore
from synapse.storage.room import RoomStore
from synapse.storage.roommember import RoomMemberStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.event_push_actions import EventPushActionsStore
from synapse.storage.state import StateStore
from synapse.storage.stream import StreamStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
import ujson as json
# So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the
# DataStore or are cached and don't have cache invalidation logic.
#
# Rather than write duplicate versions of those functions, or lift them to
# a common base class, we going to grab the underlying __func__ object from
# the method descriptor on the DataStore and chuck them into our class.
class SlavedEventStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedEventStore, self).__init__(db_conn, hs)
self._stream_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering",
)
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",
entity_column="room_id",
stream_column="stream_ordering",
max_value=events_max,
)
self._events_stream_cache = StreamChangeCache(
"EventsRoomStreamChangeCache", min_event_val,
prefilled_cache=event_cache_prefill,
)
self._membership_stream_cache = StreamChangeCache(
"MembershipStreamChangeCache", events_max,
)
# Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them.
get_room_name_and_aliases = RoomStore.__dict__["get_room_name_and_aliases"]
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
get_latest_event_ids_in_room = EventFederationStore.__dict__[
"get_latest_event_ids_in_room"
]
_get_current_state_for_key = StateStore.__dict__[
"_get_current_state_for_key"
]
get_invited_rooms_for_user = RoomMemberStore.__dict__[
"get_invited_rooms_for_user"
]
get_unread_event_push_actions_by_room_for_user = (
EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"]
)
_get_state_group_for_events = (
StateStore.__dict__["_get_state_group_for_events"]
)
_get_state_group_for_event = (
StateStore.__dict__["_get_state_group_for_event"]
)
_get_state_groups_from_groups = (
StateStore.__dict__["_get_state_groups_from_groups"]
)
_get_state_group_from_group = (
StateStore.__dict__["_get_state_group_from_group"]
)
get_recent_event_ids_for_room = (
StreamStore.__dict__["get_recent_event_ids_for_room"]
)
get_unread_push_actions_for_user_in_range = (
DataStore.get_unread_push_actions_for_user_in_range.__func__
)
get_push_action_users_in_range = (
DataStore.get_push_action_users_in_range.__func__
)
get_event = DataStore.get_event.__func__
get_events = DataStore.get_events.__func__
get_current_state = DataStore.get_current_state.__func__
get_current_state_for_key = DataStore.get_current_state_for_key.__func__
get_rooms_for_user_where_membership_is = (
DataStore.get_rooms_for_user_where_membership_is.__func__
)
get_membership_changes_for_user = (
DataStore.get_membership_changes_for_user.__func__
)
get_room_events_max_id = DataStore.get_room_events_max_id.__func__
get_room_events_stream_for_room = (
DataStore.get_room_events_stream_for_room.__func__
)
get_events_around = DataStore.get_events_around.__func__
get_state_for_event = DataStore.get_state_for_event.__func__
get_state_for_events = DataStore.get_state_for_events.__func__
get_state_groups = DataStore.get_state_groups.__func__
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
get_room_events_stream_for_rooms = (
DataStore.get_room_events_stream_for_rooms.__func__
)
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
_set_before_and_after = staticmethod(DataStore._set_before_and_after)
_get_events = DataStore._get_events.__func__
_get_events_from_cache = DataStore._get_events_from_cache.__func__
_invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__
_enqueue_events = DataStore._enqueue_events.__func__
_do_fetch = DataStore._do_fetch.__func__
_fetch_event_rows = DataStore._fetch_event_rows.__func__
_get_event_from_row = DataStore._get_event_from_row.__func__
_get_rooms_for_user_where_membership_is_txn = (
DataStore._get_rooms_for_user_where_membership_is_txn.__func__
)
_get_members_rows_txn = DataStore._get_members_rows_txn.__func__
_get_state_for_groups = DataStore._get_state_for_groups.__func__
_get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
_get_events_around_txn = DataStore._get_events_around_txn.__func__
_get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token()
result["backfill"] = -self._backfill_id_gen.get_current_token()
return result
def process_replication(self, result):
state_resets = set(
r[0] for r in result.get("state_resets", {"rows": []})["rows"]
)
stream = result.get("events")
if stream:
self._stream_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=False, state_resets=state_resets
)
stream = result.get("backfill")
if stream:
self._backfill_id_gen.advance(-int(stream["position"]))
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=True, state_resets=state_resets
)
stream = result.get("forward_ex_outliers")
if stream:
self._stream_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
event_id = row[1]
self._invalidate_get_event_cache(event_id)
stream = result.get("backward_ex_outliers")
if stream:
self._backfill_id_gen.advance(-int(stream["position"]))
for row in stream["rows"]:
event_id = row[1]
self._invalidate_get_event_cache(event_id)
return super(SlavedEventStore, self).process_replication(result)
def _process_replication_row(self, row, backfilled, state_resets):
position = row[0]
internal = json.loads(row[1])
event_json = json.loads(row[2])
event = FrozenEvent(event_json, internal_metadata_dict=internal)
self.invalidate_caches_for_event(
event, backfilled, reset_state=position in state_resets
)
def invalidate_caches_for_event(self, event, backfilled, reset_state):
if reset_state:
self._get_current_state_for_key.invalidate_all()
self.get_rooms_for_user.invalidate_all()
self.get_users_in_room.invalidate((event.room_id,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
self.get_room_name_and_aliases.invalidate((event.room_id,))
self._invalidate_get_event_cache(event.event_id)
self.get_latest_event_ids_in_room.invalidate((event.room_id,))
self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
(event.room_id,)
)
if not backfilled:
self._events_stream_cache.entity_has_changed(
event.room_id, event.internal_metadata.stream_ordering
)
# self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
# (event.room_id,)
# )
if event.type == EventTypes.Redaction:
self._invalidate_get_event_cache(event.redacts)
if event.type == EventTypes.Member:
self.get_rooms_for_user.invalidate((event.state_key,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
self.get_users_in_room.invalidate((event.room_id,))
self._membership_stream_cache.entity_has_changed(
event.state_key, event.internal_metadata.stream_ordering
)
self.get_invited_rooms_for_user.invalidate((event.state_key,))
if not event.is_state():
return
if backfilled:
return
if (not event.internal_metadata.is_invite_from_remote()
and event.internal_metadata.is_outlier()):
return
self._get_current_state_for_key.invalidate((
event.room_id, event.type, event.state_key
))
if event.type in [EventTypes.Name, EventTypes.Aliases]:
self.get_room_name_and_aliases.invalidate(
(event.room_id,)
)
pass

View file

@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage.filtering import FilteringStore
class SlavedFilteringStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedFilteringStore, self).__init__(db_conn, hs)
# Filters are immutable so this cache doesn't need to be expired
get_user_filter = FilteringStore.__dict__["get_user_filter"]

View file

@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.storage import DataStore
class SlavedPresenceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedPresenceStore, self).__init__(db_conn, hs)
self._presence_id_gen = SlavedIdTracker(
db_conn, "presence_stream", "stream_id",
)
self._presence_on_startup = self._get_active_presence(db_conn)
self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
)
_get_active_presence = DataStore._get_active_presence.__func__
take_presence_startup_info = DataStore.take_presence_startup_info.__func__
get_presence_for_users = DataStore.get_presence_for_users.__func__
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedPresenceStore, self).stream_positions()
position = self._presence_id_gen.get_current_token()
result["presence"] = position
return result
def process_replication(self, result):
stream = result.get("presence")
if stream:
self._presence_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.presence_stream_cache.entity_has_changed(
user_id, position
)
return super(SlavedPresenceStore, self).process_replication(result)

View file

@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .events import SlavedEventStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.push_rule import PushRuleStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedPushRuleStore(SlavedEventStore):
def __init__(self, db_conn, hs):
super(SlavedPushRuleStore, self).__init__(db_conn, hs)
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id",
)
self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache",
self._push_rules_stream_id_gen.get_current_token(),
)
get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
get_push_rules_enabled_for_user = (
PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
)
have_push_rules_changed_for_user = (
DataStore.have_push_rules_changed_for_user.__func__
)
def get_push_rules_stream_token(self):
return (
self._push_rules_stream_id_gen.get_current_token(),
self._stream_id_gen.get_current_token(),
)
def stream_positions(self):
result = super(SlavedPushRuleStore, self).stream_positions()
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("push_rules")
if stream:
for row in stream["rows"]:
position = row[0]
user_id = row[2]
self.get_push_rules_for_user.invalidate((user_id,))
self.get_push_rules_enabled_for_user.invalidate((user_id,))
self.push_rules_stream_cache.entity_has_changed(
user_id, position
)
self._push_rules_stream_id_gen.advance(int(stream["position"]))
return super(SlavedPushRuleStore, self).process_replication(result)

View file

@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
class SlavedPusherStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedPusherStore, self).__init__(db_conn, hs)
self._pushers_id_gen = SlavedIdTracker(
db_conn, "pushers", "id",
extra_tables=[("deleted_pushers", "stream_id")],
)
get_all_pushers = DataStore.get_all_pushers.__func__
get_pushers_by = DataStore.get_pushers_by.__func__
get_pushers_by_app_id_and_pushkey = (
DataStore.get_pushers_by_app_id_and_pushkey.__func__
)
_decode_pushers_rows = DataStore._decode_pushers_rows.__func__
def stream_positions(self):
result = super(SlavedPusherStore, self).stream_positions()
result["pushers"] = self._pushers_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("pushers")
if stream:
self._pushers_id_gen.advance(int(stream["position"]))
stream = result.get("deleted_pushers")
if stream:
self._pushers_id_gen.advance(int(stream["position"]))
return super(SlavedPusherStore, self).process_replication(result)

View file

@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.receipts import ReceiptsStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
# So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the
# DataStore or are cached and don't have cache invalidation logic.
#
# Rather than write duplicate versions of those functions, or lift them to
# a common base class, we going to grab the underlying __func__ object from
# the method descriptor on the DataStore and chuck them into our class.
class SlavedReceiptsStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedReceiptsStore, self).__init__(db_conn, hs)
self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id"
)
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
)
get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
get_linearized_receipts_for_room = (
ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
)
_get_linearized_receipts_for_rooms = (
ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
)
get_last_receipt_event_id_for_user = (
ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
)
get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
get_linearized_receipts_for_rooms = (
DataStore.get_linearized_receipts_for_rooms.__func__
)
def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions()
result["receipts"] = self._receipts_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("receipts")
if stream:
self._receipts_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, room_id, receipt_type, user_id = row[:4]
self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
self._receipts_stream_cache.entity_has_changed(room_id, position)
return super(SlavedReceiptsStore, self).process_replication(result)
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self.get_linearized_receipts_for_room.invalidate_many((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type)
)

Some files were not shown because too many files have changed in this diff Show more