Merge branch 'develop' into key_distribution

Conflicts:
	synapse/config/homeserver.py
This commit is contained in:
Mark Haines 2015-04-29 13:15:14 +01:00
commit 4ad8b45155
115 changed files with 4445 additions and 1137 deletions

31
CAPTCHA_SETUP Normal file
View file

@ -0,0 +1,31 @@
Captcha can be enabled for this home server. This file explains how to do that.
The captcha mechanism used is Google's ReCaptcha. This requires API keys from Google.
Getting keys
------------
Requires a public/private key pair from:
https://developers.google.com/recaptcha/
Setting ReCaptcha Keys
----------------------
The keys are a config option on the home server config. If they are not
visible, you can generate them via --generate-config. Set the following value:
recaptcha_public_key: YOUR_PUBLIC_KEY
recaptcha_private_key: YOUR_PRIVATE_KEY
In addition, you MUST enable captchas via:
enable_registration_captcha: true
Configuring IP used for auth
----------------------------
The ReCaptcha API requires that the IP address of the user who solved the
captcha is sent. If the client is connecting through a proxy or load balancer,
it may be required to use the X-Forwarded-For (XFF) header instead of the origin
IP address. This can be configured as an option on the home server like so:
captcha_ip_origin_is_x_forwarded: true

View file

@ -20,7 +20,7 @@ The overall architecture is::
https://somewhere.org/_matrix https://elsewhere.net/_matrix https://somewhere.org/_matrix https://elsewhere.net/_matrix
``#matrix:matrix.org`` is the official support room for Matrix, and can be ``#matrix:matrix.org`` is the official support room for Matrix, and can be
accessed by the web client at http://matrix.org/alpha or via an IRC bridge at accessed by the web client at http://matrix.org/beta or via an IRC bridge at
irc://irc.freenode.net/matrix. irc://irc.freenode.net/matrix.
Synapse is currently in rapid development, but as of version 0.5 we believe it Synapse is currently in rapid development, but as of version 0.5 we believe it
@ -69,21 +69,27 @@ Synapse ships with two basic demo Matrix clients: webclient (a basic group chat
web client demo implemented in AngularJS) and cmdclient (a basic Python web client demo implemented in AngularJS) and cmdclient (a basic Python
command line utility which lets you easily see what the JSON APIs are up to). command line utility which lets you easily see what the JSON APIs are up to).
Meanwhile, iOS and Android SDKs and clients are currently in development and available from: Meanwhile, iOS and Android SDKs and clients are available from:
- https://github.com/matrix-org/matrix-ios-sdk - https://github.com/matrix-org/matrix-ios-sdk
- https://github.com/matrix-org/matrix-ios-kit
- https://github.com/matrix-org/matrix-ios-console
- https://github.com/matrix-org/matrix-android-sdk - https://github.com/matrix-org/matrix-android-sdk
We'd like to invite you to join #matrix:matrix.org (via http://matrix.org/alpha), run a homeserver, take a look at the Matrix spec at We'd like to invite you to join #matrix:matrix.org (via
http://matrix.org/docs/spec, experiment with the APIs and the demo https://matrix.org/beta), run a homeserver, take a look at the Matrix spec at
clients, and report any bugs via http://matrix.org/jira. https://matrix.org/docs/spec and API docs at https://matrix.org/docs/api,
experiment with the APIs and the demo clients, and report any bugs via
https://matrix.org/jira.
Thanks for using Matrix! Thanks for using Matrix!
[1] End-to-end encryption is currently in development [1] End-to-end encryption is currently in development
Homeserver Installation Synapse Installation
======================= ====================
Synapse is the reference python/twisted Matrix homeserver implementation.
System requirements: System requirements:
- POSIX-compliant system (tested on Linux & OS X) - POSIX-compliant system (tested on Linux & OS X)
@ -118,6 +124,9 @@ To install the synapse homeserver run::
This installs synapse, along with the libraries it uses, into a virtual This installs synapse, along with the libraries it uses, into a virtual
environment under ``~/.synapse``. environment under ``~/.synapse``.
Alternatively, Silvio Fricke has contributed a Dockerfile to automate the
above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/.
To set up your homeserver, run (in your virtualenv, as before):: To set up your homeserver, run (in your virtualenv, as before)::
$ cd ~/.synapse $ cd ~/.synapse
@ -152,36 +161,51 @@ you can use the command line to register new users::
For reliable VoIP calls to be routed via this homeserver, you MUST configure For reliable VoIP calls to be routed via this homeserver, you MUST configure
a TURN server. See docs/turn-howto.rst for details. a TURN server. See docs/turn-howto.rst for details.
Troubleshooting Installation Using PostgreSQL
---------------------------- ================
Synapse requires pip 1.7 or later, so if your OS provides too old a version and As of Synapse 0.9, `PostgreSQL <http://www.postgresql.org>`_ is supported as an
you get errors about ``error: no such option: --process-dependency-links`` you alternative to the `SQLite <http://sqlite.org/>`_ database that Synapse has
may need to manually upgrade it:: traditionally used for convenience and simplicity.
$ sudo pip install --upgrade pip The advantages of Postgres include:
If pip crashes mid-installation for reason (e.g. lost terminal), pip may * significant performance improvements due to the superior threading and
refuse to run until you remove the temporary installation directory it caching model, smarter query optimiser
created. To reset the installation:: * allowing the DB to be run on separate hardware
* allowing basic active/backup high-availability with a "hot spare" synapse
pointing at the same DB master, as well as enabling DB replication in
synapse itself.
$ rm -rf /tmp/pip_install_matrix The only disadvantage is that the code is relatively new as of April 2015 and
may have a few regressions relative to SQLite.
pip seems to leak *lots* of memory during installation. For instance, a Linux For information on how to install and use PostgreSQL, please see
host with 512MB of RAM may run out of memory whilst installing Twisted. If this `docs/postgres.rst <docs/postgres.rst>`_.
happens, you will have to individually install the dependencies which are
failing, e.g.::
$ pip install twisted Running Synapse
===============
On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you To actually run your new homeserver, pick a working directory for Synapse to run
will need to export CFLAGS=-Qunused-arguments. (e.g. ``~/.synapse``), and::
$ cd ~/.synapse
$ source ./bin/activate
$ synctl start
Platform Specific Instructions
==============================
ArchLinux ArchLinux
--------- ---------
Installation on ArchLinux may encounter a few hiccups as Arch defaults to The quickest way to get up and running with ArchLinux is probably with Ivan
python 3, but synapse currently assumes python 2.7 by default. Shapovalov's AUR package from
https://aur.archlinux.org/packages/matrix-synapse/, which should pull in all
the necessary dependencies.
Alternatively, to install using pip a few changes may be needed as ArchLinux
defaults to python 3, but synapse currently assumes python 2.7 by default:
pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 ):: pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 )::
@ -201,7 +225,7 @@ installing under virtualenv)::
$ sudo pip2.7 uninstall py-bcrypt $ sudo pip2.7 uninstall py-bcrypt
$ sudo pip2.7 install py-bcrypt $ sudo pip2.7 install py-bcrypt
During setup of homeserver you need to call python2.7 directly again:: During setup of Synapse you need to call python2.7 directly again::
$ cd ~/.synapse $ cd ~/.synapse
$ python2.7 -m synapse.app.homeserver \ $ python2.7 -m synapse.app.homeserver \
@ -242,15 +266,33 @@ Troubleshooting:
you do, you may need to create a symlink to ``libsodium.a`` so ``ld`` can find you do, you may need to create a symlink to ``libsodium.a`` so ``ld`` can find
it: ``ln -s /usr/local/lib/libsodium.a /usr/lib/libsodium.a`` it: ``ln -s /usr/local/lib/libsodium.a /usr/lib/libsodium.a``
Running Your Homeserver Troubleshooting
======================= ===============
To actually run your new homeserver, pick a working directory for Synapse to run Troubleshooting Installation
(e.g. ``~/.synapse``), and:: ----------------------------
$ cd ~/.synapse Synapse requires pip 1.7 or later, so if your OS provides too old a version and
$ source ./bin/activate you get errors about ``error: no such option: --process-dependency-links`` you
$ synctl start may need to manually upgrade it::
$ sudo pip install --upgrade pip
If pip crashes mid-installation for reason (e.g. lost terminal), pip may
refuse to run until you remove the temporary installation directory it
created. To reset the installation::
$ rm -rf /tmp/pip_install_matrix
pip seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are
failing, e.g.::
$ pip install twisted
On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
will need to export CFLAGS=-Qunused-arguments.
Troubleshooting Running Troubleshooting Running
----------------------- -----------------------
@ -271,7 +313,7 @@ fix try re-installing from PyPI or directly from
$ pip install --user https://github.com/pyca/pynacl/tarball/master $ pip install --user https://github.com/pyca/pynacl/tarball/master
ArchLinux ArchLinux
--------- ~~~~~~~~~
If running `$ synctl start` fails with 'returned non-zero exit status 1', If running `$ synctl start` fails with 'returned non-zero exit status 1',
you will need to explicitly call Python2.7 - either running as:: you will need to explicitly call Python2.7 - either running as::
@ -280,16 +322,16 @@ you will need to explicitly call Python2.7 - either running as::
...or by editing synctl with the correct python executable. ...or by editing synctl with the correct python executable.
Homeserver Development Synapse Development
====================== ===================
To check out a homeserver for development, clone the git repo into a working To check out a synapse for development, clone the git repo into a working
directory of your choice:: directory of your choice::
$ git clone https://github.com/matrix-org/synapse.git $ git clone https://github.com/matrix-org/synapse.git
$ cd synapse $ cd synapse
The homeserver has a number of external dependencies, that are easiest Synapse has a number of external dependencies, that are easiest
to install using pip and a virtualenv:: to install using pip and a virtualenv::
$ virtualenv env $ virtualenv env
@ -300,7 +342,7 @@ to install using pip and a virtualenv::
This will run a process of downloading and installing all the needed This will run a process of downloading and installing all the needed
dependencies into a virtual env. dependencies into a virtual env.
Once this is done, you may wish to run the homeserver's unit tests, to Once this is done, you may wish to run Synapse's unit tests, to
check that everything is installed as it should be:: check that everything is installed as it should be::
$ python setup.py test $ python setup.py test
@ -312,10 +354,10 @@ This should end with a 'PASSED' result::
PASSED (successes=143) PASSED (successes=143)
Upgrading an existing homeserver Upgrading an existing Synapse
================================ =============================
IMPORTANT: Before upgrading an existing homeserver to a new version, please IMPORTANT: Before upgrading an existing synapse to a new version, please
refer to UPGRADE.rst for any additional instructions. refer to UPGRADE.rst for any additional instructions.
Otherwise, simply re-install the new codebase over the current one - e.g. Otherwise, simply re-install the new codebase over the current one - e.g.
@ -376,8 +418,8 @@ SRV record, as that is the name other machines will expect it to have::
You may additionally want to pass one or more "-v" options, in order to You may additionally want to pass one or more "-v" options, in order to
increase the verbosity of logging output; at least for initial testing. increase the verbosity of logging output; at least for initial testing.
Running a Demo Federation of Homeservers Running a Demo Federation of Synapses
---------------------------------------- -------------------------------------
If you want to get up and running quickly with a trio of homeservers in a If you want to get up and running quickly with a trio of homeservers in a
private federation (``localhost:8080``, ``localhost:8081`` and private federation (``localhost:8080``, ``localhost:8081`` and
@ -412,7 +454,10 @@ account. Your name will take the form of::
Specify your desired localpart in the topmost box of the "Register for an Specify your desired localpart in the topmost box of the "Register for an
account" form, and click the "Register" button. Hostnames can contain ports if account" form, and click the "Register" button. Hostnames can contain ports if
required due to lack of SRV records (e.g. @matthew:localhost:8448 on an required due to lack of SRV records (e.g. @matthew:localhost:8448 on an
internal synapse sandbox running on localhost) internal synapse sandbox running on localhost).
If registration fails, you may need to enable it in the homeserver (see
`Synapse Installation`_ above)
Logging In To An Existing Account Logging In To An Existing Account

114
docs/postgres.rst Normal file
View file

@ -0,0 +1,114 @@
Using Postgres
--------------
Set up database
===============
The PostgreSQL database used *must* have the correct encoding set, otherwise
would not be able to store UTF8 strings. To create a database with the correct
encoding use, e.g.::
CREATE DATABASE synapse
ENCODING 'UTF8'
LC_COLLATE='C'
LC_CTYPE='C'
template=template0
OWNER synapse_user;
This would create an appropriate database named ``synapse`` owned by the
``synapse_user`` user (which must already exist).
Set up client
=============
Postgres support depends on the postgres python connector ``psycopg2``. In the
virtual env::
sudo apt-get install libpq-dev
pip install psycopg2
Synapse config
==============
When you are ready to start using PostgreSQL, add the following line to your
config file::
database_config: <db_config_file>
Where ``<db_config_file>`` is the file name that points to a yaml file of the
following form::
name: psycopg2
args:
user: <user>
password: <pass>
database: <db>
host: <host>
cp_min: 5
cp_max: 10
All key, values in ``args`` are passed to the ``psycopg2.connect(..)``
function, except keys beginning with ``cp_``, which are consumed by the twisted
adbapi connection pool.
Porting from SQLite
===================
Overview
~~~~~~~~
The script ``port_from_sqlite_to_postgres.py`` allows porting an existing
synapse server backed by SQLite to using PostgreSQL. This is done in as a two
phase process:
1. Copy the existing SQLite database to a separate location (while the server
is down) and running the port script against that offline database.
2. Shut down the server. Rerun the port script to port any data that has come
in since taking the first snapshot. Restart server against the PostgreSQL
database.
The port script is designed to be run repeatedly against newer snapshots of the
SQLite database file. This makes it safe to repeat step 1 if there was a delay
between taking the previous snapshot and being ready to do step 2.
It is safe to at any time kill the port script and restart it.
Using the port script
~~~~~~~~~~~~~~~~~~~~~
Firstly, shut down the currently running synapse server and copy its database
file (typically ``homeserver.db``) to another location. Once the copy is
complete, restart synapse. For instance::
./synctl stop
cp homeserver.db homeserver.db.snapshot
./synctl start
Assuming your database config file (as described in the section *Synapse
config*) is named ``database_config.yaml`` and the SQLite snapshot is at
``homeserver.db.snapshot`` then simply run::
python scripts/port_from_sqlite_to_postgres.py \
--sqlite-database homeserver.db.snapshot \
--postgres-config database_config.yaml
The flag ``--curses`` displays a coloured curses progress UI.
If the script took a long time to complete, or time has otherwise passed since
the original snapshot was taken, repeat the previous steps with a newer
snapshot.
To complete the conversion shut down the synapse server and run the port
script one last time, e.g. if the SQLite database is at ``homeserver.db``
run::
python scripts/port_from_sqlite_to_postgres.py \
--sqlite-database homeserver.db \
--postgres-config database_config.yaml
Once that has completed, change the synapse config to point at the PostgreSQL
database configuration file using the ``database_config`` parameter (see
`Synapse Config`_) and restart synapse. Synapse should now be running against
PostgreSQL.

View file

@ -33,10 +33,9 @@ def request_registration(user, password, server_location, shared_secret):
).hexdigest() ).hexdigest()
data = { data = {
"user": user, "username": user,
"password": password, "password": password,
"mac": mac, "mac": mac,
"type": "org.matrix.login.shared_secret",
} }
server_location = server_location.rstrip("/") server_location = server_location.rstrip("/")
@ -44,7 +43,7 @@ def request_registration(user, password, server_location, shared_secret):
print "Sending registration request..." print "Sending registration request..."
req = urllib2.Request( req = urllib2.Request(
"%s/_matrix/client/api/v1/register" % (server_location,), "%s/_matrix/client/v2_alpha/register" % (server_location,),
data=json.dumps(data), data=json.dumps(data),
headers={'Content-Type': 'application/json'} headers={'Content-Type': 'application/json'}
) )

View file

@ -0,0 +1,754 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer, reactor
from twisted.enterprise import adbapi
from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.engines import create_engine
import argparse
import curses
import logging
import sys
import time
import traceback
import yaml
logger = logging.getLogger("port_from_sqlite_to_postgres")
BOOLEAN_COLUMNS = {
"events": ["processed", "outlier"],
"rooms": ["is_public"],
"event_edges": ["is_state"],
"presence_list": ["accepted"],
}
APPEND_ONLY_TABLES = [
"event_content_hashes",
"event_reference_hashes",
"event_signatures",
"event_edge_hashes",
"events",
"event_json",
"state_events",
"room_memberships",
"feedback",
"topics",
"room_names",
"rooms",
"local_media_repository",
"local_media_repository_thumbnails",
"remote_media_cache",
"remote_media_cache_thumbnails",
"redactions",
"event_edges",
"event_auth",
"received_transactions",
"sent_transactions",
"transaction_id_to_pdu",
"users",
"state_groups",
"state_groups_state",
"event_to_state_groups",
"rejections",
]
end_error_exec_info = None
class Store(object):
"""This object is used to pull out some of the convenience API from the
Storage layer.
*All* database interactions should go through this object.
"""
def __init__(self, db_pool, engine):
self.db_pool = db_pool
self.database_engine = engine
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
_simple_insert = SQLBaseStore.__dict__["_simple_insert"]
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
_simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
_simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
_simple_select_one_onecol_txn = SQLBaseStore.__dict__["_simple_select_one_onecol_txn"]
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
_execute_and_decode = SQLBaseStore.__dict__["_execute_and_decode"]
def runInteraction(self, desc, func, *args, **kwargs):
def r(conn):
try:
i = 0
N = 5
while True:
try:
txn = conn.cursor()
return func(
LoggingTransaction(txn, desc, self.database_engine),
*args, **kwargs
)
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
logger.warn("[TXN DEADLOCK] {%s} %d/%d", desc, i, N)
if i < N:
i += 1
conn.rollback()
continue
raise
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", desc, e)
raise
return self.db_pool.runWithConnection(r)
def execute(self, f, *args, **kwargs):
return self.runInteraction(f.__name__, f, *args, **kwargs)
def execute_sql(self, sql, *args):
def r(txn):
txn.execute(sql, args)
return txn.fetchall()
return self.runInteraction("execute_sql", r)
def insert_many_txn(self, txn, table, headers, rows):
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in headers),
", ".join("%s" for _ in headers)
)
try:
txn.executemany(sql, rows)
except:
logger.exception(
"Failed to insert: %s",
table,
)
raise
class Porter(object):
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
@defer.inlineCallbacks
def setup_table(self, table):
if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting.
next_chunk = yield self.postgres_store._simple_select_one_onecol(
table="port_from_sqlite3",
keyvalues={"table_name": table},
retcol="rowid",
allow_none=True,
)
total_to_port = None
if next_chunk is None:
if table == "sent_transactions":
next_chunk, already_ported, total_to_port = (
yield self._setup_sent_transactions()
)
else:
yield self.postgres_store._simple_insert(
table="port_from_sqlite3",
values={"table_name": table, "rowid": 1}
)
next_chunk = 1
already_ported = 0
if total_to_port is None:
already_ported, total_to_port = yield self._get_total_count_to_port(
table, next_chunk
)
else:
def delete_all(txn):
txn.execute(
"DELETE FROM port_from_sqlite3 WHERE table_name = %s",
(table,)
)
txn.execute("TRUNCATE %s CASCADE" % (table,))
yield self.postgres_store.execute(delete_all)
yield self.postgres_store._simple_insert(
table="port_from_sqlite3",
values={"table_name": table, "rowid": 0}
)
next_chunk = 1
already_ported, total_to_port = yield self._get_total_count_to_port(
table, next_chunk
)
defer.returnValue((table, already_ported, total_to_port, next_chunk))
@defer.inlineCallbacks
def handle_table(self, table, postgres_size, table_size, next_chunk):
if not table_size:
return
self.progress.add_table(table, postgres_size, table_size)
select = (
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
% (table,)
)
while True:
def r(txn):
txn.execute(select, (next_chunk, self.batch_size,))
rows = txn.fetchall()
headers = [column[0] for column in txn.description]
return headers, rows
headers, rows = yield self.sqlite_store.runInteraction("select", r)
if rows:
next_chunk = rows[-1][0] + 1
self._convert_rows(table, headers, rows)
def insert(txn):
self.postgres_store.insert_many_txn(
txn, table, headers[1:], rows
)
self.postgres_store._simple_update_one_txn(
txn,
table="port_from_sqlite3",
keyvalues={"table_name": table},
updatevalues={"rowid": next_chunk},
)
yield self.postgres_store.execute(insert)
postgres_size += len(rows)
self.progress.update(table, postgres_size)
else:
return
def setup_db(self, db_config, database_engine):
db_conn = database_engine.module.connect(
**{
k: v for k, v in db_config.get("args", {}).items()
if not k.startswith("cp_")
}
)
database_engine.prepare_database(db_conn)
db_conn.commit()
@defer.inlineCallbacks
def run(self):
try:
sqlite_db_pool = adbapi.ConnectionPool(
self.sqlite_config["name"],
**self.sqlite_config["args"]
)
postgres_db_pool = adbapi.ConnectionPool(
self.postgres_config["name"],
**self.postgres_config["args"]
)
sqlite_engine = create_engine("sqlite3")
postgres_engine = create_engine("psycopg2")
self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
self.postgres_store = Store(postgres_db_pool, postgres_engine)
# Step 1. Set up databases.
self.progress.set_state("Preparing SQLite3")
self.setup_db(sqlite_config, sqlite_engine)
self.progress.set_state("Preparing PostgreSQL")
self.setup_db(postgres_config, postgres_engine)
# Step 2. Get tables.
self.progress.set_state("Fetching tables")
sqlite_tables = yield self.sqlite_store._simple_select_onecol(
table="sqlite_master",
keyvalues={
"type": "table",
},
retcol="name",
)
postgres_tables = yield self.postgres_store._simple_select_onecol(
table="information_schema.tables",
keyvalues={
"table_schema": "public",
},
retcol="distinct table_name",
)
tables = set(sqlite_tables) & set(postgres_tables)
self.progress.set_state("Creating tables")
logger.info("Found %d tables", len(tables))
def create_port_table(txn):
txn.execute(
"CREATE TABLE port_from_sqlite3 ("
" table_name varchar(100) NOT NULL UNIQUE,"
" rowid bigint NOT NULL"
")"
)
try:
yield self.postgres_store.runInteraction(
"create_port_table", create_port_table
)
except Exception as e:
logger.info("Failed to create port table: %s", e)
self.progress.set_state("Setting up")
# Set up tables.
setup_res = yield defer.gatherResults(
[
self.setup_table(table)
for table in tables
if table not in ["schema_version", "applied_schema_deltas"]
and not table.startswith("sqlite_")
],
consumeErrors=True,
)
# Process tables.
yield defer.gatherResults(
[
self.handle_table(*res)
for res in setup_res
],
consumeErrors=True,
)
self.progress.done()
except:
global end_error_exec_info
end_error_exec_info = sys.exc_info()
logger.exception("")
finally:
reactor.stop()
def _convert_rows(self, table, headers, rows):
bool_col_names = BOOLEAN_COLUMNS.get(table, [])
bool_cols = [
i for i, h in enumerate(headers) if h in bool_col_names
]
def conv(j, col):
if j in bool_cols:
return bool(col)
return col
for i, row in enumerate(rows):
rows[i] = tuple(
self.postgres_store.database_engine.encode_parameter(
conv(j, col)
)
for j, col in enumerate(row)
if j > 0
)
@defer.inlineCallbacks
def _setup_sent_transactions(self):
# Only save things from the last day
yesterday = int(time.time()*1000) - 86400000
# And save the max transaction id from each destination
select = (
"SELECT rowid, * FROM sent_transactions WHERE rowid IN ("
"SELECT max(rowid) FROM sent_transactions"
" GROUP BY destination"
")"
)
def r(txn):
txn.execute(select)
rows = txn.fetchall()
headers = [column[0] for column in txn.description]
ts_ind = headers.index('ts')
return headers, [r for r in rows if r[ts_ind] < yesterday]
headers, rows = yield self.sqlite_store.runInteraction(
"select", r,
)
self._convert_rows("sent_transactions", headers, rows)
inserted_rows = len(rows)
max_inserted_rowid = max(r[0] for r in rows)
def insert(txn):
self.postgres_store.insert_many_txn(
txn, "sent_transactions", headers[1:], rows
)
yield self.postgres_store.execute(insert)
def get_start_id(txn):
txn.execute(
"SELECT rowid FROM sent_transactions WHERE ts >= ?"
" ORDER BY rowid ASC LIMIT 1",
(yesterday,)
)
rows = txn.fetchall()
if rows:
return rows[0][0]
else:
return 1
next_chunk = yield self.sqlite_store.execute(get_start_id)
next_chunk = max(max_inserted_rowid + 1, next_chunk)
yield self.postgres_store._simple_insert(
table="port_from_sqlite3",
values={"table_name": "sent_transactions", "rowid": next_chunk}
)
def get_sent_table_size(txn):
txn.execute(
"SELECT count(*) FROM sent_transactions"
" WHERE ts >= ?",
(yesterday,)
)
size, = txn.fetchone()
return int(size)
remaining_count = yield self.sqlite_store.execute(
get_sent_table_size
)
total_count = remaining_count + inserted_rows
defer.returnValue((next_chunk, inserted_rows, total_count))
@defer.inlineCallbacks
def _get_remaining_count_to_port(self, table, next_chunk):
rows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
next_chunk,
)
defer.returnValue(rows[0][0])
@defer.inlineCallbacks
def _get_already_ported_count(self, table):
rows = yield self.postgres_store.execute_sql(
"SELECT count(*) FROM %s" % (table,),
)
defer.returnValue(rows[0][0])
@defer.inlineCallbacks
def _get_total_count_to_port(self, table, next_chunk):
remaining, done = yield defer.gatherResults(
[
self._get_remaining_count_to_port(table, next_chunk),
self._get_already_ported_count(table),
],
consumeErrors=True,
)
remaining = int(remaining) if remaining else 0
done = int(done) if done else 0
defer.returnValue((done, remaining + done))
##############################################
###### The following is simply UI stuff ######
##############################################
class Progress(object):
"""Used to report progress of the port
"""
def __init__(self):
self.tables = {}
self.start_time = int(time.time())
def add_table(self, table, cur, size):
self.tables[table] = {
"start": cur,
"num_done": cur,
"total": size,
"perc": int(cur * 100 / size),
}
def update(self, table, num_done):
data = self.tables[table]
data["num_done"] = num_done
data["perc"] = int(num_done * 100 / data["total"])
def done(self):
pass
class CursesProgress(Progress):
"""Reports progress to a curses window
"""
def __init__(self, stdscr):
self.stdscr = stdscr
curses.use_default_colors()
curses.curs_set(0)
curses.init_pair(1, curses.COLOR_RED, -1)
curses.init_pair(2, curses.COLOR_GREEN, -1)
self.last_update = 0
self.finished = False
self.total_processed = 0
self.total_remaining = 0
super(CursesProgress, self).__init__()
def update(self, table, num_done):
super(CursesProgress, self).update(table, num_done)
self.total_processed = 0
self.total_remaining = 0
for table, data in self.tables.items():
self.total_processed += data["num_done"] - data["start"]
self.total_remaining += data["total"] - data["num_done"]
self.render()
def render(self, force=False):
now = time.time()
if not force and now - self.last_update < 0.2:
# reactor.callLater(1, self.render)
return
self.stdscr.clear()
rows, cols = self.stdscr.getmaxyx()
duration = int(now) - int(self.start_time)
minutes, seconds = divmod(duration, 60)
duration_str = '%02dm %02ds' % (minutes, seconds,)
if self.finished:
status = "Time spent: %s (Done!)" % (duration_str,)
else:
if self.total_processed > 0:
left = float(self.total_remaining) / self.total_processed
est_remaining = (int(now) - self.start_time) * left
est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60)
else:
est_remaining_str = "Unknown"
status = (
"Time spent: %s (est. remaining: %s)"
% (duration_str, est_remaining_str,)
)
self.stdscr.addstr(
0, 0,
status,
curses.A_BOLD,
)
max_len = max([len(t) for t in self.tables.keys()])
left_margin = 5
middle_space = 1
items = self.tables.items()
items.sort(
key=lambda i: (i[1]["perc"], i[0]),
)
for i, (table, data) in enumerate(items):
if i + 2 >= rows:
break
perc = data["perc"]
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
self.stdscr.addstr(
i+2, left_margin + max_len - len(table),
table,
curses.A_BOLD | color,
)
size = 20
progress = "[%s%s]" % (
"#" * int(perc*size/100),
" " * (size - int(perc*size/100)),
)
self.stdscr.addstr(
i+2, left_margin + max_len + middle_space,
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
)
if self.finished:
self.stdscr.addstr(
rows-1, 0,
"Press any key to exit...",
)
self.stdscr.refresh()
self.last_update = time.time()
def done(self):
self.finished = True
self.render(True)
self.stdscr.getch()
def set_state(self, state):
self.stdscr.clear()
self.stdscr.addstr(
0, 0,
state + "...",
curses.A_BOLD,
)
self.stdscr.refresh()
class TerminalProgress(Progress):
"""Just prints progress to the terminal
"""
def update(self, table, num_done):
super(TerminalProgress, self).update(table, num_done)
data = self.tables[table]
print "%s: %d%% (%d/%d)" % (
table, data["perc"],
data["num_done"], data["total"],
)
def set_state(self, state):
print state + "..."
##############################################
##############################################
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="A script to port an existing synapse SQLite database to"
" a new PostgreSQL database."
)
parser.add_argument("-v", action='store_true')
parser.add_argument(
"--sqlite-database", required=True,
help="The snapshot of the SQLite database file. This must not be"
" currently used by a running synapse server"
)
parser.add_argument(
"--postgres-config", type=argparse.FileType('r'), required=True,
help="The database config file for the PostgreSQL database"
)
parser.add_argument(
"--curses", action='store_true',
help="display a curses based progress UI"
)
parser.add_argument(
"--batch-size", type=int, default=1000,
help="The number of rows to select from the SQLite table each"
" iteration [default=1000]",
)
args = parser.parse_args()
logging_config = {
"level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
}
if args.curses:
logging_config["filename"] = "port-synapse.log"
logging.basicConfig(**logging_config)
sqlite_config = {
"name": "sqlite3",
"args": {
"database": args.sqlite_database,
"cp_min": 1,
"cp_max": 1,
"check_same_thread": False,
},
}
postgres_config = yaml.safe_load(args.postgres_config)
if "name" not in postgres_config:
sys.stderr.write("Malformed database config: no 'name'")
sys.exit(2)
if postgres_config["name"] != "psycopg2":
sys.stderr.write("Database must use 'psycopg2' connector.")
sys.exit(3)
def start(stdscr=None):
if stdscr:
progress = CursesProgress(stdscr)
else:
progress = TerminalProgress()
porter = Porter(
sqlite_config=sqlite_config,
postgres_config=postgres_config,
progress=progress,
batch_size=args.batch_size,
)
reactor.callWhenRunning(porter.run)
reactor.run()
if args.curses:
curses.wrapper(start)
else:
start()
if end_error_exec_info:
exc_type, exc_value, exc_traceback = end_error_exec_info
traceback.print_exception(exc_type, exc_value, exc_traceback)

View file

@ -37,9 +37,13 @@ textarea, input {
margin: auto margin: auto
} }
.g-recaptcha div {
margin: auto;
}
#registrationForm { #registrationForm {
text-align: left; text-align: left;
padding: 1em; padding: 5px;
margin-bottom: 40px; margin-bottom: 40px;
display: inline-block; display: inline-block;

View file

@ -18,7 +18,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.types import UserID, ClientInfo from synapse.types import UserID, ClientInfo
@ -40,6 +40,7 @@ class Auth(object):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
def check(self, event, auth_events): def check(self, event, auth_events):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
@ -222,6 +223,13 @@ class Auth(object):
elif target_in_room: # the target is already in the room. elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." % raise AuthError(403, "%s is already in the room." %
target_user_id) target_user_id)
else:
invite_level = self._get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, "You cannot invite user %s." % target_user_id
)
elif Membership.JOIN == membership: elif Membership.JOIN == membership:
# Joins are valid iff caller == target and they were: # Joins are valid iff caller == target and they were:
# invited: They are accepting the invitation # invited: They are accepting the invitation
@ -362,7 +370,10 @@ class Auth(object):
defer.returnValue((user, ClientInfo(device_id, token_id))) defer.returnValue((user, ClientInfo(device_id, token_id)))
except KeyError: except KeyError:
raise AuthError(403, "Missing access token.") raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
errcode=Codes.MISSING_TOKEN
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_token(self, token): def get_user_by_token(self, token):
@ -376,10 +387,12 @@ class Auth(object):
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
try:
ret = yield self.store.get_user_by_token(token) ret = yield self.store.get_user_by_token(token)
if not ret: if not ret:
raise StoreError(400, "Unknown token") raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN
)
user_info = { user_info = {
"admin": bool(ret.get("admin", False)), "admin": bool(ret.get("admin", False)),
"device_id": ret.get("device_id"), "device_id": ret.get("device_id"),
@ -388,9 +401,6 @@ class Auth(object):
} }
defer.returnValue(user_info) defer.returnValue(user_info)
except StoreError:
raise AuthError(403, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_appservice_by_req(self, request): def get_appservice_by_req(self, request):
@ -398,11 +408,16 @@ class Auth(object):
token = request.args["access_token"][0] token = request.args["access_token"][0]
service = yield self.store.get_app_service_by_token(token) service = yield self.store.get_app_service_by_token(token)
if not service: if not service:
raise AuthError(403, "Unrecognised access token.", raise AuthError(
errcode=Codes.UNKNOWN_TOKEN) self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN
)
defer.returnValue(service) defer.returnValue(service)
except KeyError: except KeyError:
raise AuthError(403, "Missing access token.") raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token."
)
def is_server_admin(self, user): def is_server_admin(self, user):
return self.store.is_server_admin(user) return self.store.is_server_admin(user)
@ -561,6 +576,7 @@ class Auth(object):
("ban", []), ("ban", []),
("redact", []), ("redact", []),
("kick", []), ("kick", []),
("invite", []),
] ]
old_list = current_state.content.get("users") old_list = current_state.content.get("users")

View file

@ -59,6 +59,9 @@ class LoginType(object):
EMAIL_URL = u"m.login.email.url" EMAIL_URL = u"m.login.email.url"
EMAIL_IDENTITY = u"m.login.email.identity" EMAIL_IDENTITY = u"m.login.email.identity"
RECAPTCHA = u"m.login.recaptcha" RECAPTCHA = u"m.login.recaptcha"
DUMMY = u"m.login.dummy"
# Only for C/S API v1
APPLICATION_SERVICE = u"m.login.application_service" APPLICATION_SERVICE = u"m.login.application_service"
SHARED_SECRET = u"org.matrix.login.shared_secret" SHARED_SECRET = u"org.matrix.login.shared_secret"

View file

@ -31,6 +31,7 @@ class Codes(object):
BAD_PAGINATION = "M_BAD_PAGINATION" BAD_PAGINATION = "M_BAD_PAGINATION"
UNKNOWN = "M_UNKNOWN" UNKNOWN = "M_UNKNOWN"
NOT_FOUND = "M_NOT_FOUND" NOT_FOUND = "M_NOT_FOUND"
MISSING_TOKEN = "M_MISSING_TOKEN"
UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN" UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
@ -38,6 +39,7 @@ class Codes(object):
MISSING_PARAM = "M_MISSING_PARAM" MISSING_PARAM = "M_MISSING_PARAM"
TOO_LARGE = "M_TOO_LARGE" TOO_LARGE = "M_TOO_LARGE"
EXCLUSIVE = "M_EXCLUSIVE" EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):

View file

@ -17,8 +17,9 @@
import sys import sys
sys.dont_write_bytecode = True sys.dont_write_bytecode = True
from synapse.storage.engines import create_engine
from synapse.storage import ( from synapse.storage import (
prepare_database, prepare_sqlite3_database, UpgradeDatabaseException, are_all_users_on_domain, UpgradeDatabaseException,
) )
from synapse.server import HomeServer from synapse.server import HomeServer
@ -59,9 +60,9 @@ import os
import re import re
import resource import resource
import subprocess import subprocess
import sqlite3
logger = logging.getLogger(__name__)
logger = logging.getLogger("synapse.app.homeserver")
class SynapseHomeServer(HomeServer): class SynapseHomeServer(HomeServer):
@ -108,13 +109,11 @@ class SynapseHomeServer(HomeServer):
return None return None
def build_db_pool(self): def build_db_pool(self):
name = self.db_config["name"]
return adbapi.ConnectionPool( return adbapi.ConnectionPool(
"sqlite3", self.get_db_name(), name,
check_same_thread=False, **self.db_config.get("args", {})
cp_min=1,
cp_max=1,
cp_openfun=prepare_database, # Prepare the database for each conn
# so that :memory: sqlite works
) )
def create_resource_tree(self, redirect_root_to_web_client): def create_resource_tree(self, redirect_root_to_web_client):
@ -247,6 +246,21 @@ class SynapseHomeServer(HomeServer):
) )
logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port) logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port)
def run_startup_checks(self, db_conn, database_engine):
all_users_native = are_all_users_on_domain(
db_conn.cursor(), database_engine, self.hostname
)
if not all_users_native:
sys.stderr.write(
"\n"
"******************************************************\n"
"Found users in database not native to %s!\n"
"You cannot changed a synapse server_name after it's been configured\n"
"******************************************************\n"
"\n" % (self.hostname,)
)
sys.exit(1)
def get_version_string(): def get_version_string():
try: try:
@ -358,15 +372,20 @@ def setup(config_options):
tls_context_factory = context_factory.ServerContextFactory(config) tls_context_factory = context_factory.ServerContextFactory(config)
database_engine = create_engine(config.database_config["name"])
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
hs = SynapseHomeServer( hs = SynapseHomeServer(
config.server_name, config.server_name,
domain_with_port=domain_with_port, domain_with_port=domain_with_port,
upload_dir=os.path.abspath("uploads"), upload_dir=os.path.abspath("uploads"),
db_name=config.database_path, db_name=config.database_path,
db_config=config.database_config,
tls_context_factory=tls_context_factory, tls_context_factory=tls_context_factory,
config=config, config=config,
content_addr=config.content_addr, content_addr=config.content_addr,
version_string=version_string, version_string=version_string,
database_engine=database_engine,
) )
hs.create_resource_tree( hs.create_resource_tree(
@ -378,9 +397,17 @@ def setup(config_options):
logger.info("Preparing database: %s...", db_name) logger.info("Preparing database: %s...", db_name)
try: try:
with sqlite3.connect(db_name) as db_conn: db_conn = database_engine.module.connect(
prepare_sqlite3_database(db_conn) **{
prepare_database(db_conn) k: v for k, v in config.database_config.get("args", {}).items()
if not k.startswith("cp_")
}
)
database_engine.prepare_database(db_conn)
hs.run_startup_checks(db_conn, database_engine)
db_conn.commit()
except UpgradeDatabaseException: except UpgradeDatabaseException:
sys.stderr.write( sys.stderr.write(
"\nFailed to upgrade database.\n" "\nFailed to upgrade database.\n"

View file

@ -158,9 +158,10 @@ class Config(object):
and value is not None): and value is not None):
config[key] = value config[key] = value
with open(config_args.config_path, "w") as config_file: with open(config_args.config_path, "w") as config_file:
# TODO(paul) it would be lovely if we wrote out vim- and emacs- # TODO(mark/paul) We might want to output emacs-style mode
# style mode markers into the file, to hint to people that # markers as well as vim-style mode markers into the file,
# this is a YAML file. # to further hint to people this is a YAML file.
config_file.write("# vim:ft=yaml\n")
yaml.dump(config, config_file, default_flow_style=False) yaml.dump(config, config_file, default_flow_style=False)
print ( print (
"A config file has been generated in %s for server name" "A config file has been generated in %s for server name"

View file

@ -20,6 +20,7 @@ class CaptchaConfig(Config):
def __init__(self, args): def __init__(self, args):
super(CaptchaConfig, self).__init__(args) super(CaptchaConfig, self).__init__(args)
self.recaptcha_private_key = args.recaptcha_private_key self.recaptcha_private_key = args.recaptcha_private_key
self.recaptcha_public_key = args.recaptcha_public_key
self.enable_registration_captcha = args.enable_registration_captcha self.enable_registration_captcha = args.enable_registration_captcha
self.captcha_ip_origin_is_x_forwarded = ( self.captcha_ip_origin_is_x_forwarded = (
args.captcha_ip_origin_is_x_forwarded args.captcha_ip_origin_is_x_forwarded
@ -30,9 +31,13 @@ class CaptchaConfig(Config):
def add_arguments(cls, parser): def add_arguments(cls, parser):
super(CaptchaConfig, cls).add_arguments(parser) super(CaptchaConfig, cls).add_arguments(parser)
group = parser.add_argument_group("recaptcha") group = parser.add_argument_group("recaptcha")
group.add_argument(
"--recaptcha-public-key", type=str, default="YOUR_PUBLIC_KEY",
help="This Home Server's ReCAPTCHA public key."
)
group.add_argument( group.add_argument(
"--recaptcha-private-key", type=str, default="YOUR_PRIVATE_KEY", "--recaptcha-private-key", type=str, default="YOUR_PRIVATE_KEY",
help="The matching private key for the web client's public key." help="This Home Server's ReCAPTCHA private key."
) )
group.add_argument( group.add_argument(
"--enable-registration-captcha", type=bool, default=False, "--enable-registration-captcha", type=bool, default=False,

View file

@ -15,6 +15,7 @@
from ._base import Config from ._base import Config
import os import os
import yaml
class DatabaseConfig(Config): class DatabaseConfig(Config):
@ -26,18 +27,45 @@ class DatabaseConfig(Config):
self.database_path = self.abspath(args.database_path) self.database_path = self.abspath(args.database_path)
self.event_cache_size = self.parse_size(args.event_cache_size) self.event_cache_size = self.parse_size(args.event_cache_size)
if args.database_config:
with open(args.database_config) as f:
self.database_config = yaml.safe_load(f)
else:
self.database_config = {
"name": "sqlite3",
"args": {
"database": self.database_path,
},
}
name = self.database_config.get("name", None)
if name == "psycopg2":
pass
elif name == "sqlite3":
self.database_config.setdefault("args", {}).update({
"cp_min": 1,
"cp_max": 1,
"check_same_thread": False,
})
else:
raise RuntimeError("Unsupported database type '%s'" % (name,))
@classmethod @classmethod
def add_arguments(cls, parser): def add_arguments(cls, parser):
super(DatabaseConfig, cls).add_arguments(parser) super(DatabaseConfig, cls).add_arguments(parser)
db_group = parser.add_argument_group("database") db_group = parser.add_argument_group("database")
db_group.add_argument( db_group.add_argument(
"-d", "--database-path", default="homeserver.db", "-d", "--database-path", default="homeserver.db",
help="The database name." metavar="SQLITE_DATABASE_PATH", help="The database name."
) )
db_group.add_argument( db_group.add_argument(
"--event-cache-size", default="100K", "--event-cache-size", default="100K",
help="Number of events to cache in memory." help="Number of events to cache in memory."
) )
db_group.add_argument(
"--database-config", default=None,
help="Location of the database configuration file."
)
@classmethod @classmethod
def generate_config(cls, args, config_dir_path): def generate_config(cls, args, config_dir_path):

View file

@ -1,42 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class EmailConfig(Config):
def __init__(self, args):
super(EmailConfig, self).__init__(args)
self.email_from_address = args.email_from_address
self.email_smtp_server = args.email_smtp_server
@classmethod
def add_arguments(cls, parser):
super(EmailConfig, cls).add_arguments(parser)
email_group = parser.add_argument_group("email")
email_group.add_argument(
"--email-from-address",
default="FROM@EXAMPLE.COM",
help="The address to send emails from (e.g. for password resets)."
)
email_group.add_argument(
"--email-smtp-server",
default="",
help=(
"The SMTP server to send emails from (e.g. for password"
" resets)."
)
)

View file

@ -20,7 +20,6 @@ from .database import DatabaseConfig
from .ratelimiting import RatelimitConfig from .ratelimiting import RatelimitConfig
from .repository import ContentRepositoryConfig from .repository import ContentRepositoryConfig
from .captcha import CaptchaConfig from .captcha import CaptchaConfig
from .email import EmailConfig
from .voip import VoipConfig from .voip import VoipConfig
from .registration import RegistrationConfig from .registration import RegistrationConfig
from .metrics import MetricsConfig from .metrics import MetricsConfig
@ -30,7 +29,7 @@ from .key import KeyConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
EmailConfig, VoipConfig, RegistrationConfig, VoipConfig, RegistrationConfig,
MetricsConfig, AppServiceConfig, KeyConfig,): MetricsConfig, AppServiceConfig, KeyConfig,):
pass pass

View file

@ -78,7 +78,6 @@ class LoggingConfig(Config):
handler.addFilter(LoggingContextFilter(request="")) handler.addFilter(LoggingContextFilter(request=""))
logger.addHandler(handler) logger.addHandler(handler)
logger.info("Test")
else: else:
with open(self.log_config, 'r') as f: with open(self.log_config, 'r') as f:
logging.config.dictConfig(yaml.load(f)) logging.config.dictConfig(yaml.load(f))

View file

@ -30,6 +30,8 @@ from .typing import TypingNotificationHandler
from .admin import AdminHandler from .admin import AdminHandler
from .appservice import ApplicationServicesHandler from .appservice import ApplicationServicesHandler
from .sync import SyncHandler from .sync import SyncHandler
from .auth import AuthHandler
from .identity import IdentityHandler
class Handlers(object): class Handlers(object):
@ -64,3 +66,5 @@ class Handlers(object):
) )
) )
self.sync_handler = SyncHandler(hs) self.sync_handler = SyncHandler(hs)
self.auth_handler = AuthHandler(hs)
self.identity_handler = IdentityHandler(hs)

View file

@ -16,7 +16,6 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import LimitExceededError, SynapseError from synapse.api.errors import LimitExceededError, SynapseError
from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID from synapse.types import UserID
@ -58,8 +57,6 @@ class BaseHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_new_client_event(self, builder): def _create_new_client_event(self, builder):
yield run_on_reactor()
latest_ret = yield self.store.get_latest_events_in_room( latest_ret = yield self.store.get_latest_events_in_room(
builder.room_id, builder.room_id,
) )
@ -101,8 +98,6 @@ class BaseHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_new_client_event(self, event, context, extra_destinations=[], def handle_new_client_event(self, event, context, extra_destinations=[],
extra_users=[], suppress_auth=False): extra_users=[], suppress_auth=False):
yield run_on_reactor()
# We now need to go and hit out to wherever we need to hit out to. # We now need to go and hit out to wherever we need to hit out to.
if not suppress_auth: if not suppress_auth:
@ -143,7 +138,9 @@ class BaseHandler(object):
) )
# Don't block waiting on waking up all the listeners. # Don't block waiting on waking up all the listeners.
d = self.notifier.on_new_room_event(event, extra_users=extra_users) notify_d = self.notifier.on_new_room_event(
event, extra_users=extra_users
)
def log_failure(f): def log_failure(f):
logger.warn( logger.warn(
@ -151,8 +148,10 @@ class BaseHandler(object):
event.event_id, f.value event.event_id, f.value
) )
d.addErrback(log_failure) notify_d.addErrback(log_failure)
yield federation_handler.handle_new_event( fed_d = federation_handler.handle_new_event(
event, destinations=destinations, event, destinations=destinations,
) )
fed_d.addErrback(log_failure)

277
synapse/handlers/auth.py Normal file
View file

@ -0,0 +1,277 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import LoginType
from synapse.types import UserID
from synapse.api.errors import LoginError, Codes
from synapse.http.client import SimpleHttpClient
from synapse.util.async import run_on_reactor
from twisted.web.client import PartialDownloadError
import logging
import bcrypt
import simplejson
import synapse.util.stringutils as stringutils
logger = logging.getLogger(__name__)
class AuthHandler(BaseHandler):
def __init__(self, hs):
super(AuthHandler, self).__init__(hs)
self.checkers = {
LoginType.PASSWORD: self._check_password_auth,
LoginType.RECAPTCHA: self._check_recaptcha,
LoginType.EMAIL_IDENTITY: self._check_email_identity,
LoginType.DUMMY: self._check_dummy_auth,
}
self.sessions = {}
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip=None):
"""
Takes a dictionary sent by the client in the login / registration
protocol and handles the login flow.
Args:
flows: list of list of stages
authdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
Returns:
A tuple of authed, dict, dict where authed is true if the client
has successfully completed an auth flow. If it is true, the first
dict contains the authenticated credentials of each stage.
If authed is false, the first dictionary is the server response to
the login request and should be passed back to the client.
In either case, the second dict contains the parameters for this
request (which may have been given only in a previous call).
"""
authdict = None
sid = None
if clientdict and 'auth' in clientdict:
authdict = clientdict['auth']
del clientdict['auth']
if 'session' in authdict:
sid = authdict['session']
sess = self._get_session_info(sid)
if len(clientdict) > 0:
# This was designed to allow the client to omit the parameters
# and just supply the session in subsequent calls so it split
# auth between devices by just sharing the session, (eg. so you
# could continue registration from your phone having clicked the
# email auth link on there). It's probably too open to abuse
# because it lets unauthenticated clients store arbitrary objects
# on a home server.
# sess['clientdict'] = clientdict
# self._save_session(sess)
pass
elif 'clientdict' in sess:
clientdict = sess['clientdict']
if not authdict:
defer.returnValue(
(False, self._auth_dict_for_flows(flows, sess), clientdict)
)
if 'creds' not in sess:
sess['creds'] = {}
creds = sess['creds']
# check auth type currently being presented
if 'type' in authdict:
if authdict['type'] not in self.checkers:
raise LoginError(400, "", Codes.UNRECOGNIZED)
result = yield self.checkers[authdict['type']](authdict, clientip)
if result:
creds[authdict['type']] = result
self._save_session(sess)
for f in flows:
if len(set(f) - set(creds.keys())) == 0:
logger.info("Auth completed with creds: %r", creds)
self._remove_session(sess)
defer.returnValue((True, creds, clientdict))
ret = self._auth_dict_for_flows(flows, sess)
ret['completed'] = creds.keys()
defer.returnValue((False, ret, clientdict))
@defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip):
"""
Adds the result of out-of-band authentication into an existing auth
session. Currently used for adding the result of fallback auth.
"""
if stagetype not in self.checkers:
raise LoginError(400, "", Codes.MISSING_PARAM)
if 'session' not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
sess = self._get_session_info(
authdict['session']
)
if 'creds' not in sess:
sess['creds'] = {}
creds = sess['creds']
result = yield self.checkers[stagetype](authdict, clientip)
if result:
creds[stagetype] = result
self._save_session(sess)
defer.returnValue(True)
defer.returnValue(False)
@defer.inlineCallbacks
def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
user = authdict["user"]
password = authdict["password"]
if not user.startswith('@'):
user = UserID.create(user, self.hs.hostname).to_string()
user_info = yield self.store.get_user_by_id(user_id=user)
if not user_info:
logger.warn("Attempted to login as %s but they do not exist", user)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
stored_hash = user_info[0]["password_hash"]
if bcrypt.checkpw(password, stored_hash):
defer.returnValue(user)
else:
logger.warn("Failed password login for user %s", user)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip):
try:
user_response = authdict["response"]
except KeyError:
# Client tried to provide captcha but didn't give the parameter:
# bad request.
raise LoginError(
400, "Captcha response is required",
errcode=Codes.CAPTCHA_NEEDED
)
logger.info(
"Submitting recaptcha response %s with remoteip %s",
user_response, clientip
)
# TODO: get this from the homeserver rather than creating a new one for
# each request
try:
client = SimpleHttpClient(self.hs)
data = yield client.post_urlencoded_get_json(
"https://www.google.com/recaptcha/api/siteverify",
args={
'secret': self.hs.config.recaptcha_private_key,
'response': user_response,
'remoteip': clientip,
}
)
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
resp_body = simplejson.loads(data)
if 'success' in resp_body and resp_body['success']:
defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@defer.inlineCallbacks
def _check_email_identity(self, authdict, _):
yield run_on_reactor()
if 'threepid_creds' not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
threepid_creds = authdict['threepid_creds']
identity_handler = self.hs.get_handlers().identity_handler
logger.info("Getting validated threepid. threepidcreds: %r" % (threepid_creds,))
threepid = yield identity_handler.threepid_from_creds(threepid_creds)
if not threepid:
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
threepid['threepid_creds'] = authdict['threepid_creds']
defer.returnValue(threepid)
@defer.inlineCallbacks
def _check_dummy_auth(self, authdict, _):
yield run_on_reactor()
defer.returnValue(True)
def _get_params_recaptcha(self):
return {"public_key": self.hs.config.recaptcha_public_key}
def _auth_dict_for_flows(self, flows, session):
public_flows = []
for f in flows:
public_flows.append(f)
get_params = {
LoginType.RECAPTCHA: self._get_params_recaptcha,
}
params = {}
for f in public_flows:
for stage in f:
if stage in get_params and stage not in params:
params[stage] = get_params[stage]()
return {
"session": session['id'],
"flows": [{"stages": f} for f in public_flows],
"params": params
}
def _get_session_info(self, session_id):
if session_id not in self.sessions:
session_id = None
if not session_id:
# create a new session
while session_id is None or session_id in self.sessions:
session_id = stringutils.random_string(24)
self.sessions[session_id] = {
"id": session_id,
}
return self.sessions[session_id]
def _save_session(self, session):
# TODO: Persistent storage
logger.debug("Saving session %s", session)
self.sessions[session["id"]] = session
def _remove_session(self, session):
logger.debug("Removing session %s", session)
del self.sessions[session["id"]]

View file

@ -179,7 +179,7 @@ class FederationHandler(BaseHandler):
# it's probably a good idea to mark it as not in retry-state # it's probably a good idea to mark it as not in retry-state
# for sending (although this is a bit of a leap) # for sending (although this is a bit of a leap)
retry_timings = yield self.store.get_destination_retry_timings(origin) retry_timings = yield self.store.get_destination_retry_timings(origin)
if (retry_timings and retry_timings.retry_last_ts): if retry_timings and retry_timings["retry_last_ts"]:
self.store.set_destination_retry_timings(origin, 0, 0) self.store.set_destination_retry_timings(origin, 0, 0)
room = yield self.store.get_room(event.room_id) room = yield self.store.get_room(event.room_id)

View file

@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for interacting with Identity Servers"""
from twisted.internet import defer
from synapse.api.errors import (
CodeMessageException
)
from ._base import BaseHandler
from synapse.http.client import SimpleHttpClient
from synapse.util.async import run_on_reactor
import json
import logging
logger = logging.getLogger(__name__)
class IdentityHandler(BaseHandler):
def __init__(self, hs):
super(IdentityHandler, self).__init__(hs)
@defer.inlineCallbacks
def threepid_from_creds(self, creds):
yield run_on_reactor()
# TODO: get this from the homeserver rather than creating a new one for
# each request
http_client = SimpleHttpClient(self.hs)
# XXX: make this configurable!
# trustedIdServers = ['matrix.org', 'localhost:8090']
trustedIdServers = ['matrix.org']
if not creds['id_server'] in trustedIdServers:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', creds['id_server'])
defer.returnValue(None)
data = {}
try:
data = yield http_client.get_json(
"https://%s%s" % (
creds['id_server'],
"/_matrix/identity/api/v1/3pid/getValidated3pid"
),
{'sid': creds['sid'], 'client_secret': creds['client_secret']}
)
except CodeMessageException as e:
data = json.loads(e.msg)
if 'medium' in data:
defer.returnValue(data)
defer.returnValue(None)
@defer.inlineCallbacks
def bind_threepid(self, creds, mxid):
yield run_on_reactor()
logger.debug("binding threepid %r to %s", creds, mxid)
http_client = SimpleHttpClient(self.hs)
data = None
try:
data = yield http_client.post_urlencoded_get_json(
"https://%s%s" % (
creds['id_server'], "/_matrix/identity/api/v1/3pid/bind"
),
{
'sid': creds['sid'],
'client_secret': creds['client_secret'],
'mxid': mxid,
}
)
logger.debug("bound threepid %r to %s", creds, mxid)
except CodeMessageException as e:
data = json.loads(e.msg)
defer.returnValue(data)

View file

@ -16,13 +16,9 @@
from twisted.internet import defer from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.errors import LoginError, Codes, CodeMessageException from synapse.api.errors import LoginError, Codes
from synapse.http.client import SimpleHttpClient
from synapse.util.emailutils import EmailException
import synapse.util.emailutils as emailutils
import bcrypt import bcrypt
import json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -57,7 +53,7 @@ class LoginHandler(BaseHandler):
logger.warn("Attempted to login as %s but they do not exist", user) logger.warn("Attempted to login as %s but they do not exist", user)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
stored_hash = user_info[0]["password_hash"] stored_hash = user_info["password_hash"]
if bcrypt.checkpw(password, stored_hash): if bcrypt.checkpw(password, stored_hash):
# generate an access token and store it. # generate an access token and store it.
token = self.reg_handler._generate_token(user) token = self.reg_handler._generate_token(user)
@ -69,48 +65,19 @@ class LoginHandler(BaseHandler):
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks @defer.inlineCallbacks
def reset_password(self, user_id, email): def set_password(self, user_id, newpassword, token_id=None):
is_valid = yield self._check_valid_association(user_id, email) password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
logger.info("reset_password user=%s email=%s valid=%s", user_id, email,
is_valid) yield self.store.user_set_password_hash(user_id, password_hash)
if is_valid: yield self.store.user_delete_access_tokens_apart_from(user_id, token_id)
try: yield self.hs.get_pusherpool().remove_pushers_by_user_access_token(
# send an email out user_id, token_id
emailutils.send_email(
smtp_server=self.hs.config.email_smtp_server,
from_addr=self.hs.config.email_from_address,
to_addr=email,
subject="Password Reset",
body="TODO."
) )
except EmailException as e: yield self.store.flush_user(user_id)
logger.exception(e)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_valid_association(self, user_id, email): def add_threepid(self, user_id, medium, address, validated_at):
identity = yield self._query_email(email) yield self.store.user_add_threepid(
if identity and "mxid" in identity: user_id, medium, address, validated_at,
if identity["mxid"] == user_id: self.hs.get_clock().time_msec()
defer.returnValue(True)
return
defer.returnValue(False)
@defer.inlineCallbacks
def _query_email(self, email):
http_client = SimpleHttpClient(self.hs)
try:
data = yield http_client.get_json(
# TODO FIXME This should be configurable.
# XXX: ID servers need to use HTTPS
"http://%s%s" % (
"matrix.org:8090", "/_matrix/identity/api/v1/lookup"
),
{
'medium': 'email',
'address': email
}
) )
defer.returnValue(data)
except CodeMessageException as e:
data = json.loads(e.msg)
defer.returnValue(data)

View file

@ -274,7 +274,8 @@ class MessageHandler(BaseHandler):
if limit is None: if limit is None:
limit = 10 limit = 10
for event in room_list: @defer.inlineCallbacks
def handle_room(event):
d = { d = {
"room_id": event.room_id, "room_id": event.room_id,
"membership": event.membership, "membership": event.membership,
@ -290,12 +291,19 @@ class MessageHandler(BaseHandler):
rooms_ret.append(d) rooms_ret.append(d)
if event.membership != Membership.JOIN: if event.membership != Membership.JOIN:
continue return
try: try:
messages, token = yield self.store.get_recent_events_for_room( (messages, token), current_state = yield defer.gatherResults(
[
self.store.get_recent_events_for_room(
event.room_id, event.room_id,
limit=limit, limit=limit,
end_token=now_token.room_key, end_token=now_token.room_key,
),
self.state_handler.get_current_state(
event.room_id
),
]
) )
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token[0])
@ -311,9 +319,6 @@ class MessageHandler(BaseHandler):
"end": end_token.to_string(), "end": end_token.to_string(),
} }
current_state = yield self.state_handler.get_current_state(
event.room_id
)
d["state"] = [ d["state"] = [
serialize_event(c, time_now, as_client_event) serialize_event(c, time_now, as_client_event)
for c in current_state.values() for c in current_state.values()
@ -321,6 +326,11 @@ class MessageHandler(BaseHandler):
except: except:
logger.exception("Failed to get snapshot") logger.exception("Failed to get snapshot")
yield defer.gatherResults(
[handle_room(e) for e in room_list],
consumeErrors=True
)
ret = { ret = {
"rooms": rooms_ret, "rooms": rooms_ret,
"presence": presence, "presence": presence,

View file

@ -858,22 +858,24 @@ class PresenceEventSource(object):
presence = self.hs.get_handlers().presence_handler presence = self.hs.get_handlers().presence_handler
cachemap = presence._user_cachemap cachemap = presence._user_cachemap
max_serial = presence._user_cachemap_latest_serial
clock = self.clock clock = self.clock
latest_serial = None latest_serial = 0
updates = [] updates = []
# TODO(paul): use a DeferredList ? How to limit concurrency. # TODO(paul): use a DeferredList ? How to limit concurrency.
for observed_user in cachemap.keys(): for observed_user in cachemap.keys():
cached = cachemap[observed_user] cached = cachemap[observed_user]
if cached.serial <= from_key: if cached.serial <= from_key or cached.serial > max_serial:
continue continue
if not (yield self.is_visible(observer_user, observed_user)): if not (yield self.is_visible(observer_user, observed_user)):
continue continue
if latest_serial is None or cached.serial > latest_serial: latest_serial = max(cached.serial, latest_serial)
latest_serial = cached.serial
updates.append(cached.make_event(user=observed_user, clock=clock)) updates.append(cached.make_event(user=observed_user, clock=clock))
# TODO(paul): limit # TODO(paul): limit
@ -882,6 +884,10 @@ class PresenceEventSource(object):
if serial < from_key: if serial < from_key:
break break
if serial > max_serial:
continue
latest_serial = max(latest_serial, serial)
for u in user_ids: for u in user_ids:
updates.append({ updates.append({
"type": "m.presence", "type": "m.presence",

View file

@ -18,18 +18,15 @@ from twisted.internet import defer
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError, AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
CodeMessageException
) )
from ._base import BaseHandler from ._base import BaseHandler
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.http.client import SimpleHttpClient
from synapse.http.client import CaptchaServerHttpClient from synapse.http.client import CaptchaServerHttpClient
import base64 import base64
import bcrypt import bcrypt
import json
import logging import logging
import urllib import urllib
@ -44,6 +41,30 @@ class RegistrationHandler(BaseHandler):
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
self.distributor.declare("registered_user") self.distributor.declare("registered_user")
@defer.inlineCallbacks
def check_username(self, localpart):
yield run_on_reactor()
if urllib.quote(localpart) != localpart:
raise SynapseError(
400,
"User ID must only contain characters which do not"
" require URL encoding."
)
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
u = yield self.store.get_user_by_id(user_id)
if u:
raise SynapseError(
400,
"User ID already taken.",
errcode=Codes.USER_IN_USE,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def register(self, localpart=None, password=None): def register(self, localpart=None, password=None):
"""Registers a new client on the server. """Registers a new client on the server.
@ -64,18 +85,11 @@ class RegistrationHandler(BaseHandler):
password_hash = bcrypt.hashpw(password, bcrypt.gensalt()) password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
if localpart: if localpart:
if localpart and urllib.quote(localpart) != localpart: yield self.check_username(localpart)
raise SynapseError(
400,
"User ID must only contain characters which do not"
" require URL encoding."
)
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
token = self._generate_token(user_id) token = self._generate_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
@ -157,7 +171,11 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_recaptcha(self, ip, private_key, challenge, response): def check_recaptcha(self, ip, private_key, challenge, response):
"""Checks a recaptcha is correct.""" """
Checks a recaptcha is correct.
Used only by c/s api v1
"""
captcha_response = yield self._validate_captcha( captcha_response = yield self._validate_captcha(
ip, ip,
@ -176,13 +194,18 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def register_email(self, threepidCreds): def register_email(self, threepidCreds):
"""Registers emails with an identity server.""" """
Registers emails with an identity server.
Used only by c/s api v1
"""
for c in threepidCreds: for c in threepidCreds:
logger.info("validating theeepidcred sid %s on id server %s", logger.info("validating theeepidcred sid %s on id server %s",
c['sid'], c['idServer']) c['sid'], c['idServer'])
try: try:
threepid = yield self._threepid_from_creds(c) identity_handler = self.hs.get_handlers().identity_handler
threepid = yield identity_handler.threepid_from_creds(c)
except: except:
logger.exception("Couldn't validate 3pid") logger.exception("Couldn't validate 3pid")
raise RegistrationError(400, "Couldn't validate 3pid") raise RegistrationError(400, "Couldn't validate 3pid")
@ -194,12 +217,16 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def bind_emails(self, user_id, threepidCreds): def bind_emails(self, user_id, threepidCreds):
"""Links emails with a user ID and informs an identity server.""" """Links emails with a user ID and informs an identity server.
Used only by c/s api v1
"""
# Now we have a matrix ID, bind it to the threepids we were given # Now we have a matrix ID, bind it to the threepids we were given
for c in threepidCreds: for c in threepidCreds:
identity_handler = self.hs.get_handlers().identity_handler
# XXX: This should be a deferred list, shouldn't it? # XXX: This should be a deferred list, shouldn't it?
yield self._bind_threepid(c, user_id) yield identity_handler.bind_threepid(c, user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_user_id_is_valid(self, user_id): def check_user_id_is_valid(self, user_id):
@ -226,62 +253,12 @@ class RegistrationHandler(BaseHandler):
def _generate_user_id(self): def _generate_user_id(self):
return "-" + stringutils.random_string(18) return "-" + stringutils.random_string(18)
@defer.inlineCallbacks
def _threepid_from_creds(self, creds):
# TODO: get this from the homeserver rather than creating a new one for
# each request
http_client = SimpleHttpClient(self.hs)
# XXX: make this configurable!
trustedIdServers = ['matrix.org:8090', 'matrix.org']
if not creds['idServer'] in trustedIdServers:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', creds['idServer'])
defer.returnValue(None)
data = {}
try:
data = yield http_client.get_json(
# XXX: This should be HTTPS
"http://%s%s" % (
creds['idServer'],
"/_matrix/identity/api/v1/3pid/getValidated3pid"
),
{'sid': creds['sid'], 'clientSecret': creds['clientSecret']}
)
except CodeMessageException as e:
data = json.loads(e.msg)
if 'medium' in data:
defer.returnValue(data)
defer.returnValue(None)
@defer.inlineCallbacks
def _bind_threepid(self, creds, mxid):
yield
logger.debug("binding threepid")
http_client = SimpleHttpClient(self.hs)
data = None
try:
data = yield http_client.post_urlencoded_get_json(
# XXX: Change when ID servers are all HTTPS
"http://%s%s" % (
creds['idServer'], "/_matrix/identity/api/v1/3pid/bind"
),
{
'sid': creds['sid'],
'clientSecret': creds['clientSecret'],
'mxid': mxid,
}
)
logger.debug("bound threepid")
except CodeMessageException as e:
data = json.loads(e.msg)
defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def _validate_captcha(self, ip_addr, private_key, challenge, response): def _validate_captcha(self, ip_addr, private_key, challenge, response):
"""Validates the captcha provided. """Validates the captcha provided.
Used only by c/s api v1
Returns: Returns:
dict: Containing 'valid'(bool) and 'error_url'(str) if invalid. dict: Containing 'valid'(bool) and 'error_url'(str) if invalid.
@ -299,6 +276,9 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _submit_captcha(self, ip_addr, private_key, challenge, response): def _submit_captcha(self, ip_addr, private_key, challenge, response):
"""
Used only by c/s api v1
"""
# TODO: get this from the homeserver rather than creating a new one for # TODO: get this from the homeserver rather than creating a new one for
# each request # each request
client = CaptchaServerHttpClient(self.hs) client = CaptchaServerHttpClient(self.hs)

View file

@ -213,7 +213,8 @@ class RoomCreationHandler(BaseHandler):
"state_default": 50, "state_default": 50,
"ban": 50, "ban": 50,
"kick": 50, "kick": 50,
"redact": 50 "redact": 50,
"invite": 0,
}, },
) )
@ -310,25 +311,6 @@ class RoomMemberHandler(BaseHandler):
# paginating # paginating
defer.returnValue(chunk_data) defer.returnValue(chunk_data)
@defer.inlineCallbacks
def get_room_member(self, room_id, member_user_id, auth_user_id):
"""Retrieve a room member from a room.
Args:
room_id : The room the member is in.
member_user_id : The member's user ID
auth_user_id : The user ID of the user making this request.
Returns:
The room member, or None if this member does not exist.
Raises:
SynapseError if something goes wrong.
"""
yield self.auth.check_joined_room(room_id, auth_user_id)
member = yield self.store.get_room_member(user_id=member_user_id,
room_id=room_id)
defer.returnValue(member)
@defer.inlineCallbacks @defer.inlineCallbacks
def change_membership(self, event, context, do_auth=True): def change_membership(self, event, context, do_auth=True):
""" Change the membership status of a user in a room. """ Change the membership status of a user in a room.

View file

@ -200,6 +200,8 @@ class CaptchaServerHttpClient(SimpleHttpClient):
""" """
Separate HTTP client for talking to google's captcha servers Separate HTTP client for talking to google's captcha servers
Only slightly special because accepts partial download responses Only slightly special because accepts partial download responses
used only by c/s api v1
""" """
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -131,10 +131,10 @@ class HttpServer(object):
""" """
def register_path(self, method, path_pattern, callback): def register_path(self, method, path_pattern, callback):
""" Register a callback that get's fired if we receive a http request """ Register a callback that gets fired if we receive a http request
with the given method for a path that matches the given regex. with the given method for a path that matches the given regex.
If the regex contains groups these get's passed to the calback via If the regex contains groups these gets passed to the calback via
an unpacked tuple. an unpacked tuple.
Args: Args:
@ -153,6 +153,13 @@ class JsonResource(HttpServer, resource.Resource):
Resources. Resources.
Register callbacks via register_path() Register callbacks via register_path()
Callbacks can return a tuple of status code and a dict in which case the
the dict will automatically be sent to the client as a JSON object.
The JsonResource is primarily intended for returning JSON, but callbacks
may send something other than JSON, they may do so by using the methods
on the request object and instead returning None.
""" """
isLeaf = True isLeaf = True
@ -185,9 +192,8 @@ class JsonResource(HttpServer, resource.Resource):
interface=self.hs.config.bind_host interface=self.hs.config.bind_host
) )
# Gets called by twisted
def render(self, request): def render(self, request):
""" This get's called by twisted every time someone sends us a request. """ This gets called by twisted every time someone sends us a request.
""" """
self._async_render(request) self._async_render(request)
return server.NOT_DONE_YET return server.NOT_DONE_YET
@ -195,7 +201,7 @@ class JsonResource(HttpServer, resource.Resource):
@request_handler @request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render(self, request): def _async_render(self, request):
""" This get's called by twisted every time someone sends us a request. """ This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and This checks if anyone has registered a callback for that method and
path. path.
""" """
@ -227,9 +233,11 @@ class JsonResource(HttpServer, resource.Resource):
urllib.unquote(u).decode("UTF-8") for u in m.groups() urllib.unquote(u).decode("UTF-8") for u in m.groups()
] ]
code, response = yield callback(request, *args) callback_return = yield callback(request, *args)
if callback_return is not None:
code, response = callback_return
self._send_response(request, code, response) self._send_response(request, code, response)
response_timer.inc_by( response_timer.inc_by(
self.clock.time_msec() - start, request.method, servlet_classname self.clock.time_msec() - start, request.method, servlet_classname
) )

View file

@ -98,7 +98,7 @@ class _NotificationListener(object):
try: try:
notifier.clock.cancel_call_later(self.timer) notifier.clock.cancel_call_later(self.timer)
except: except:
logger.exception("Failed to cancel notifier timer") logger.warn("Failed to cancel notifier timer")
class Notifier(object): class Notifier(object):

View file

@ -253,7 +253,8 @@ class Pusher(object):
self.user_name, config, timeout=0) self.user_name, config, timeout=0)
self.last_token = chunk['end'] self.last_token = chunk['end']
self.store.update_pusher_last_token( self.store.update_pusher_last_token(
self.app_id, self.pushkey, self.last_token) self.app_id, self.pushkey, self.user_name, self.last_token
)
logger.info("Pusher %s for user %s starting from token %s", logger.info("Pusher %s for user %s starting from token %s",
self.pushkey, self.user_name, self.last_token) self.pushkey, self.user_name, self.last_token)
@ -314,7 +315,7 @@ class Pusher(object):
pk pk
) )
yield self.hs.get_pusherpool().remove_pusher( yield self.hs.get_pusherpool().remove_pusher(
self.app_id, pk self.app_id, pk, self.user_name
) )
if not self.alive: if not self.alive:
@ -326,6 +327,7 @@ class Pusher(object):
self.store.update_pusher_last_token_and_success( self.store.update_pusher_last_token_and_success(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name,
self.last_token, self.last_token,
self.clock.time_msec() self.clock.time_msec()
) )
@ -334,6 +336,7 @@ class Pusher(object):
self.store.update_pusher_failing_since( self.store.update_pusher_failing_since(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name,
self.failing_since) self.failing_since)
else: else:
if not self.failing_since: if not self.failing_since:
@ -341,6 +344,7 @@ class Pusher(object):
self.store.update_pusher_failing_since( self.store.update_pusher_failing_since(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name,
self.failing_since self.failing_since
) )
@ -358,6 +362,7 @@ class Pusher(object):
self.store.update_pusher_last_token( self.store.update_pusher_last_token(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name,
self.last_token self.last_token
) )
@ -365,6 +370,7 @@ class Pusher(object):
self.store.update_pusher_failing_since( self.store.update_pusher_failing_since(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name,
self.failing_since self.failing_since
) )
else: else:

View file

@ -19,10 +19,7 @@ from twisted.internet import defer
from httppusher import HttpPusher from httppusher import HttpPusher
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from syutil.jsonutil import encode_canonical_json
import logging import logging
import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -52,12 +49,10 @@ class PusherPool:
@defer.inlineCallbacks @defer.inlineCallbacks
def start(self): def start(self):
pushers = yield self.store.get_all_pushers() pushers = yield self.store.get_all_pushers()
for p in pushers:
p['data'] = json.loads(p['data'])
self._start_pushers(pushers) self._start_pushers(pushers)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_pusher(self, user_name, profile_tag, kind, app_id, def add_pusher(self, user_name, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name, pushkey, lang, data): app_display_name, device_display_name, pushkey, lang, data):
# we try to create the pusher just to validate the config: it # we try to create the pusher just to validate the config: it
# will then get pulled out of the database, # will then get pulled out of the database,
@ -71,7 +66,7 @@ class PusherPool:
"app_display_name": app_display_name, "app_display_name": app_display_name,
"device_display_name": device_display_name, "device_display_name": device_display_name,
"pushkey": pushkey, "pushkey": pushkey,
"pushkey_ts": self.hs.get_clock().time_msec(), "ts": self.hs.get_clock().time_msec(),
"lang": lang, "lang": lang,
"data": data, "data": data,
"last_token": None, "last_token": None,
@ -79,17 +74,50 @@ class PusherPool:
"failing_since": None "failing_since": None
}) })
yield self._add_pusher_to_store( yield self._add_pusher_to_store(
user_name, profile_tag, kind, app_id, user_name, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name, app_display_name, device_display_name,
pushkey, lang, data pushkey, lang, data
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _add_pusher_to_store(self, user_name, profile_tag, kind, app_id, def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey,
app_display_name, device_display_name, not_user_id):
to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(
app_id, pushkey
)
for p in to_remove:
if p['user_name'] != not_user_id:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
app_id, pushkey, p['user_name']
)
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def remove_pushers_by_user_access_token(self, user_id, not_access_token_id):
all = yield self.store.get_all_pushers()
logger.info(
"Removing all pushers for user %s except access token %s",
user_id, not_access_token_id
)
for p in all:
if (
p['user_name'] == user_id and
p['access_token'] != not_access_token_id
):
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name']
)
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def _add_pusher_to_store(self, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, lang, data): pushkey, lang, data):
yield self.store.add_pusher( yield self.store.add_pusher(
user_name=user_name, user_name=user_name,
access_token=access_token,
profile_tag=profile_tag, profile_tag=profile_tag,
kind=kind, kind=kind,
app_id=app_id, app_id=app_id,
@ -98,9 +126,9 @@ class PusherPool:
pushkey=pushkey, pushkey=pushkey,
pushkey_ts=self.hs.get_clock().time_msec(), pushkey_ts=self.hs.get_clock().time_msec(),
lang=lang, lang=lang,
data=encode_canonical_json(data).decode("UTF-8"), data=data,
) )
self._refresh_pusher((app_id, pushkey)) self._refresh_pusher(app_id, pushkey, user_name)
def _create_pusher(self, pusherdict): def _create_pusher(self, pusherdict):
if pusherdict['kind'] == 'http': if pusherdict['kind'] == 'http':
@ -112,7 +140,7 @@ class PusherPool:
app_display_name=pusherdict['app_display_name'], app_display_name=pusherdict['app_display_name'],
device_display_name=pusherdict['device_display_name'], device_display_name=pusherdict['device_display_name'],
pushkey=pusherdict['pushkey'], pushkey=pusherdict['pushkey'],
pushkey_ts=pusherdict['pushkey_ts'], pushkey_ts=pusherdict['ts'],
data=pusherdict['data'], data=pusherdict['data'],
last_token=pusherdict['last_token'], last_token=pusherdict['last_token'],
last_success=pusherdict['last_success'], last_success=pusherdict['last_success'],
@ -125,11 +153,17 @@ class PusherPool:
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _refresh_pusher(self, app_id_pushkey): def _refresh_pusher(self, app_id, pushkey, user_name):
p = yield self.store.get_pushers_by_app_id_and_pushkey( resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(
app_id_pushkey app_id, pushkey
) )
p['data'] = json.loads(p['data'])
p = None
for r in resultlist:
if r['user_name'] == user_name:
p = r
if p:
self._start_pushers([p]) self._start_pushers([p])
@ -138,17 +172,23 @@ class PusherPool:
for pusherdict in pushers: for pusherdict in pushers:
p = self._create_pusher(pusherdict) p = self._create_pusher(pusherdict)
if p: if p:
fullid = "%s:%s" % (pusherdict['app_id'], pusherdict['pushkey']) fullid = "%s:%s:%s" % (
pusherdict['app_id'],
pusherdict['pushkey'],
pusherdict['user_name']
)
if fullid in self.pushers: if fullid in self.pushers:
self.pushers[fullid].stop() self.pushers[fullid].stop()
self.pushers[fullid] = p self.pushers[fullid] = p
p.start() p.start()
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pusher(self, app_id, pushkey): def remove_pusher(self, app_id, pushkey, user_name):
fullid = "%s:%s" % (app_id, pushkey) fullid = "%s:%s:%s" % (app_id, pushkey, user_name)
if fullid in self.pushers: if fullid in self.pushers:
logger.info("Stopping pusher %s", fullid) logger.info("Stopping pusher %s", fullid)
self.pushers[fullid].stop() self.pushers[fullid].stop()
del self.pushers[fullid] del self.pushers[fullid]
yield self.store.delete_pusher_by_app_id_pushkey(app_id, pushkey) yield self.store.delete_pusher_by_app_id_pushkey_user_name(
app_id, pushkey, user_name
)

View file

@ -48,5 +48,5 @@ class ClientV1RestServlet(RestServlet):
self.hs = hs self.hs = hs
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory() self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth() self.auth = hs.get_v1auth()
self.txns = HttpTransactionStore() self.txns = HttpTransactionStore()

View file

@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
user, _ = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
@ -37,7 +37,7 @@ class PusherRestServlet(ClientV1RestServlet):
and 'kind' in content and and 'kind' in content and
content['kind'] is None): content['kind'] is None):
yield pusher_pool.remove_pusher( yield pusher_pool.remove_pusher(
content['app_id'], content['pushkey'] content['app_id'], content['pushkey'], user_name=user.to_string()
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -51,9 +51,21 @@ class PusherRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Missing parameters: "+','.join(missing), raise SynapseError(400, "Missing parameters: "+','.join(missing),
errcode=Codes.MISSING_PARAM) errcode=Codes.MISSING_PARAM)
append = False
if 'append' in content:
append = content['append']
if not append:
yield pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
app_id=content['app_id'],
pushkey=content['pushkey'],
not_user_id=user.to_string()
)
try: try:
yield pusher_pool.add_pusher( yield pusher_pool.add_pusher(
user_name=user.to_string(), user_name=user.to_string(),
access_token=client.token_id,
profile_tag=content['profile_tag'], profile_tag=content['profile_tag'],
kind=content['kind'], kind=content['kind'],
app_id=content['app_id'], app_id=content['app_id'],

View file

@ -15,7 +15,10 @@
from . import ( from . import (
sync, sync,
filter filter,
account,
register,
auth
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -32,3 +35,6 @@ class ClientV2AlphaRestResource(JsonResource):
def register_servlets(client_resource, hs): def register_servlets(client_resource, hs):
sync.register_servlets(hs, client_resource) sync.register_servlets(hs, client_resource)
filter.register_servlets(hs, client_resource) filter.register_servlets(hs, client_resource)
account.register_servlets(hs, client_resource)
register.register_servlets(hs, client_resource)
auth.register_servlets(hs, client_resource)

View file

@ -17,9 +17,11 @@
""" """
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
from synapse.api.errors import SynapseError
import re import re
import logging import logging
import simplejson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,3 +38,23 @@ def client_v2_pattern(path_regex):
SRE_Pattern SRE_Pattern
""" """
return re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex) return re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)
def parse_request_allow_empty(request):
content = request.content.read()
if content is None or content == '':
return None
try:
return simplejson.loads(content)
except simplejson.JSONDecodeError:
raise SynapseError(400, "Content not JSON.")
def parse_json_dict_from_request(request):
try:
content = simplejson.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except simplejson.JSONDecodeError:
raise SynapseError(400, "Content not JSON.")

View file

@ -0,0 +1,159 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import LoginType
from synapse.api.errors import LoginError, SynapseError, Codes
from synapse.http.servlet import RestServlet
from synapse.util.async import run_on_reactor
from ._base import client_v2_pattern, parse_json_dict_from_request
import logging
logger = logging.getLogger(__name__)
class PasswordRestServlet(RestServlet):
PATTERN = client_v2_pattern("/account/password")
def __init__(self, hs):
super(PasswordRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
self.login_handler = hs.get_handlers().login_handler
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
body = parse_json_dict_from_request(request)
authed, result, params = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
[LoginType.EMAIL_IDENTITY]
], body)
if not authed:
defer.returnValue((401, result))
user_id = None
if LoginType.PASSWORD in result:
# if using password, they should also be logged in
auth_user, client = yield self.auth.get_user_by_req(request)
if auth_user.to_string() != result[LoginType.PASSWORD]:
raise LoginError(400, "", Codes.UNKNOWN)
user_id = auth_user.to_string()
elif LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
if 'medium' not in threepid or 'address' not in threepid:
raise SynapseError(500, "Malformed threepid")
# if using email, we must know about the email they're authing with!
threepid_user = yield self.hs.get_datastore().get_user_by_threepid(
threepid['medium'], threepid['address']
)
if not threepid_user:
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
user_id = threepid_user
else:
logger.error("Auth succeeded but no known type!", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
if 'new_password' not in params:
raise SynapseError(400, "", Codes.MISSING_PARAM)
new_password = params['new_password']
yield self.login_handler.set_password(
user_id, new_password, None
)
defer.returnValue((200, {}))
def on_OPTIONS(self, _):
return 200, {}
class ThreepidRestServlet(RestServlet):
PATTERN = client_v2_pattern("/account/3pid")
def __init__(self, hs):
super(ThreepidRestServlet, self).__init__()
self.hs = hs
self.login_handler = hs.get_handlers().login_handler
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
yield run_on_reactor()
auth_user, _ = yield self.auth.get_user_by_req(request)
threepids = yield self.hs.get_datastore().user_get_threepids(
auth_user.to_string()
)
defer.returnValue((200, {'threepids': threepids}))
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
body = parse_json_dict_from_request(request)
if 'threePidCreds' not in body:
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
threePidCreds = body['threePidCreds']
auth_user, client = yield self.auth.get_user_by_req(request)
threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
if not threepid:
raise SynapseError(
400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED
)
for reqd in ['medium', 'address', 'validated_at']:
if reqd not in threepid:
logger.warn("Couldn't add 3pid: invalid response from ID sevrer")
raise SynapseError(500, "Invalid response from ID Server")
yield self.login_handler.add_threepid(
auth_user.to_string(),
threepid['medium'],
threepid['address'],
threepid['validated_at'],
)
if 'bind' in body and body['bind']:
logger.debug(
"Binding emails %s to %s",
threepid, auth_user.to_string()
)
yield self.identity_handler.bind_threepid(
threePidCreds, auth_user.to_string()
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
PasswordRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)

View file

@ -0,0 +1,190 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern
import logging
logger = logging.getLogger(__name__)
RECAPTCHA_TEMPLATE = """
<html>
<head>
<title>Authentication</title>
<meta name='viewport' content='width=device-width, initial-scale=1,
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<script src="https://www.google.com/recaptcha/api.js"
async defer></script>
<script src="//code.jquery.com/jquery-1.11.2.min.js"></script>
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
<script>
function captchaDone() {
$('#registrationForm').submit();
}
</script>
</head>
<body>
<form id="registrationForm" method="post" action="%(myurl)s">
<div>
<p>
Hello! We need to prevent computer programs and other automated
things from creating accounts on this server.
</p>
<p>
Please verify that you're not a robot.
</p>
<input type="hidden" name="session" value="%(session)s" />
<div class="g-recaptcha"
data-sitekey="%(sitekey)s"
data-callback="captchaDone">
</div>
<noscript>
<input type="submit" value="All Done" />
</noscript>
</div>
</div>
</form>
</body>
</html>
"""
SUCCESS_TEMPLATE = """
<html>
<head>
<title>Success!</title>
<meta name='viewport' content='width=device-width, initial-scale=1,
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
<script>
if (window.onAuthDone != undefined) {
window.onAuthDone();
}
</script>
</head>
<body>
<div>
<p>Thank you</p>
<p>You may now close this window and return to the application</p>
</div>
</body>
</html>
"""
class AuthRestServlet(RestServlet):
"""
Handles Client / Server API authentication in any situations where it
cannot be handled in the normal flow (with requests to the same endpoint).
Current use is for web fallback auth.
"""
PATTERN = client_v2_pattern("/auth/(?P<stagetype>[\w\.]*)/fallback/web")
def __init__(self, hs):
super(AuthRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
self.registration_handler = hs.get_handlers().registration_handler
@defer.inlineCallbacks
def on_GET(self, request, stagetype):
yield
if stagetype == LoginType.RECAPTCHA:
if ('session' not in request.args or
len(request.args['session']) == 0):
raise SynapseError(400, "No session supplied")
session = request.args["session"][0]
html = RECAPTCHA_TEMPLATE % {
'session': session,
'myurl': "%s/auth/%s/fallback/web" % (
CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA
),
'sitekey': self.hs.config.recaptcha_public_key,
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
request.finish()
defer.returnValue(None)
else:
raise SynapseError(404, "Unknown auth stage type")
@defer.inlineCallbacks
def on_POST(self, request, stagetype):
yield
if stagetype == "m.login.recaptcha":
if ('g-recaptcha-response' not in request.args or
len(request.args['g-recaptcha-response'])) == 0:
raise SynapseError(400, "No captcha response supplied")
if ('session' not in request.args or
len(request.args['session'])) == 0:
raise SynapseError(400, "No session supplied")
session = request.args['session'][0]
authdict = {
'response': request.args['g-recaptcha-response'][0],
'session': session,
}
success = yield self.auth_handler.add_oob_auth(
LoginType.RECAPTCHA,
authdict,
self.hs.get_ip_from_request(request)
)
if success:
html = SUCCESS_TEMPLATE
else:
html = RECAPTCHA_TEMPLATE % {
'session': session,
'myurl': "%s/auth/%s/fallback/web" % (
CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA
),
'sitekey': self.hs.config.recaptcha_public_key,
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
request.finish()
defer.returnValue(None)
else:
raise SynapseError(404, "Unknown auth stage type")
def on_OPTIONS(self, _):
return 200, {}
def register_servlets(hs, http_server):
AuthRestServlet(hs).register(http_server)

View file

@ -0,0 +1,183 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes
from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern, parse_request_allow_empty
import logging
import hmac
from hashlib import sha1
from synapse.util.async import run_on_reactor
# We ought to be using hmac.compare_digest() but on older pythons it doesn't
# exist. It's a _really minor_ security flaw to use plain string comparison
# because the timing attack is so obscured by all the other code here it's
# unlikely to make much difference
if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
else:
compare_digest = lambda a, b: a == b
logger = logging.getLogger(__name__)
class RegisterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/register")
def __init__(self, hs):
super(RegisterRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler
self.login_handler = hs.get_handlers().login_handler
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
body = parse_request_allow_empty(request)
if 'password' not in body:
raise SynapseError(400, "", Codes.MISSING_PARAM)
if 'username' in body:
desired_username = body['username']
yield self.registration_handler.check_username(desired_username)
is_using_shared_secret = False
is_application_server = False
service = None
if 'access_token' in request.args:
service = yield self.auth.get_appservice_by_req(request)
if self.hs.config.enable_registration_captcha:
flows = [
[LoginType.RECAPTCHA],
[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]
]
else:
flows = [
[LoginType.DUMMY],
[LoginType.EMAIL_IDENTITY]
]
if service:
is_application_server = True
elif 'mac' in body:
# Check registration-specific shared secret auth
if 'username' not in body:
raise SynapseError(400, "", Codes.MISSING_PARAM)
self._check_shared_secret_auth(
body['username'], body['mac']
)
is_using_shared_secret = True
else:
authed, result, params = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
if not authed:
defer.returnValue((401, result))
can_register = (
not self.hs.config.disable_registration
or is_application_server
or is_using_shared_secret
)
if not can_register:
raise SynapseError(403, "Registration has been disabled")
if 'password' not in params:
raise SynapseError(400, "", Codes.MISSING_PARAM)
desired_username = params['username'] if 'username' in params else None
new_password = params['password']
(user_id, token) = yield self.registration_handler.register(
localpart=desired_username,
password=new_password
)
if LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
for reqd in ['medium', 'address', 'validated_at']:
if reqd not in threepid:
logger.info("Can't add incomplete 3pid")
else:
yield self.login_handler.add_threepid(
user_id,
threepid['medium'],
threepid['address'],
threepid['validated_at'],
)
if 'bind_email' in params and params['bind_email']:
logger.info("bind_email specified: binding")
emailThreepid = result[LoginType.EMAIL_IDENTITY]
threepid_creds = emailThreepid['threepid_creds']
logger.debug("Binding emails %s to %s" % (
emailThreepid, user_id
))
yield self.identity_handler.bind_threepid(threepid_creds, user_id)
else:
logger.info("bind_email not specified: not binding email")
result = {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
def on_OPTIONS(self, _):
return 200, {}
def _check_shared_secret_auth(self, username, mac):
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
user = username.encode("utf-8")
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
got_mac = str(mac)
want_mac = hmac.new(
key=self.hs.config.registration_shared_secret,
msg=user,
digestmod=sha1,
).hexdigest()
if compare_digest(want_mac, got_mac):
return True
else:
raise SynapseError(
403, "HMAC incorrect",
)
def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server)

View file

@ -65,6 +65,7 @@ class BaseHomeServer(object):
'replication_layer', 'replication_layer',
'datastore', 'datastore',
'handlers', 'handlers',
'v1auth',
'auth', 'auth',
'rest_servlet_factory', 'rest_servlet_factory',
'state_handler', 'state_handler',
@ -182,6 +183,15 @@ class HomeServer(BaseHomeServer):
def build_auth(self): def build_auth(self):
return Auth(self) return Auth(self)
def build_v1auth(self):
orf = Auth(self)
# Matrix spec makes no reference to what HTTP status code is returned,
# but the V1 API uses 403 where it means 401, and the webclient
# relies on this behaviour, so V1 gets its own copy of the auth
# with backwards compat behaviour.
orf.TOKEN_NOT_FOUND_HTTP_STATUS = 403
return orf
def build_state_handler(self): def build_state_handler(self):
return StateHandler(self) return StateHandler(self)

View file

@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 15 SCHEMA_VERSION = 16
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))
@ -104,14 +104,16 @@ class DataStore(RoomMemberStore, RoomStore,
self.client_ip_last_seen.prefill(*key + (now,)) self.client_ip_last_seen.prefill(*key + (now,))
yield self._simple_insert( yield self._simple_upsert(
"user_ips", "user_ips",
{ keyvalues={
"user": user.to_string(), "user_id": user.to_string(),
"access_token": access_token, "access_token": access_token,
"device_id": device_id,
"ip": ip, "ip": ip,
"user_agent": user_agent, "user_agent": user_agent,
},
values={
"device_id": device_id,
"last_seen": now, "last_seen": now,
}, },
desc="insert_client_ip", desc="insert_client_ip",
@ -120,7 +122,7 @@ class DataStore(RoomMemberStore, RoomStore,
def get_user_ip_and_agents(self, user): def get_user_ip_and_agents(self, user):
return self._simple_select_list( return self._simple_select_list(
table="user_ips", table="user_ips",
keyvalues={"user": user.to_string()}, keyvalues={"user_id": user.to_string()},
retcols=[ retcols=[
"device_id", "access_token", "ip", "user_agent", "last_seen" "device_id", "access_token", "ip", "user_agent", "last_seen"
], ],
@ -148,21 +150,23 @@ class UpgradeDatabaseException(PrepareDatabaseException):
pass pass
def prepare_database(db_conn): def prepare_database(db_conn, database_engine):
"""Prepares a database for usage. Will either create all necessary tables """Prepares a database for usage. Will either create all necessary tables
or upgrade from an older schema version. or upgrade from an older schema version.
""" """
try: try:
cur = db_conn.cursor() cur = db_conn.cursor()
version_info = _get_or_create_schema_state(cur) version_info = _get_or_create_schema_state(cur, database_engine)
if version_info: if version_info:
user_version, delta_files, upgraded = version_info user_version, delta_files, upgraded = version_info
_upgrade_existing_database(cur, user_version, delta_files, upgraded) _upgrade_existing_database(
cur, user_version, delta_files, upgraded, database_engine
)
else: else:
_setup_new_database(cur) _setup_new_database(cur, database_engine)
cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
cur.close() cur.close()
db_conn.commit() db_conn.commit()
@ -171,7 +175,7 @@ def prepare_database(db_conn):
raise raise
def _setup_new_database(cur): def _setup_new_database(cur, database_engine):
"""Sets up the database by finding a base set of "full schemas" and then """Sets up the database by finding a base set of "full schemas" and then
applying any necessary deltas. applying any necessary deltas.
@ -225,31 +229,30 @@ def _setup_new_database(cur):
directory_entries = os.listdir(sql_dir) directory_entries = os.listdir(sql_dir)
sql_script = "BEGIN TRANSACTION;\n"
for filename in fnmatch.filter(directory_entries, "*.sql"): for filename in fnmatch.filter(directory_entries, "*.sql"):
sql_loc = os.path.join(sql_dir, filename) sql_loc = os.path.join(sql_dir, filename)
logger.debug("Applying schema %s", sql_loc) logger.debug("Applying schema %s", sql_loc)
sql_script += read_schema(sql_loc) executescript(cur, sql_loc)
sql_script += "\n"
sql_script += "COMMIT TRANSACTION;"
cur.executescript(sql_script)
cur.execute( cur.execute(
"INSERT OR REPLACE INTO schema_version (version, upgraded)" database_engine.convert_param_style(
" VALUES (?,?)", "INSERT INTO schema_version (version, upgraded)"
(max_current_ver, False) " VALUES (?,?)"
),
(max_current_ver, False,)
) )
_upgrade_existing_database( _upgrade_existing_database(
cur, cur,
current_version=max_current_ver, current_version=max_current_ver,
applied_delta_files=[], applied_delta_files=[],
upgraded=False upgraded=False,
database_engine=database_engine,
) )
def _upgrade_existing_database(cur, current_version, applied_delta_files, def _upgrade_existing_database(cur, current_version, applied_delta_files,
upgraded): upgraded, database_engine):
"""Upgrades an existing database. """Upgrades an existing database.
Delta files can either be SQL stored in *.sql files, or python modules Delta files can either be SQL stored in *.sql files, or python modules
@ -305,6 +308,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
if not upgraded: if not upgraded:
start_ver += 1 start_ver += 1
logger.debug("applied_delta_files: %s", applied_delta_files)
for v in range(start_ver, SCHEMA_VERSION + 1): for v in range(start_ver, SCHEMA_VERSION + 1):
logger.debug("Upgrading schema to v%d", v) logger.debug("Upgrading schema to v%d", v)
@ -321,6 +326,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
directory_entries.sort() directory_entries.sort()
for file_name in directory_entries: for file_name in directory_entries:
relative_path = os.path.join(str(v), file_name) relative_path = os.path.join(str(v), file_name)
logger.debug("Found file: %s", relative_path)
if relative_path in applied_delta_files: if relative_path in applied_delta_files:
continue continue
@ -342,9 +348,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
module.run_upgrade(cur) module.run_upgrade(cur)
elif ext == ".sql": elif ext == ".sql":
# A plain old .sql file, just read and execute it # A plain old .sql file, just read and execute it
delta_schema = read_schema(absolute_path)
logger.debug("Applying schema %s", relative_path) logger.debug("Applying schema %s", relative_path)
cur.executescript(delta_schema) executescript(cur, absolute_path)
else: else:
# Not a valid delta file. # Not a valid delta file.
logger.warn( logger.warn(
@ -356,24 +361,82 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
# Mark as done. # Mark as done.
cur.execute( cur.execute(
database_engine.convert_param_style(
"INSERT INTO applied_schema_deltas (version, file)" "INSERT INTO applied_schema_deltas (version, file)"
" VALUES (?,?)", " VALUES (?,?)",
),
(v, relative_path) (v, relative_path)
) )
cur.execute( cur.execute(
"INSERT OR REPLACE INTO schema_version (version, upgraded)" database_engine.convert_param_style(
"REPLACE INTO schema_version (version, upgraded)"
" VALUES (?,?)", " VALUES (?,?)",
),
(v, True) (v, True)
) )
def _get_or_create_schema_state(txn): def get_statements(f):
statement_buffer = ""
in_comment = False # If we're in a /* ... */ style comment
for line in f:
line = line.strip()
if in_comment:
# Check if this line contains an end to the comment
comments = line.split("*/", 1)
if len(comments) == 1:
continue
line = comments[1]
in_comment = False
# Remove inline block comments
line = re.sub(r"/\*.*\*/", " ", line)
# Does this line start a comment?
comments = line.split("/*", 1)
if len(comments) > 1:
line = comments[0]
in_comment = True
# Deal with line comments
line = line.split("--", 1)[0]
line = line.split("//", 1)[0]
# Find *all* semicolons. We need to treat first and last entry
# specially.
statements = line.split(";")
# We must prepend statement_buffer to the first statement
first_statement = "%s %s" % (
statement_buffer.strip(),
statements[0].strip()
)
statements[0] = first_statement
# Every entry, except the last, is a full statement
for statement in statements[:-1]:
yield statement.strip()
# The last entry did *not* end in a semicolon, so we store it for the
# next semicolon we find
statement_buffer = statements[-1].strip()
def executescript(txn, schema_path):
with open(schema_path, 'r') as f:
for statement in get_statements(f):
txn.execute(statement)
def _get_or_create_schema_state(txn, database_engine):
# Bluntly try creating the schema_version tables.
schema_path = os.path.join( schema_path = os.path.join(
dir_path, "schema", "schema_version.sql", dir_path, "schema", "schema_version.sql",
) )
create_schema = read_schema(schema_path) executescript(txn, schema_path)
txn.executescript(create_schema)
txn.execute("SELECT version, upgraded FROM schema_version") txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone() row = txn.fetchone()
@ -382,10 +445,13 @@ def _get_or_create_schema_state(txn):
if current_version: if current_version:
txn.execute( txn.execute(
"SELECT file FROM applied_schema_deltas WHERE version >= ?", database_engine.convert_param_style(
"SELECT file FROM applied_schema_deltas WHERE version >= ?"
),
(current_version,) (current_version,)
) )
return current_version, txn.fetchall(), upgraded applied_deltas = [d for d, in txn.fetchall()]
return current_version, applied_deltas, upgraded
return None return None
@ -417,7 +483,19 @@ def prepare_sqlite3_database(db_conn):
if row and row[0]: if row and row[0]:
db_conn.execute( db_conn.execute(
"INSERT OR REPLACE INTO schema_version (version, upgraded)" "REPLACE INTO schema_version (version, upgraded)"
" VALUES (?,?)", " VALUES (?,?)",
(row[0], False) (row[0], False)
) )
def are_all_users_on_domain(txn, database_engine, domain):
sql = database_engine.convert_param_style(
"SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
)
pat = "%:" + domain
txn.execute(sql, (pat,))
num_not_matching = txn.fetchall()[0][0]
if num_not_matching == 0:
return True
return False

View file

@ -22,6 +22,8 @@ from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
from synapse.util.lrucache import LruCache from synapse.util.lrucache import LruCache
import synapse.metrics import synapse.metrics
from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer from twisted.internet import defer
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
@ -145,11 +147,12 @@ class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object """An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute() passed to the constructor. Adds logging and metrics to the .execute()
method.""" method."""
__slots__ = ["txn", "name"] __slots__ = ["txn", "name", "database_engine"]
def __init__(self, txn, name): def __init__(self, txn, name, database_engine):
object.__setattr__(self, "txn", txn) object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name) object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine)
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.txn, name) return getattr(self.txn, name)
@ -161,25 +164,31 @@ class LoggingTransaction(object):
# TODO(paul): Maybe use 'info' and 'debug' for values? # TODO(paul): Maybe use 'info' and 'debug' for values?
sql_logger.debug("[SQL] {%s} %s", self.name, sql) sql_logger.debug("[SQL] {%s} %s", self.name, sql)
try: sql = self.database_engine.convert_param_style(sql)
if args and args[0]: if args and args[0]:
values = args[0] args = list(args)
args[0] = [
self.database_engine.encode_parameter(a) for a in args[0]
]
try:
sql_logger.debug( sql_logger.debug(
"[SQL values] {%s} " + ", ".join(("<%r>",) * len(values)), "[SQL values] {%s} " + ", ".join(("<%r>",) * len(args[0])),
self.name, self.name,
*values *args[0]
) )
except: except:
# Don't let logging failures stop SQL from working # Don't let logging failures stop SQL from working
pass pass
start = time.time() * 1000 start = time.time() * 1000
try: try:
return self.txn.execute( return self.txn.execute(
sql, *args, **kwargs sql, *args, **kwargs
) )
except: except Exception as e:
logger.exception("[SQL FAIL] {%s}", self.name) logger.debug("[SQL FAIL] {%s} %s", self.name, e)
raise raise
finally: finally:
msecs = (time.time() * 1000) - start msecs = (time.time() * 1000) - start
@ -245,6 +254,14 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
max_entries=hs.config.event_cache_size) max_entries=hs.config.event_cache_size)
self.database_engine = hs.database_engine
self._stream_id_gen = StreamIdGenerator()
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self)
def start_profiling(self): def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec() self._previous_loop_ts = self._clock.time_msec()
@ -281,8 +298,11 @@ class SQLBaseStore(object):
start_time = time.time() * 1000 start_time = time.time() * 1000
def inner_func(txn, *args, **kwargs): def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context: with LoggingContext("runInteraction") as context:
if self.database_engine.is_connection_closed(conn):
conn.reconnect()
current_context.copy_to(context) current_context.copy_to(context)
start = time.time() * 1000 start = time.time() * 1000
txn_id = self._TXN_ID txn_id = self._TXN_ID
@ -296,9 +316,48 @@ class SQLBaseStore(object):
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
transaction_logger.debug("[TXN START] {%s}", name) transaction_logger.debug("[TXN START] {%s}", name)
try: try:
return func(LoggingTransaction(txn, name), *args, **kwargs) i = 0
except: N = 5
logger.exception("[TXN FAIL] {%s}", name) while True:
try:
txn = conn.cursor()
return func(
LoggingTransaction(txn, name, self.database_engine),
*args, **kwargs
)
except self.database_engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
logger.warn(
"[TXN OPERROR] {%s} %s %d/%d",
name, e, i, N
)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
raise
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", name, e)
raise raise
finally: finally:
end = time.time() * 1000 end = time.time() * 1000
@ -311,7 +370,7 @@ class SQLBaseStore(object):
sql_txn_timer.inc_by(duration, desc) sql_txn_timer.inc_by(duration, desc)
with PreserveLoggingContext(): with PreserveLoggingContext():
result = yield self._db_pool.runInteraction( result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs inner_func, *args, **kwargs
) )
defer.returnValue(result) defer.returnValue(result)
@ -342,11 +401,11 @@ class SQLBaseStore(object):
The result of decoder(results) The result of decoder(results)
""" """
def interaction(txn): def interaction(txn):
cursor = txn.execute(query, args) txn.execute(query, args)
if decoder: if decoder:
return decoder(cursor) return decoder(txn)
else: else:
return cursor.fetchall() return txn.fetchall()
return self.runInteraction(desc, interaction) return self.runInteraction(desc, interaction)
@ -356,27 +415,29 @@ class SQLBaseStore(object):
# "Simple" SQL API methods that operate on a single table with no JOINs, # "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns. # no complex WHERE clauses, just a dict of values for columns.
def _simple_insert(self, table, values, or_replace=False, or_ignore=False, @defer.inlineCallbacks
def _simple_insert(self, table, values, or_ignore=False,
desc="_simple_insert"): desc="_simple_insert"):
"""Executes an INSERT query on the named table. """Executes an INSERT query on the named table.
Args: Args:
table : string giving the table name table : string giving the table name
values : dict of new column names and values for them values : dict of new column names and values for them
or_replace : bool; if True performs an INSERT OR REPLACE
""" """
return self.runInteraction( try:
yield self.runInteraction(
desc, desc,
self._simple_insert_txn, table, values, or_replace=or_replace, self._simple_insert_txn, table, values,
or_ignore=or_ignore,
) )
except self.database_engine.module.IntegrityError:
# We have to do or_ignore flag at this layer, since we can't reuse
# a cursor after we receive an error from the db.
if not or_ignore:
raise
@log_function @log_function
def _simple_insert_txn(self, txn, table, values, or_replace=False, def _simple_insert_txn(self, txn, table, values):
or_ignore=False): sql = "INSERT INTO %s (%s) VALUES(%s)" % (
sql = "%s INTO %s (%s) VALUES(%s)" % (
("INSERT OR REPLACE" if or_replace else
"INSERT OR IGNORE" if or_ignore else "INSERT"),
table, table,
", ".join(k for k in values), ", ".join(k for k in values),
", ".join("?" for k in values) ", ".join("?" for k in values)
@ -388,22 +449,26 @@ class SQLBaseStore(object):
) )
txn.execute(sql, values.values()) txn.execute(sql, values.values())
return txn.lastrowid
def _simple_upsert(self, table, keyvalues, values, desc="_simple_upsert"): def _simple_upsert(self, table, keyvalues, values,
insertion_values={}, desc="_simple_upsert"):
""" """
Args: Args:
table (str): The table to upsert into table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values values (dict): The nonunique columns and their new values
insertion_values (dict): key/values to use when inserting
Returns: A deferred Returns: A deferred
""" """
return self.runInteraction( return self.runInteraction(
desc, desc,
self._simple_upsert_txn, table, keyvalues, values self._simple_upsert_txn, table, keyvalues, values, insertion_values,
) )
def _simple_upsert_txn(self, txn, table, keyvalues, values): def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={}):
# We need to lock the table :(
self.database_engine.lock_table(txn, table)
# Try to update # Try to update
sql = "UPDATE %s SET %s WHERE %s" % ( sql = "UPDATE %s SET %s WHERE %s" % (
table, table,
@ -422,6 +487,7 @@ class SQLBaseStore(object):
allvalues = {} allvalues = {}
allvalues.update(keyvalues) allvalues.update(keyvalues)
allvalues.update(values) allvalues.update(values)
allvalues.update(insertion_values)
sql = "INSERT INTO %s (%s) VALUES (%s)" % ( sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table, table,
@ -489,8 +555,7 @@ class SQLBaseStore(object):
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol): def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
sql = ( sql = (
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s " "SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
"ORDER BY rowid asc"
) % { ) % {
"retcol": retcol, "retcol": retcol,
"table": table, "table": table,
@ -548,14 +613,14 @@ class SQLBaseStore(object):
retcols : list of strings giving the names of the columns to return retcols : list of strings giving the names of the columns to return
""" """
if keyvalues: if keyvalues:
sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % ( sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols), ", ".join(retcols),
table, table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues) " AND ".join("%s = ?" % (k, ) for k in keyvalues)
) )
txn.execute(sql, keyvalues.values()) txn.execute(sql, keyvalues.values())
else: else:
sql = "SELECT %s FROM %s ORDER BY rowid asc" % ( sql = "SELECT %s FROM %s" % (
", ".join(retcols), ", ".join(retcols),
table table
) )
@ -607,10 +672,10 @@ class SQLBaseStore(object):
def _simple_select_one_txn(self, txn, table, keyvalues, retcols, def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
allow_none=False): allow_none=False):
select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % ( select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols), ", ".join(retcols),
table, table,
" AND ".join("%s = ?" % (k) for k in keyvalues) " AND ".join("%s = ?" % (k,) for k in keyvalues)
) )
txn.execute(select_sql, keyvalues.values()) txn.execute(select_sql, keyvalues.values())
@ -648,6 +713,11 @@ class SQLBaseStore(object):
updatevalues=updatevalues, updatevalues=updatevalues,
) )
# if txn.rowcount == 0:
# raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
return ret return ret
return self.runInteraction(desc, func) return self.runInteraction(desc, func)
@ -860,6 +930,12 @@ class SQLBaseStore(object):
result = txn.fetchone() result = txn.fetchone()
return result[0] if result else None return result[0] if result else None
def get_next_stream_id(self):
with self._next_stream_id_lock:
i = self._next_stream_id
self._next_stream_id += 1
return i
class _RollbackButIsFineException(Exception): class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying """ This exception is used to rollback a transaction without implying
@ -883,7 +959,7 @@ class Table(object):
_select_where_clause = "SELECT %s FROM %s WHERE %s" _select_where_clause = "SELECT %s FROM %s WHERE %s"
_select_clause = "SELECT %s FROM %s" _select_clause = "SELECT %s FROM %s"
_insert_clause = "INSERT OR REPLACE INTO %s (%s) VALUES (%s)" _insert_clause = "REPLACE INTO %s (%s) VALUES (%s)"
@classmethod @classmethod
def select_statement(cls, where_clause=None): def select_statement(cls, where_clause=None):

View file

@ -366,11 +366,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
new_txn_id = max(highest_txn_id, last_txn_id) + 1 new_txn_id = max(highest_txn_id, last_txn_id) + 1
# Insert new txn into txn table # Insert new txn into txn table
event_ids = [e.event_id for e in events] event_ids = json.dumps([e.event_id for e in events])
txn.execute( txn.execute(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) " "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)", "VALUES(?,?,?)",
(service.id, new_txn_id, json.dumps(event_ids)) (service.id, new_txn_id, event_ids)
) )
return AppServiceTransaction( return AppServiceTransaction(
service=service, id=new_txn_id, events=events service=service, id=new_txn_id, events=events

View file

@ -21,8 +21,6 @@ from twisted.internet import defer
from collections import namedtuple from collections import namedtuple
import sqlite3
RoomAliasMapping = namedtuple( RoomAliasMapping = namedtuple(
"RoomAliasMapping", "RoomAliasMapping",
@ -91,7 +89,7 @@ class DirectoryStore(SQLBaseStore):
}, },
desc="create_room_alias_association", desc="create_room_alias_association",
) )
except sqlite3.IntegrityError: except self.database_engine.module.IntegrityError:
raise SynapseError( raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string() 409, "Room alias %s already exists" % room_alias.to_string()
) )
@ -120,12 +118,12 @@ class DirectoryStore(SQLBaseStore):
defer.returnValue(room_id) defer.returnValue(room_id)
def _delete_room_alias_txn(self, txn, room_alias): def _delete_room_alias_txn(self, txn, room_alias):
cursor = txn.execute( txn.execute(
"SELECT room_id FROM room_aliases WHERE room_alias = ?", "SELECT room_id FROM room_aliases WHERE room_alias = ?",
(room_alias.to_string(),) (room_alias.to_string(),)
) )
res = cursor.fetchone() res = txn.fetchone()
if res: if res:
room_id = res[0] room_id = res[0]
else: else:

View file

@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .postgres import PostgresEngine
from .sqlite3 import Sqlite3Engine
import importlib
SUPPORTED_MODULE = {
"sqlite3": Sqlite3Engine,
"psycopg2": PostgresEngine,
}
def create_engine(name):
engine_class = SUPPORTED_MODULE.get(name, None)
if engine_class:
module = importlib.import_module(name)
return engine_class(module)
raise RuntimeError(
"Unsupported database engine '%s'" % (name,)
)

View file

@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage import prepare_database
class PostgresEngine(object):
def __init__(self, database_module):
self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE)
def convert_param_style(self, sql):
return sql.replace("?", "%s")
def encode_parameter(self, param):
return param
def on_new_connection(self, db_conn):
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
def prepare_database(self, db_conn):
prepare_database(db_conn, self)
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
return error.pgcode in ["40001", "40P01"]
return False
def is_connection_closed(self, conn):
return bool(conn)
def lock_table(self, txn, table):
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))

View file

@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage import prepare_database, prepare_sqlite3_database
class Sqlite3Engine(object):
def __init__(self, database_module):
self.module = database_module
def convert_param_style(self, sql):
return sql
def encode_parameter(self, param):
return param
def on_new_connection(self, db_conn):
self.prepare_database(db_conn)
def prepare_database(self, db_conn):
prepare_sqlite3_database(db_conn)
prepare_database(db_conn, self)
def is_deadlock(self, error):
return False
def is_connection_closed(self, conn):
return False
def lock_table(self, txn, table):
return

View file

@ -153,7 +153,7 @@ class EventFederationStore(SQLBaseStore):
results = self._get_prev_events_and_state( results = self._get_prev_events_and_state(
txn, txn,
event_id, event_id,
is_state=1, is_state=True,
) )
return [(e_id, h, ) for e_id, h, _ in results] return [(e_id, h, ) for e_id, h, _ in results]
@ -164,7 +164,7 @@ class EventFederationStore(SQLBaseStore):
} }
if is_state is not None: if is_state is not None:
keyvalues["is_state"] = is_state keyvalues["is_state"] = bool(is_state)
res = self._simple_select_list_txn( res = self._simple_select_list_txn(
txn, txn,
@ -242,7 +242,6 @@ class EventFederationStore(SQLBaseStore):
"room_id": room_id, "room_id": room_id,
"min_depth": depth, "min_depth": depth,
}, },
or_replace=True,
) )
def _handle_prev_events(self, txn, outlier, event_id, prev_events, def _handle_prev_events(self, txn, outlier, event_id, prev_events,
@ -260,9 +259,8 @@ class EventFederationStore(SQLBaseStore):
"event_id": event_id, "event_id": event_id,
"prev_event_id": e_id, "prev_event_id": e_id,
"room_id": room_id, "room_id": room_id,
"is_state": 0, "is_state": False,
}, },
or_ignore=True,
) )
# Update the extremities table if this is not an outlier. # Update the extremities table if this is not an outlier.
@ -281,19 +279,19 @@ class EventFederationStore(SQLBaseStore):
# We only insert as a forward extremity the new event if there are # We only insert as a forward extremity the new event if there are
# no other events that reference it as a prev event # no other events that reference it as a prev event
query = ( query = (
"INSERT OR IGNORE INTO %(table)s (event_id, room_id) " "SELECT 1 FROM event_edges WHERE prev_event_id = ?"
"SELECT ?, ? WHERE NOT EXISTS (" )
"SELECT 1 FROM %(event_edges)s WHERE "
"prev_event_id = ? "
")"
) % {
"table": "event_forward_extremities",
"event_edges": "event_edges",
}
logger.debug("query: %s", query) txn.execute(query, (event_id,))
txn.execute(query, (event_id, room_id, event_id)) if not txn.fetchone():
query = (
"INSERT INTO event_forward_extremities"
" (event_id, room_id)"
" VALUES (?, ?)"
)
txn.execute(query, (event_id, room_id))
# Insert all the prev_events as a backwards thing, they'll get # Insert all the prev_events as a backwards thing, they'll get
# deleted in a second if they're incorrect anyway. # deleted in a second if they're incorrect anyway.
@ -306,7 +304,6 @@ class EventFederationStore(SQLBaseStore):
"event_id": e_id, "event_id": e_id,
"room_id": room_id, "room_id": room_id,
}, },
or_ignore=True,
) )
# Also delete from the backwards extremities table all ones that # Also delete from the backwards extremities table all ones that
@ -400,7 +397,7 @@ class EventFederationStore(SQLBaseStore):
query = ( query = (
"SELECT prev_event_id FROM event_edges " "SELECT prev_event_id FROM event_edges "
"WHERE room_id = ? AND event_id = ? AND is_state = 0 " "WHERE room_id = ? AND event_id = ? AND is_state = ? "
"LIMIT ?" "LIMIT ?"
) )
@ -409,7 +406,7 @@ class EventFederationStore(SQLBaseStore):
for event_id in front: for event_id in front:
txn.execute( txn.execute(
query, query,
(room_id, event_id, limit - len(event_results)) (room_id, event_id, False, limit - len(event_results))
) )
for e_id, in txn.fetchall(): for e_id, in txn.fetchall():

View file

@ -52,7 +52,6 @@ class EventsStore(SQLBaseStore):
is_new_state=is_new_state, is_new_state=is_new_state,
current_state=current_state, current_state=current_state,
) )
self.get_room_events_max_id.invalidate()
except _RollbackButIsFineException: except _RollbackButIsFineException:
pass pass
@ -96,12 +95,22 @@ class EventsStore(SQLBaseStore):
# Remove the any existing cache entries for the event_id # Remove the any existing cache entries for the event_id
self._invalidate_get_event_cache(event.event_id) self._invalidate_get_event_cache(event.event_id)
if stream_ordering is None:
with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
return self._persist_event_txn(
txn, event, context, backfilled,
stream_ordering=stream_ordering,
is_new_state=is_new_state,
current_state=current_state,
)
# We purposefully do this first since if we include a `current_state` # We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table # key, we *want* to update the `current_state_events` table
if current_state: if current_state:
txn.execute( self._simple_delete_txn(
"DELETE FROM current_state_events WHERE room_id = ?", txn,
(event.room_id,) table="current_state_events",
keyvalues={"room_id": event.room_id},
) )
for s in current_state: for s in current_state:
@ -114,7 +123,6 @@ class EventsStore(SQLBaseStore):
"type": s.type, "type": s.type,
"state_key": s.state_key, "state_key": s.state_key,
}, },
or_replace=True,
) )
if event.is_state() and is_new_state: if event.is_state() and is_new_state:
@ -128,7 +136,6 @@ class EventsStore(SQLBaseStore):
"type": event.type, "type": event.type,
"state_key": event.state_key, "state_key": event.state_key,
}, },
or_replace=True,
) )
for prev_state_id, _ in event.prev_state: for prev_state_id, _ in event.prev_state:
@ -151,14 +158,6 @@ class EventsStore(SQLBaseStore):
event.depth event.depth
) )
self._handle_prev_events(
txn,
outlier=outlier,
event_id=event.event_id,
prev_events=event.prev_events,
room_id=event.room_id,
)
have_persisted = self._simple_select_one_onecol_txn( have_persisted = self._simple_select_one_onecol_txn(
txn, txn,
table="event_json", table="event_json",
@ -169,7 +168,7 @@ class EventsStore(SQLBaseStore):
metadata_json = encode_canonical_json( metadata_json = encode_canonical_json(
event.internal_metadata.get_dict() event.internal_metadata.get_dict()
) ).decode("UTF-8")
# If we have already persisted this event, we don't need to do any # If we have already persisted this event, we don't need to do any
# more processing. # more processing.
@ -185,23 +184,29 @@ class EventsStore(SQLBaseStore):
) )
txn.execute( txn.execute(
sql, sql,
(metadata_json.decode("UTF-8"), event.event_id,) (metadata_json, event.event_id,)
) )
sql = ( sql = (
"UPDATE events SET outlier = 0" "UPDATE events SET outlier = ?"
" WHERE event_id = ?" " WHERE event_id = ?"
) )
txn.execute( txn.execute(
sql, sql,
(event.event_id,) (False, event.event_id,)
) )
return return
self._handle_prev_events(
txn,
outlier=outlier,
event_id=event.event_id,
prev_events=event.prev_events,
room_id=event.room_id,
)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
self._store_room_member_txn(txn, event) self._store_room_member_txn(txn, event)
elif event.type == EventTypes.Feedback:
self._store_feedback_txn(txn, event)
elif event.type == EventTypes.Name: elif event.type == EventTypes.Name:
self._store_room_name_txn(txn, event) self._store_room_name_txn(txn, event)
elif event.type == EventTypes.Topic: elif event.type == EventTypes.Topic:
@ -224,10 +229,9 @@ class EventsStore(SQLBaseStore):
values={ values={
"event_id": event.event_id, "event_id": event.event_id,
"room_id": event.room_id, "room_id": event.room_id,
"internal_metadata": metadata_json.decode("UTF-8"), "internal_metadata": metadata_json,
"json": encode_canonical_json(event_dict).decode("UTF-8"), "json": encode_canonical_json(event_dict).decode("UTF-8"),
}, },
or_replace=True,
) )
content = encode_canonical_json( content = encode_canonical_json(
@ -245,9 +249,6 @@ class EventsStore(SQLBaseStore):
"depth": event.depth, "depth": event.depth,
} }
if stream_ordering is not None:
vals["stream_ordering"] = stream_ordering
unrec = { unrec = {
k: v k: v
for k, v in event.get_dict().items() for k, v in event.get_dict().items()
@ -264,67 +265,24 @@ class EventsStore(SQLBaseStore):
unrec unrec
).decode("UTF-8") ).decode("UTF-8")
try: sql = (
self._simple_insert_txn( "INSERT INTO events"
txn, " (stream_ordering, topological_ordering, event_id, type,"
"events", " room_id, content, processed, outlier, depth)"
vals, " VALUES (?,?,?,?,?,?,?,?,?)"
or_replace=(not outlier), )
or_ignore=bool(outlier),
txn.execute(
sql,
(
stream_ordering, event.depth, event.event_id, event.type,
event.room_id, content, True, outlier, event.depth
) )
except:
logger.warn(
"Failed to persist, probably duplicate: %s",
event.event_id,
exc_info=True,
) )
raise _RollbackButIsFineException("_persist_event")
if context.rejected: if context.rejected:
self._store_rejections_txn(txn, event.event_id, context.rejected) self._store_rejections_txn(txn, event.event_id, context.rejected)
if event.is_state():
vals = {
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
}
# TODO: How does this work with backfilling?
if hasattr(event, "replaces_state"):
vals["prev_state"] = event.replaces_state
self._simple_insert_txn(
txn,
"state_events",
vals,
)
if is_new_state and not context.rejected:
self._simple_insert_txn(
txn,
"current_state_events",
{
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
)
for e_id, h in event.prev_state:
self._simple_insert_txn(
txn,
table="event_edges",
values={
"event_id": event.event_id,
"prev_event_id": e_id,
"room_id": event.room_id,
"is_state": 1,
},
)
for hash_alg, hash_base64 in event.hashes.items(): for hash_alg, hash_base64 in event.hashes.items():
hash_bytes = decode_base64(hash_base64) hash_bytes = decode_base64(hash_base64)
self._store_event_content_hash_txn( self._store_event_content_hash_txn(
@ -354,6 +312,50 @@ class EventsStore(SQLBaseStore):
txn, event.event_id, ref_alg, ref_hash_bytes txn, event.event_id, ref_alg, ref_hash_bytes
) )
if event.is_state():
vals = {
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
}
# TODO: How does this work with backfilling?
if hasattr(event, "replaces_state"):
vals["prev_state"] = event.replaces_state
self._simple_insert_txn(
txn,
"state_events",
vals,
)
for e_id, h in event.prev_state:
self._simple_insert_txn(
txn,
table="event_edges",
values={
"event_id": event.event_id,
"prev_event_id": e_id,
"room_id": event.room_id,
"is_state": True,
},
)
if is_new_state and not context.rejected:
self._simple_upsert_txn(
txn,
"current_state_events",
keyvalues={
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
values={
"event_id": event.event_id,
}
)
def _store_redaction(self, txn, event): def _store_redaction(self, txn, event):
# invalidate the cache for the redacted event # invalidate the cache for the redacted event
self._invalidate_get_event_cache(event.redacts) self._invalidate_get_event_cache(event.redacts)

View file

@ -57,16 +57,18 @@ class KeyStore(SQLBaseStore):
OpenSSL.crypto.FILETYPE_ASN1, tls_certificate OpenSSL.crypto.FILETYPE_ASN1, tls_certificate
) )
fingerprint = hashlib.sha256(tls_certificate_bytes).hexdigest() fingerprint = hashlib.sha256(tls_certificate_bytes).hexdigest()
return self._simple_insert( return self._simple_upsert(
table="server_tls_certificates", table="server_tls_certificates",
values={ keyvalues={
"server_name": server_name, "server_name": server_name,
"fingerprint": fingerprint, "fingerprint": fingerprint,
},
values={
"from_server": from_server, "from_server": from_server,
"ts_added_ms": time_now_ms, "ts_added_ms": time_now_ms,
"tls_certificate": buffer(tls_certificate_bytes), "tls_certificate": buffer(tls_certificate_bytes),
}, },
or_ignore=True, desc="store_server_certificate",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -107,16 +109,18 @@ class KeyStore(SQLBaseStore):
ts_now_ms (int): The time now in milliseconds ts_now_ms (int): The time now in milliseconds
verification_key (VerifyKey): The NACL verify key. verification_key (VerifyKey): The NACL verify key.
""" """
return self._simple_insert( return self._simple_upsert(
table="server_signature_keys", table="server_signature_keys",
values={ keyvalues={
"server_name": server_name, "server_name": server_name,
"key_id": "%s:%s" % (verify_key.alg, verify_key.version), "key_id": "%s:%s" % (verify_key.alg, verify_key.version),
},
values={
"from_server": from_server, "from_server": from_server,
"ts_added_ms": time_now_ms, "ts_added_ms": time_now_ms,
"verify_key": buffer(verify_key.encode()), "verify_key": buffer(verify_key.encode()),
}, },
or_ignore=True, desc="store_server_verify_key",
) )
def store_server_keys_json(self, server_name, key_id, from_server, def store_server_keys_json(self, server_name, key_id, from_server,

View file

@ -57,6 +57,7 @@ class PresenceStore(SQLBaseStore):
values={"observed_user_id": observed_localpart, values={"observed_user_id": observed_localpart,
"observer_user_id": observer_userid}, "observer_user_id": observer_userid},
desc="allow_presence_visible", desc="allow_presence_visible",
or_ignore=True,
) )
def disallow_presence_visible(self, observed_localpart, observer_userid): def disallow_presence_visible(self, observed_localpart, observer_userid):

View file

@ -154,7 +154,7 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_name, priority_class, new_rule_priority)) txn.execute(sql, (user_name, priority_class, new_rule_priority))
# now insert the new rule # now insert the new rule
sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" (" sql = "INSERT INTO "+PushRuleTable.table_name+" ("
sql += ",".join(new_rule.keys())+") VALUES (" sql += ",".join(new_rule.keys())+") VALUES ("
sql += ", ".join(["?" for _ in new_rule.keys()])+")" sql += ", ".join(["?" for _ in new_rule.keys()])+")"
@ -183,7 +183,7 @@ class PushRuleStore(SQLBaseStore):
new_rule['priority_class'] = priority_class new_rule['priority_class'] = priority_class
new_rule['priority'] = new_prio new_rule['priority'] = new_prio
sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" (" sql = "INSERT INTO "+PushRuleTable.table_name+" ("
sql += ",".join(new_rule.keys())+") VALUES (" sql += ",".join(new_rule.keys())+") VALUES ("
sql += ", ".join(["?" for _ in new_rule.keys()])+")" sql += ", ".join(["?" for _ in new_rule.keys()])+")"

View file

@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
from ._base import SQLBaseStore, Table from ._base import SQLBaseStore, Table
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from syutil.jsonutil import encode_canonical_json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -27,93 +27,55 @@ logger = logging.getLogger(__name__)
class PusherStore(SQLBaseStore): class PusherStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_pushers_by_app_id_and_pushkey(self, app_id_and_pushkey): def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
sql = ( sql = (
"SELECT id, user_name, kind, profile_tag, app_id," "SELECT * FROM pushers "
"app_display_name, device_display_name, pushkey, ts, data, "
"last_token, last_success, failing_since "
"FROM pushers "
"WHERE app_id = ? AND pushkey = ?" "WHERE app_id = ? AND pushkey = ?"
) )
rows = yield self._execute( rows = yield self._execute_and_decode(
"get_pushers_by_app_id_and_pushkey", None, sql, "get_pushers_by_app_id_and_pushkey",
app_id_and_pushkey[0], app_id_and_pushkey[1] sql,
app_id, pushkey
) )
ret = [ defer.returnValue(rows)
{
"id": r[0],
"user_name": r[1],
"kind": r[2],
"profile_tag": r[3],
"app_id": r[4],
"app_display_name": r[5],
"device_display_name": r[6],
"pushkey": r[7],
"pushkey_ts": r[8],
"data": r[9],
"last_token": r[10],
"last_success": r[11],
"failing_since": r[12]
}
for r in rows
]
defer.returnValue(ret[0])
@defer.inlineCallbacks @defer.inlineCallbacks
def get_all_pushers(self): def get_all_pushers(self):
sql = ( sql = (
"SELECT id, user_name, kind, profile_tag, app_id," "SELECT * FROM pushers"
"app_display_name, device_display_name, pushkey, ts, data, "
"last_token, last_success, failing_since "
"FROM pushers"
) )
rows = yield self._execute("get_all_pushers", None, sql) rows = yield self._execute_and_decode("get_all_pushers", sql)
ret = [ defer.returnValue(rows)
{
"id": r[0],
"user_name": r[1],
"kind": r[2],
"profile_tag": r[3],
"app_id": r[4],
"app_display_name": r[5],
"device_display_name": r[6],
"pushkey": r[7],
"pushkey_ts": r[8],
"data": r[9],
"last_token": r[10],
"last_success": r[11],
"failing_since": r[12]
}
for r in rows
]
defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_pusher(self, user_name, profile_tag, kind, app_id, def add_pusher(self, user_name, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name, app_display_name, device_display_name,
pushkey, pushkey_ts, lang, data): pushkey, pushkey_ts, lang, data):
try: try:
next_id = yield self._pushers_id_gen.get_next()
yield self._simple_upsert( yield self._simple_upsert(
PushersTable.table_name, PushersTable.table_name,
dict( dict(
app_id=app_id, app_id=app_id,
pushkey=pushkey, pushkey=pushkey,
user_name=user_name,
), ),
dict( dict(
user_name=user_name, access_token=access_token,
kind=kind, kind=kind,
profile_tag=profile_tag, profile_tag=profile_tag,
app_display_name=app_display_name, app_display_name=app_display_name,
device_display_name=device_display_name, device_display_name=device_display_name,
ts=pushkey_ts, ts=pushkey_ts,
lang=lang, lang=lang,
data=data data=encode_canonical_json(data).decode("UTF-8"),
),
insertion_values=dict(
id=next_id,
), ),
desc="add_pusher", desc="add_pusher",
) )
@ -122,37 +84,38 @@ class PusherStore(SQLBaseStore):
raise StoreError(500, "Problem creating pusher.") raise StoreError(500, "Problem creating pusher.")
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_pusher_by_app_id_pushkey(self, app_id, pushkey): def delete_pusher_by_app_id_pushkey_user_name(self, app_id, pushkey, user_name):
yield self._simple_delete_one( yield self._simple_delete_one(
PushersTable.table_name, PushersTable.table_name,
{"app_id": app_id, "pushkey": pushkey}, {"app_id": app_id, "pushkey": pushkey, 'user_name': user_name},
desc="delete_pusher_by_app_id_pushkey", desc="delete_pusher_by_app_id_pushkey_user_name",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def update_pusher_last_token(self, app_id, pushkey, last_token): def update_pusher_last_token(self, app_id, pushkey, user_name, last_token):
yield self._simple_update_one( yield self._simple_update_one(
PushersTable.table_name, PushersTable.table_name,
{'app_id': app_id, 'pushkey': pushkey}, {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name},
{'last_token': last_token}, {'last_token': last_token},
desc="update_pusher_last_token", desc="update_pusher_last_token",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def update_pusher_last_token_and_success(self, app_id, pushkey, def update_pusher_last_token_and_success(self, app_id, pushkey, user_name,
last_token, last_success): last_token, last_success):
yield self._simple_update_one( yield self._simple_update_one(
PushersTable.table_name, PushersTable.table_name,
{'app_id': app_id, 'pushkey': pushkey}, {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name},
{'last_token': last_token, 'last_success': last_success}, {'last_token': last_token, 'last_success': last_success},
desc="update_pusher_last_token_and_success", desc="update_pusher_last_token_and_success",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def update_pusher_failing_since(self, app_id, pushkey, failing_since): def update_pusher_failing_since(self, app_id, pushkey, user_name,
failing_since):
yield self._simple_update_one( yield self._simple_update_one(
PushersTable.table_name, PushersTable.table_name,
{'app_id': app_id, 'pushkey': pushkey}, {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name},
{'failing_since': failing_since}, {'failing_since': failing_since},
desc="update_pusher_failing_since", desc="update_pusher_failing_since",
) )
@ -160,21 +123,3 @@ class PusherStore(SQLBaseStore):
class PushersTable(Table): class PushersTable(Table):
table_name = "pushers" table_name = "pushers"
fields = [
"id",
"user_name",
"kind",
"profile_tag",
"app_id",
"app_display_name",
"device_display_name",
"pushkey",
"pushkey_ts",
"data",
"last_token",
"last_success",
"failing_since"
]
EntryType = collections.namedtuple("PusherEntry", fields)

View file

@ -15,8 +15,6 @@
from twisted.internet import defer from twisted.internet import defer
from sqlite3 import IntegrityError
from synapse.api.errors import StoreError, Codes from synapse.api.errors import StoreError, Codes
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore, cached
@ -39,17 +37,13 @@ class RegistrationStore(SQLBaseStore):
Raises: Raises:
StoreError if there was a problem adding this. StoreError if there was a problem adding this.
""" """
row = yield self._simple_select_one( next_id = yield self._access_tokens_id_gen.get_next()
"users", {"name": user_id}, ["id"],
desc="add_access_token_to_user",
)
if not row:
raise StoreError(400, "Bad user ID supplied.")
row_id = row["id"]
yield self._simple_insert( yield self._simple_insert(
"access_tokens", "access_tokens",
{ {
"user_id": row_id, "id": next_id,
"user_id": user_id,
"token": token "token": token
}, },
desc="add_access_token_to_user", desc="add_access_token_to_user",
@ -74,32 +68,71 @@ class RegistrationStore(SQLBaseStore):
def _register(self, txn, user_id, token, password_hash): def _register(self, txn, user_id, token, password_hash):
now = int(self.clock.time()) now = int(self.clock.time())
next_id = self._access_tokens_id_gen.get_next_txn(txn)
try: try:
txn.execute("INSERT INTO users(name, password_hash, creation_ts) " txn.execute("INSERT INTO users(name, password_hash, creation_ts) "
"VALUES (?,?,?)", "VALUES (?,?,?)",
[user_id, password_hash, now]) [user_id, password_hash, now])
except IntegrityError: except self.database_engine.module.IntegrityError:
raise StoreError( raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE 400, "User ID already taken.", errcode=Codes.USER_IN_USE
) )
# it's possible for this to get a conflict, but only for a single user # it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID # since tokens are namespaced based on their user ID
txn.execute("INSERT INTO access_tokens(user_id, token) " + txn.execute(
"VALUES (?,?)", [txn.lastrowid, token]) "INSERT INTO access_tokens(id, user_id, token)"
" VALUES (?,?,?)",
def get_user_by_id(self, user_id): (next_id, user_id, token,)
query = ("SELECT users.name, users.password_hash FROM users"
" WHERE users.name = ?")
return self._execute(
"get_user_by_id", self.cursor_to_dict, query, user_id
) )
def get_user_by_id(self, user_id):
return self._simple_select_one(
table="users",
keyvalues={
"name": user_id,
},
retcols=["name", "password_hash"],
allow_none=True,
)
@defer.inlineCallbacks
def user_set_password_hash(self, user_id, password_hash):
"""
NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be
pointless. Use flush_user separately.
"""
yield self._simple_update_one('users', {
'name': user_id
}, {
'password_hash': password_hash
})
@defer.inlineCallbacks
def user_delete_access_tokens_apart_from(self, user_id, token_id):
rows = yield self.get_user_by_id(user_id)
if len(rows) == 0:
raise Exception("No such user!")
yield self._execute(
"delete_access_tokens_apart_from", None,
"DELETE FROM access_tokens WHERE user_id = ? AND id != ?",
rows[0]['id'], token_id
)
@defer.inlineCallbacks
def flush_user(self, user_id):
rows = yield self._execute(
'flush_user', None,
"SELECT token FROM access_tokens WHERE user_id = ?",
user_id
)
for r in rows:
self.get_user_by_token.invalidate(r)
@cached() @cached()
# TODO(paul): Currently there's no code to invalidate this cache. That
# means if/when we ever add internal ways to invalidate access tokens or
# change whether a user is a server admin, those will need to invoke
# store.get_user_by_token.invalidate(token)
def get_user_by_token(self, token): def get_user_by_token(self, token):
"""Get a user from the given access token. """Get a user from the given access token.
@ -134,13 +167,49 @@ class RegistrationStore(SQLBaseStore):
"SELECT users.name, users.admin," "SELECT users.name, users.admin,"
" access_tokens.device_id, access_tokens.id as token_id" " access_tokens.device_id, access_tokens.id as token_id"
" FROM users" " FROM users"
" INNER JOIN access_tokens on users.id = access_tokens.user_id" " INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?" " WHERE token = ?"
) )
cursor = txn.execute(sql, (token,)) txn.execute(sql, (token,))
rows = self.cursor_to_dict(cursor) rows = self.cursor_to_dict(txn)
if rows: if rows:
return rows[0] return rows[0]
raise StoreError(404, "Token not found.") return None
@defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
yield self._simple_upsert("user_threepids", {
"user": user_id,
"medium": medium,
"address": address,
}, {
"validated_at": validated_at,
"added_at": added_at,
})
@defer.inlineCallbacks
def user_get_threepids(self, user_id):
ret = yield self._simple_select_list(
"user_threepids", {
"user": user_id
},
['medium', 'address', 'validated_at', 'added_at'],
'user_get_threepids'
)
defer.returnValue(ret)
@defer.inlineCallbacks
def get_user_by_threepid(self, medium, address):
ret = yield self._simple_select_one(
"user_threepids",
{
"medium": medium,
"address": address
},
['user'], True, 'get_user_by_threepid'
)
if ret:
defer.returnValue(ret['user'])
defer.returnValue(None)

View file

@ -72,6 +72,7 @@ class RoomStore(SQLBaseStore):
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
retcols=RoomsTable.fields, retcols=RoomsTable.fields,
desc="get_room", desc="get_room",
allow_none=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -102,24 +103,37 @@ class RoomStore(SQLBaseStore):
"ON c.event_id = room_names.event_id " "ON c.event_id = room_names.event_id "
) )
# We use non printing ascii character US () as a seperator # We use non printing ascii character US (\x1F) as a separator
sql = ( sql = (
"SELECT r.room_id, n.name, t.topic, " "SELECT r.room_id, max(n.name), max(t.topic)"
"group_concat(a.room_alias, '') " " FROM rooms AS r"
"FROM rooms AS r " " LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id"
"LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id " " LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id"
"LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id " " WHERE r.is_public = ?"
"INNER JOIN room_aliases AS a ON a.room_id = r.room_id " " GROUP BY r.room_id"
"WHERE r.is_public = ? "
"GROUP BY r.room_id "
) % { ) % {
"topic": topic_subquery, "topic": topic_subquery,
"name": name_subquery, "name": name_subquery,
} }
c = txn.execute(sql, (is_public,)) txn.execute(sql, (is_public,))
return c.fetchall() rows = txn.fetchall()
for i, row in enumerate(rows):
room_id = row[0]
aliases = self._simple_select_onecol_txn(
txn,
table="room_aliases",
keyvalues={
"room_id": room_id
},
retcol="room_alias",
)
rows[i] = list(row) + [aliases]
return rows
rows = yield self.runInteraction( rows = yield self.runInteraction(
"get_rooms", f "get_rooms", f
@ -130,9 +144,10 @@ class RoomStore(SQLBaseStore):
"room_id": r[0], "room_id": r[0],
"name": r[1], "name": r[1],
"topic": r[2], "topic": r[2],
"aliases": r[3].split(""), "aliases": r[3],
} }
for r in rows for r in rows
if r[3] # We only return rooms that have at least one alias.
] ]
defer.returnValue(ret) defer.returnValue(ret)

View file

@ -40,7 +40,6 @@ class RoomMemberStore(SQLBaseStore):
""" """
try: try:
target_user_id = event.state_key target_user_id = event.state_key
domain = UserID.from_string(target_user_id).domain
except: except:
logger.exception( logger.exception(
"Failed to parse target_user_id=%s", target_user_id "Failed to parse target_user_id=%s", target_user_id
@ -65,42 +64,8 @@ class RoomMemberStore(SQLBaseStore):
} }
) )
# Update room hosts table
if event.membership == Membership.JOIN:
sql = (
"INSERT OR IGNORE INTO room_hosts (room_id, host) "
"VALUES (?, ?)"
)
txn.execute(sql, (event.room_id, domain))
elif event.membership != Membership.INVITE:
# Check if this was the last person to have left.
member_events = self._get_members_query_txn(
txn,
where_clause=("c.room_id = ? AND m.membership = ?"
" AND m.user_id != ?"),
where_values=(event.room_id, Membership.JOIN, target_user_id,)
)
joined_domains = set()
for e in member_events:
try:
joined_domains.add(
UserID.from_string(e.state_key).domain
)
except:
# FIXME: How do we deal with invalid user ids in the db?
logger.exception("Invalid user_id: %s", event.state_key)
if domain not in joined_domains:
sql = (
"DELETE FROM room_hosts WHERE room_id = ? AND host = ?"
)
txn.execute(sql, (event.room_id, domain))
self.get_rooms_for_user.invalidate(target_user_id) self.get_rooms_for_user.invalidate(target_user_id)
@defer.inlineCallbacks
def get_room_member(self, user_id, room_id): def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member. """Retrieve the current state of a room member.
@ -110,41 +75,27 @@ class RoomMemberStore(SQLBaseStore):
Returns: Returns:
Deferred: Results in a MembershipEvent or None. Deferred: Results in a MembershipEvent or None.
""" """
rows = yield self._get_members_by_dict({ def f(txn):
"e.room_id": room_id, events = self._get_members_events_txn(
"m.user_id": user_id, txn,
}) room_id,
user_id=user_id,
defer.returnValue(rows[0] if rows else None)
def _get_room_member(self, txn, user_id, room_id):
sql = (
"SELECT e.* FROM events as e"
" INNER JOIN room_memberships as m"
" ON e.event_id = m.event_id"
" INNER JOIN current_state_events as c"
" ON m.event_id = c.event_id"
" WHERE m.user_id = ? and e.room_id = ?"
" LIMIT 1"
) )
txn.execute(sql, (user_id, room_id))
rows = self.cursor_to_dict(txn) return events[0] if events else None
if rows:
return self._parse_events_txn(txn, rows)[0] return self.runInteraction("get_room_member", f)
else:
return None
def get_users_in_room(self, room_id): def get_users_in_room(self, room_id):
def f(txn): def f(txn):
sql = (
"SELECT m.user_id FROM room_memberships as m" rows = self._get_members_rows_txn(
" INNER JOIN current_state_events as c" txn,
" ON m.event_id = c.event_id" room_id=room_id,
" WHERE m.membership = ? AND m.room_id = ?" membership=Membership.JOIN,
) )
txn.execute(sql, (Membership.JOIN, room_id)) return [r["user_id"] for r in rows]
return [r[0] for r in txn.fetchall()]
return self.runInteraction("get_users_in_room", f) return self.runInteraction("get_users_in_room", f)
def get_room_members(self, room_id, membership=None): def get_room_members(self, room_id, membership=None):
@ -159,11 +110,14 @@ class RoomMemberStore(SQLBaseStore):
list of namedtuples representing the members in this room. list of namedtuples representing the members in this room.
""" """
where = {"m.room_id": room_id} def f(txn):
if membership: return self._get_members_events_txn(
where["m.membership"] = membership txn,
room_id,
membership=membership,
)
return self._get_members_by_dict(where) return self.runInteraction("get_room_members", f)
def get_rooms_for_user_where_membership_is(self, user_id, membership_list): def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user """ Get all the rooms for this user where the membership for this user
@ -209,32 +163,55 @@ class RoomMemberStore(SQLBaseStore):
] ]
def get_joined_hosts_for_room(self, room_id): def get_joined_hosts_for_room(self, room_id):
return self._simple_select_onecol( return self.runInteraction(
"room_hosts", "get_joined_hosts_for_room",
{"room_id": room_id}, self._get_joined_hosts_for_room_txn,
"host", room_id,
desc="get_joined_hosts_for_room",
) )
def _get_members_by_dict(self, where_dict): def _get_joined_hosts_for_room_txn(self, txn, room_id):
clause = " AND ".join("%s = ?" % k for k in where_dict.keys()) rows = self._get_members_rows_txn(
vals = where_dict.values() txn,
return self._get_members_query(clause, vals) room_id, membership=Membership.JOIN
)
joined_domains = set(
UserID.from_string(r["user_id"]).domain
for r in rows
)
return joined_domains
def _get_members_query(self, where_clause, where_values): def _get_members_query(self, where_clause, where_values):
return self.runInteraction( return self.runInteraction(
"get_members_query", self._get_members_query_txn, "get_members_query", self._get_members_events_txn,
where_clause, where_values where_clause, where_values
) )
def _get_members_query_txn(self, txn, where_clause, where_values): def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None):
rows = self._get_members_rows_txn(
txn,
room_id, membership, user_id,
)
return self._get_events_txn(txn, [r["event_id"] for r in rows])
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
where_clause = "c.room_id = ?"
where_values = [room_id]
if membership:
where_clause += " AND m.membership = ?"
where_values.append(membership)
if user_id:
where_clause += " AND m.user_id = ?"
where_values.append(user_id)
sql = ( sql = (
"SELECT e.* FROM events as e " "SELECT m.* FROM room_memberships as m"
"INNER JOIN room_memberships as m " " INNER JOIN current_state_events as c"
"ON e.event_id = m.event_id " " ON m.event_id = c.event_id"
"INNER JOIN current_state_events as c " " WHERE %(where)s"
"ON m.event_id = c.event_id "
"WHERE %(where)s "
) % { ) % {
"where": where_clause, "where": where_clause,
} }
@ -242,8 +219,7 @@ class RoomMemberStore(SQLBaseStore):
txn.execute(sql, where_values) txn.execute(sql, where_values)
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
results = self._parse_events_txn(txn, rows) return rows
return results
@cached() @cached()
def get_rooms_for_user(self, user_id): def get_rooms_for_user(self, user_id):

View file

@ -17,26 +17,25 @@ CREATE TABLE IF NOT EXISTS rejections(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
reason TEXT NOT NULL, reason TEXT NOT NULL,
last_check TEXT NOT NULL, last_check TEXT NOT NULL,
CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE UNIQUE (event_id)
); );
-- Push notification endpoints that users have configured -- Push notification endpoints that users have configured
CREATE TABLE IF NOT EXISTS pushers ( CREATE TABLE IF NOT EXISTS pushers (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT NOT NULL, user_name TEXT NOT NULL,
profile_tag varchar(32) NOT NULL, profile_tag VARCHAR(32) NOT NULL,
kind varchar(8) NOT NULL, kind VARCHAR(8) NOT NULL,
app_id varchar(64) NOT NULL, app_id VARCHAR(64) NOT NULL,
app_display_name varchar(64) NOT NULL, app_display_name VARCHAR(64) NOT NULL,
device_display_name varchar(128) NOT NULL, device_display_name VARCHAR(128) NOT NULL,
pushkey blob NOT NULL, pushkey VARBINARY(512) NOT NULL,
ts BIGINT NOT NULL, ts BIGINT UNSIGNED NOT NULL,
lang varchar(8), lang VARCHAR(8),
data blob, data LONGBLOB,
last_token TEXT, last_token TEXT,
last_success BIGINT, last_success BIGINT UNSIGNED,
failing_since BIGINT, failing_since BIGINT UNSIGNED,
FOREIGN KEY(user_name) REFERENCES users(name),
UNIQUE (app_id, pushkey) UNIQUE (app_id, pushkey)
); );
@ -55,13 +54,10 @@ CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name);
CREATE TABLE IF NOT EXISTS user_filters( CREATE TABLE IF NOT EXISTS user_filters(
user_id TEXT, user_id TEXT,
filter_id INTEGER, filter_id BIGINT UNSIGNED,
filter_json TEXT, filter_json LONGBLOB
FOREIGN KEY(user_id) REFERENCES users(id)
); );
CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters( CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters(
user_id, filter_id user_id, filter_id
); );
PRAGMA user_version = 12;

View file

@ -19,16 +19,13 @@ CREATE TABLE IF NOT EXISTS application_services(
token TEXT, token TEXT,
hs_token TEXT, hs_token TEXT,
sender TEXT, sender TEXT,
UNIQUE(token) ON CONFLICT ROLLBACK UNIQUE(token)
); );
CREATE TABLE IF NOT EXISTS application_services_regex( CREATE TABLE IF NOT EXISTS application_services_regex(
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
as_id INTEGER NOT NULL, as_id BIGINT UNSIGNED NOT NULL,
namespace INTEGER, /* enum[room_id|room_alias|user_id] */ namespace INTEGER, /* enum[room_id|room_alias|user_id] */
regex TEXT, regex TEXT,
FOREIGN KEY(as_id) REFERENCES application_services(id) FOREIGN KEY(as_id) REFERENCES application_services(id)
); );

View file

@ -15,16 +15,17 @@
CREATE TABLE IF NOT EXISTS application_services_state( CREATE TABLE IF NOT EXISTS application_services_state(
as_id TEXT PRIMARY KEY, as_id TEXT PRIMARY KEY,
state TEXT, state VARCHAR(5),
last_txn TEXT last_txn INTEGER
); );
CREATE TABLE IF NOT EXISTS application_services_txns( CREATE TABLE IF NOT EXISTS application_services_txns(
as_id TEXT NOT NULL, as_id TEXT NOT NULL,
txn_id INTEGER NOT NULL, txn_id INTEGER NOT NULL,
event_ids TEXT NOT NULL, event_ids TEXT NOT NULL,
UNIQUE(as_id, txn_id) ON CONFLICT ROLLBACK UNIQUE(as_id, txn_id)
); );
CREATE INDEX IF NOT EXISTS application_services_txns_id ON application_services_txns (
as_id
);

View file

@ -0,0 +1,2 @@
CREATE INDEX IF NOT EXISTS presence_list_user_id ON presence_list (user_id);

View file

@ -0,0 +1,25 @@
-- Drop, copy & recreate pushers table to change unique key
-- Also add access_token column at the same time
CREATE TABLE IF NOT EXISTS pushers2 (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT NOT NULL,
access_token INTEGER DEFAULT NULL,
profile_tag varchar(32) NOT NULL,
kind varchar(8) NOT NULL,
app_id varchar(64) NOT NULL,
app_display_name varchar(64) NOT NULL,
device_display_name varchar(128) NOT NULL,
pushkey blob NOT NULL,
ts BIGINT NOT NULL,
lang varchar(8),
data blob,
last_token TEXT,
last_success BIGINT,
failing_since BIGINT,
FOREIGN KEY(user_name) REFERENCES users(name),
UNIQUE (app_id, pushkey, user_name)
);
INSERT INTO pushers2 (id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since)
SELECT id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since FROM pushers;
DROP TABLE pushers;
ALTER TABLE pushers2 RENAME TO pushers;

View file

@ -0,0 +1,4 @@
CREATE INDEX events_order ON events (topological_ordering, stream_ordering);
CREATE INDEX events_order_room ON events (
room_id, topological_ordering, stream_ordering
);

View file

@ -0,0 +1,2 @@
CREATE INDEX IF NOT EXISTS remote_media_cache_thumbnails_media_id
ON remote_media_cache_thumbnails (media_id);

View file

@ -0,0 +1,9 @@
DELETE FROM event_to_state_groups WHERE state_group not in (
SELECT MAX(state_group) FROM event_to_state_groups GROUP BY event_id
);
DELETE FROM event_to_state_groups WHERE rowid not in (
SELECT MIN(rowid) FROM event_to_state_groups GROUP BY event_id
);

View file

@ -0,0 +1,3 @@
CREATE INDEX IF NOT EXISTS room_aliases_id ON room_aliases(room_id);
CREATE INDEX IF NOT EXISTS room_alias_servers_alias ON room_alias_servers(room_alias);

View file

@ -0,0 +1,80 @@
-- We can use SQLite features here, since other db support was only added in v16
--
DELETE FROM current_state_events WHERE rowid not in (
SELECT MIN(rowid) FROM current_state_events GROUP BY event_id
);
DROP INDEX IF EXISTS current_state_events_event_id;
CREATE UNIQUE INDEX current_state_events_event_id ON current_state_events(event_id);
--
DELETE FROM room_memberships WHERE rowid not in (
SELECT MIN(rowid) FROM room_memberships GROUP BY event_id
);
DROP INDEX IF EXISTS room_memberships_event_id;
CREATE UNIQUE INDEX room_memberships_event_id ON room_memberships(event_id);
--
DELETE FROM feedback WHERE rowid not in (
SELECT MIN(rowid) FROM feedback GROUP BY event_id
);
DROP INDEX IF EXISTS feedback_event_id;
CREATE UNIQUE INDEX feedback_event_id ON feedback(event_id);
--
DELETE FROM topics WHERE rowid not in (
SELECT MIN(rowid) FROM topics GROUP BY event_id
);
DROP INDEX IF EXISTS topics_event_id;
CREATE UNIQUE INDEX topics_event_id ON topics(event_id);
--
DELETE FROM room_names WHERE rowid not in (
SELECT MIN(rowid) FROM room_names GROUP BY event_id
);
DROP INDEX IF EXISTS room_names_id;
CREATE UNIQUE INDEX room_names_id ON room_names(event_id);
--
DELETE FROM presence WHERE rowid not in (
SELECT MIN(rowid) FROM presence GROUP BY user_id
);
DROP INDEX IF EXISTS presence_id;
CREATE UNIQUE INDEX presence_id ON presence(user_id);
--
DELETE FROM presence_allow_inbound WHERE rowid not in (
SELECT MIN(rowid) FROM presence_allow_inbound
GROUP BY observed_user_id, observer_user_id
);
DROP INDEX IF EXISTS presence_allow_inbound_observers;
CREATE UNIQUE INDEX presence_allow_inbound_observers ON presence_allow_inbound(
observed_user_id, observer_user_id
);
--
DELETE FROM presence_list WHERE rowid not in (
SELECT MIN(rowid) FROM presence_list
GROUP BY user_id, observed_user_id
);
DROP INDEX IF EXISTS presence_list_observers;
CREATE UNIQUE INDEX presence_list_observers ON presence_list(
user_id, observed_user_id
);
--
DELETE FROM room_aliases WHERE rowid not in (
SELECT MIN(rowid) FROM room_aliases GROUP BY room_alias
);
DROP INDEX IF EXISTS room_aliases_id;
CREATE INDEX room_aliases_id ON room_aliases(room_id);

View file

@ -0,0 +1,56 @@
-- Convert `access_tokens`.user from rowids to user strings.
-- MUST BE DONE BEFORE REMOVING ID COLUMN FROM USERS TABLE BELOW
CREATE TABLE IF NOT EXISTS new_access_tokens(
id BIGINT UNSIGNED PRIMARY KEY,
user_id TEXT NOT NULL,
device_id TEXT,
token TEXT NOT NULL,
last_used BIGINT UNSIGNED,
UNIQUE(token)
);
INSERT INTO new_access_tokens
SELECT a.id, u.name, a.device_id, a.token, a.last_used
FROM access_tokens as a
INNER JOIN users as u ON u.id = a.user_id;
DROP TABLE access_tokens;
ALTER TABLE new_access_tokens RENAME TO access_tokens;
-- Remove ID column from `users` table
CREATE TABLE IF NOT EXISTS new_users(
name TEXT,
password_hash TEXT,
creation_ts BIGINT UNSIGNED,
admin BOOL DEFAULT 0 NOT NULL,
UNIQUE(name)
);
INSERT INTO new_users SELECT name, password_hash, creation_ts, admin FROM users;
DROP TABLE users;
ALTER TABLE new_users RENAME TO users;
-- Remove UNIQUE constraint from `user_ips` table
CREATE TABLE IF NOT EXISTS new_user_ips (
user_id TEXT NOT NULL,
access_token TEXT NOT NULL,
device_id TEXT,
ip TEXT NOT NULL,
user_agent TEXT NOT NULL,
last_seen BIGINT UNSIGNED NOT NULL
);
INSERT INTO new_user_ips
SELECT user, access_token, device_id, ip, user_agent, last_seen FROM user_ips;
DROP TABLE user_ips;
ALTER TABLE new_user_ips RENAME TO user_ips;
CREATE INDEX IF NOT EXISTS user_ips_user ON user_ips(user_id);
CREATE INDEX IF NOT EXISTS user_ips_user_ip ON user_ips(user_id, access_token, ip);

View file

@ -16,52 +16,52 @@
CREATE TABLE IF NOT EXISTS event_forward_extremities( CREATE TABLE IF NOT EXISTS event_forward_extremities(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE UNIQUE (event_id, room_id)
); );
CREATE INDEX IF NOT EXISTS ev_extrem_room ON event_forward_extremities(room_id); CREATE INDEX ev_extrem_room ON event_forward_extremities(room_id);
CREATE INDEX IF NOT EXISTS ev_extrem_id ON event_forward_extremities(event_id); CREATE INDEX ev_extrem_id ON event_forward_extremities(event_id);
CREATE TABLE IF NOT EXISTS event_backward_extremities( CREATE TABLE IF NOT EXISTS event_backward_extremities(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE UNIQUE (event_id, room_id)
); );
CREATE INDEX IF NOT EXISTS ev_b_extrem_room ON event_backward_extremities(room_id); CREATE INDEX ev_b_extrem_room ON event_backward_extremities(room_id);
CREATE INDEX IF NOT EXISTS ev_b_extrem_id ON event_backward_extremities(event_id); CREATE INDEX ev_b_extrem_id ON event_backward_extremities(event_id);
CREATE TABLE IF NOT EXISTS event_edges( CREATE TABLE IF NOT EXISTS event_edges(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
prev_event_id TEXT NOT NULL, prev_event_id TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
is_state INTEGER NOT NULL, is_state BOOL NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id, is_state) UNIQUE (event_id, prev_event_id, room_id, is_state)
); );
CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id); CREATE INDEX ev_edges_id ON event_edges(event_id);
CREATE INDEX IF NOT EXISTS ev_edges_prev_id ON event_edges(prev_event_id); CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id);
CREATE TABLE IF NOT EXISTS room_depth( CREATE TABLE IF NOT EXISTS room_depth(
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
min_depth INTEGER NOT NULL, min_depth INTEGER NOT NULL,
CONSTRAINT uniqueness UNIQUE (room_id) UNIQUE (room_id)
); );
CREATE INDEX IF NOT EXISTS room_depth_room ON room_depth(room_id); CREATE INDEX room_depth_room ON room_depth(room_id);
create TABLE IF NOT EXISTS event_destinations( create TABLE IF NOT EXISTS event_destinations(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
destination TEXT NOT NULL, destination TEXT NOT NULL,
delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered delivered_ts BIGINT DEFAULT 0, -- or 0 if not delivered
CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE UNIQUE (event_id, destination)
); );
CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id); CREATE INDEX event_destinations_id ON event_destinations(event_id);
CREATE TABLE IF NOT EXISTS state_forward_extremities( CREATE TABLE IF NOT EXISTS state_forward_extremities(
@ -69,21 +69,21 @@ CREATE TABLE IF NOT EXISTS state_forward_extremities(
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
type TEXT NOT NULL, type TEXT NOT NULL,
state_key TEXT NOT NULL, state_key TEXT NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE UNIQUE (event_id, room_id)
); );
CREATE INDEX IF NOT EXISTS st_extrem_keys ON state_forward_extremities( CREATE INDEX st_extrem_keys ON state_forward_extremities(
room_id, type, state_key room_id, type, state_key
); );
CREATE INDEX IF NOT EXISTS st_extrem_id ON state_forward_extremities(event_id); CREATE INDEX st_extrem_id ON state_forward_extremities(event_id);
CREATE TABLE IF NOT EXISTS event_auth( CREATE TABLE IF NOT EXISTS event_auth(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
auth_id TEXT NOT NULL, auth_id TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, auth_id, room_id) UNIQUE (event_id, auth_id, room_id)
); );
CREATE INDEX IF NOT EXISTS evauth_edges_id ON event_auth(event_id); CREATE INDEX evauth_edges_id ON event_auth(event_id);
CREATE INDEX IF NOT EXISTS evauth_edges_auth_id ON event_auth(auth_id); CREATE INDEX evauth_edges_auth_id ON event_auth(auth_id);

View file

@ -16,50 +16,40 @@
CREATE TABLE IF NOT EXISTS event_content_hashes ( CREATE TABLE IF NOT EXISTS event_content_hashes (
event_id TEXT, event_id TEXT,
algorithm TEXT, algorithm TEXT,
hash BLOB, hash bytea,
CONSTRAINT uniqueness UNIQUE (event_id, algorithm) UNIQUE (event_id, algorithm)
); );
CREATE INDEX IF NOT EXISTS event_content_hashes_id ON event_content_hashes( CREATE INDEX event_content_hashes_id ON event_content_hashes(event_id);
event_id
);
CREATE TABLE IF NOT EXISTS event_reference_hashes ( CREATE TABLE IF NOT EXISTS event_reference_hashes (
event_id TEXT, event_id TEXT,
algorithm TEXT, algorithm TEXT,
hash BLOB, hash bytea,
CONSTRAINT uniqueness UNIQUE (event_id, algorithm) UNIQUE (event_id, algorithm)
); );
CREATE INDEX IF NOT EXISTS event_reference_hashes_id ON event_reference_hashes ( CREATE INDEX event_reference_hashes_id ON event_reference_hashes(event_id);
event_id
);
CREATE TABLE IF NOT EXISTS event_signatures ( CREATE TABLE IF NOT EXISTS event_signatures (
event_id TEXT, event_id TEXT,
signature_name TEXT, signature_name TEXT,
key_id TEXT, key_id TEXT,
signature BLOB, signature bytea,
CONSTRAINT uniqueness UNIQUE (event_id, signature_name, key_id) UNIQUE (event_id, signature_name, key_id)
); );
CREATE INDEX IF NOT EXISTS event_signatures_id ON event_signatures ( CREATE INDEX event_signatures_id ON event_signatures(event_id);
event_id
);
CREATE TABLE IF NOT EXISTS event_edge_hashes( CREATE TABLE IF NOT EXISTS event_edge_hashes(
event_id TEXT, event_id TEXT,
prev_event_id TEXT, prev_event_id TEXT,
algorithm TEXT, algorithm TEXT,
hash BLOB, hash bytea,
CONSTRAINT uniqueness UNIQUE ( UNIQUE (event_id, prev_event_id, algorithm)
event_id, prev_event_id, algorithm
)
); );
CREATE INDEX IF NOT EXISTS event_edge_hashes_id ON event_edge_hashes( CREATE INDEX event_edge_hashes_id ON event_edge_hashes(event_id);
event_id
);

View file

@ -15,7 +15,7 @@
CREATE TABLE IF NOT EXISTS events( CREATE TABLE IF NOT EXISTS events(
stream_ordering INTEGER PRIMARY KEY AUTOINCREMENT, stream_ordering INTEGER PRIMARY KEY AUTOINCREMENT,
topological_ordering INTEGER NOT NULL, topological_ordering BIGINT NOT NULL,
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
type TEXT NOT NULL, type TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
@ -23,26 +23,24 @@ CREATE TABLE IF NOT EXISTS events(
unrecognized_keys TEXT, unrecognized_keys TEXT,
processed BOOL NOT NULL, processed BOOL NOT NULL,
outlier BOOL NOT NULL, outlier BOOL NOT NULL,
depth INTEGER DEFAULT 0 NOT NULL, depth BIGINT DEFAULT 0 NOT NULL,
CONSTRAINT ev_uniq UNIQUE (event_id) UNIQUE (event_id)
); );
CREATE INDEX IF NOT EXISTS events_event_id ON events (event_id); CREATE INDEX events_stream_ordering ON events (stream_ordering);
CREATE INDEX IF NOT EXISTS events_stream_ordering ON events (stream_ordering); CREATE INDEX events_topological_ordering ON events (topological_ordering);
CREATE INDEX IF NOT EXISTS events_topological_ordering ON events (topological_ordering); CREATE INDEX events_room_id ON events (room_id);
CREATE INDEX IF NOT EXISTS events_room_id ON events (room_id);
CREATE TABLE IF NOT EXISTS event_json( CREATE TABLE IF NOT EXISTS event_json(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
internal_metadata NOT NULL, internal_metadata TEXT NOT NULL,
json BLOB NOT NULL, json TEXT NOT NULL,
CONSTRAINT ev_j_uniq UNIQUE (event_id) UNIQUE (event_id)
); );
CREATE INDEX IF NOT EXISTS event_json_id ON event_json(event_id); CREATE INDEX event_json_room_id ON event_json(room_id);
CREATE INDEX IF NOT EXISTS event_json_room_id ON event_json(room_id);
CREATE TABLE IF NOT EXISTS state_events( CREATE TABLE IF NOT EXISTS state_events(
@ -50,13 +48,13 @@ CREATE TABLE IF NOT EXISTS state_events(
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
type TEXT NOT NULL, type TEXT NOT NULL,
state_key TEXT NOT NULL, state_key TEXT NOT NULL,
prev_state TEXT prev_state TEXT,
UNIQUE (event_id)
); );
CREATE UNIQUE INDEX IF NOT EXISTS state_events_event_id ON state_events (event_id); CREATE INDEX state_events_room_id ON state_events (room_id);
CREATE INDEX IF NOT EXISTS state_events_room_id ON state_events (room_id); CREATE INDEX state_events_type ON state_events (type);
CREATE INDEX IF NOT EXISTS state_events_type ON state_events (type); CREATE INDEX state_events_state_key ON state_events (state_key);
CREATE INDEX IF NOT EXISTS state_events_state_key ON state_events (state_key);
CREATE TABLE IF NOT EXISTS current_state_events( CREATE TABLE IF NOT EXISTS current_state_events(
@ -64,13 +62,13 @@ CREATE TABLE IF NOT EXISTS current_state_events(
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
type TEXT NOT NULL, type TEXT NOT NULL,
state_key TEXT NOT NULL, state_key TEXT NOT NULL,
CONSTRAINT curr_uniq UNIQUE (room_id, type, state_key) ON CONFLICT REPLACE UNIQUE (room_id, type, state_key)
); );
CREATE INDEX IF NOT EXISTS curr_events_event_id ON current_state_events (event_id); CREATE INDEX curr_events_event_id ON current_state_events (event_id);
CREATE INDEX IF NOT EXISTS current_state_events_room_id ON current_state_events (room_id); CREATE INDEX current_state_events_room_id ON current_state_events (room_id);
CREATE INDEX IF NOT EXISTS current_state_events_type ON current_state_events (type); CREATE INDEX current_state_events_type ON current_state_events (type);
CREATE INDEX IF NOT EXISTS current_state_events_state_key ON current_state_events (state_key); CREATE INDEX current_state_events_state_key ON current_state_events (state_key);
CREATE TABLE IF NOT EXISTS room_memberships( CREATE TABLE IF NOT EXISTS room_memberships(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
@ -80,9 +78,9 @@ CREATE TABLE IF NOT EXISTS room_memberships(
membership TEXT NOT NULL membership TEXT NOT NULL
); );
CREATE INDEX IF NOT EXISTS room_memberships_event_id ON room_memberships (event_id); CREATE INDEX room_memberships_event_id ON room_memberships (event_id);
CREATE INDEX IF NOT EXISTS room_memberships_room_id ON room_memberships (room_id); CREATE INDEX room_memberships_room_id ON room_memberships (room_id);
CREATE INDEX IF NOT EXISTS room_memberships_user_id ON room_memberships (user_id); CREATE INDEX room_memberships_user_id ON room_memberships (user_id);
CREATE TABLE IF NOT EXISTS feedback( CREATE TABLE IF NOT EXISTS feedback(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
@ -98,8 +96,8 @@ CREATE TABLE IF NOT EXISTS topics(
topic TEXT NOT NULL topic TEXT NOT NULL
); );
CREATE INDEX IF NOT EXISTS topics_event_id ON topics(event_id); CREATE INDEX topics_event_id ON topics(event_id);
CREATE INDEX IF NOT EXISTS topics_room_id ON topics(room_id); CREATE INDEX topics_room_id ON topics(room_id);
CREATE TABLE IF NOT EXISTS room_names( CREATE TABLE IF NOT EXISTS room_names(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
@ -107,19 +105,19 @@ CREATE TABLE IF NOT EXISTS room_names(
name TEXT NOT NULL name TEXT NOT NULL
); );
CREATE INDEX IF NOT EXISTS room_names_event_id ON room_names(event_id); CREATE INDEX room_names_event_id ON room_names(event_id);
CREATE INDEX IF NOT EXISTS room_names_room_id ON room_names(room_id); CREATE INDEX room_names_room_id ON room_names(room_id);
CREATE TABLE IF NOT EXISTS rooms( CREATE TABLE IF NOT EXISTS rooms(
room_id TEXT PRIMARY KEY NOT NULL, room_id TEXT PRIMARY KEY NOT NULL,
is_public INTEGER, is_public BOOL,
creator TEXT creator TEXT
); );
CREATE TABLE IF NOT EXISTS room_hosts( CREATE TABLE IF NOT EXISTS room_hosts(
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
host TEXT NOT NULL, host TEXT NOT NULL,
CONSTRAINT room_hosts_uniq UNIQUE (room_id, host) ON CONFLICT IGNORE UNIQUE (room_id, host)
); );
CREATE INDEX IF NOT EXISTS room_hosts_room_id ON room_hosts (room_id); CREATE INDEX room_hosts_room_id ON room_hosts (room_id);

View file

@ -16,16 +16,16 @@ CREATE TABLE IF NOT EXISTS server_tls_certificates(
server_name TEXT, -- Server name. server_name TEXT, -- Server name.
fingerprint TEXT, -- Certificate fingerprint. fingerprint TEXT, -- Certificate fingerprint.
from_server TEXT, -- Which key server the certificate was fetched from. from_server TEXT, -- Which key server the certificate was fetched from.
ts_added_ms INTEGER, -- When the certifcate was added. ts_added_ms BIGINT, -- When the certifcate was added.
tls_certificate BLOB, -- DER encoded x509 certificate. tls_certificate bytea, -- DER encoded x509 certificate.
CONSTRAINT uniqueness UNIQUE (server_name, fingerprint) UNIQUE (server_name, fingerprint)
); );
CREATE TABLE IF NOT EXISTS server_signature_keys( CREATE TABLE IF NOT EXISTS server_signature_keys(
server_name TEXT, -- Server name. server_name TEXT, -- Server name.
key_id TEXT, -- Key version. key_id TEXT, -- Key version.
from_server TEXT, -- Which key server the key was fetched form. from_server TEXT, -- Which key server the key was fetched form.
ts_added_ms INTEGER, -- When the key was added. ts_added_ms BIGINT, -- When the key was added.
verify_key BLOB, -- NACL verification key. verify_key bytea, -- NACL verification key.
CONSTRAINT uniqueness UNIQUE (server_name, key_id) UNIQUE (server_name, key_id)
); );

View file

@ -17,10 +17,10 @@ CREATE TABLE IF NOT EXISTS local_media_repository (
media_id TEXT, -- The id used to refer to the media. media_id TEXT, -- The id used to refer to the media.
media_type TEXT, -- The MIME-type of the media. media_type TEXT, -- The MIME-type of the media.
media_length INTEGER, -- Length of the media in bytes. media_length INTEGER, -- Length of the media in bytes.
created_ts INTEGER, -- When the content was uploaded in ms. created_ts BIGINT, -- When the content was uploaded in ms.
upload_name TEXT, -- The name the media was uploaded with. upload_name TEXT, -- The name the media was uploaded with.
user_id TEXT, -- The user who uploaded the file. user_id TEXT, -- The user who uploaded the file.
CONSTRAINT uniqueness UNIQUE (media_id) UNIQUE (media_id)
); );
CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails ( CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails (
@ -30,23 +30,23 @@ CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails (
thumbnail_type TEXT, -- The MIME-type of the thumbnail. thumbnail_type TEXT, -- The MIME-type of the thumbnail.
thumbnail_method TEXT, -- The method used to make the thumbnail. thumbnail_method TEXT, -- The method used to make the thumbnail.
thumbnail_length INTEGER, -- The length of the thumbnail in bytes. thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
CONSTRAINT uniqueness UNIQUE ( UNIQUE (
media_id, thumbnail_width, thumbnail_height, thumbnail_type media_id, thumbnail_width, thumbnail_height, thumbnail_type
) )
); );
CREATE INDEX IF NOT EXISTS local_media_repository_thumbnails_media_id CREATE INDEX local_media_repository_thumbnails_media_id
ON local_media_repository_thumbnails (media_id); ON local_media_repository_thumbnails (media_id);
CREATE TABLE IF NOT EXISTS remote_media_cache ( CREATE TABLE IF NOT EXISTS remote_media_cache (
media_origin TEXT, -- The remote HS the media came from. media_origin TEXT, -- The remote HS the media came from.
media_id TEXT, -- The id used to refer to the media on that server. media_id TEXT, -- The id used to refer to the media on that server.
media_type TEXT, -- The MIME-type of the media. media_type TEXT, -- The MIME-type of the media.
created_ts INTEGER, -- When the content was uploaded in ms. created_ts BIGINT, -- When the content was uploaded in ms.
upload_name TEXT, -- The name the media was uploaded with. upload_name TEXT, -- The name the media was uploaded with.
media_length INTEGER, -- Length of the media in bytes. media_length INTEGER, -- Length of the media in bytes.
filesystem_id TEXT, -- The name used to store the media on disk. filesystem_id TEXT, -- The name used to store the media on disk.
CONSTRAINT uniqueness UNIQUE (media_origin, media_id) UNIQUE (media_origin, media_id)
); );
CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails ( CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails (
@ -58,11 +58,8 @@ CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails (
thumbnail_type TEXT, -- The MIME-type of the thumbnail. thumbnail_type TEXT, -- The MIME-type of the thumbnail.
thumbnail_length INTEGER, -- The length of the thumbnail in bytes. thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
filesystem_id TEXT, -- The name used to store the media on disk. filesystem_id TEXT, -- The name used to store the media on disk.
CONSTRAINT uniqueness UNIQUE ( UNIQUE (
media_origin, media_id, thumbnail_width, thumbnail_height, media_origin, media_id, thumbnail_width, thumbnail_height,
thumbnail_type, thumbnail_type thumbnail_type
) )
); );
CREATE INDEX IF NOT EXISTS remote_media_cache_thumbnails_media_id
ON local_media_repository_thumbnails (media_id);

View file

@ -13,26 +13,23 @@
* limitations under the License. * limitations under the License.
*/ */
CREATE TABLE IF NOT EXISTS presence( CREATE TABLE IF NOT EXISTS presence(
user_id INTEGER NOT NULL, user_id TEXT NOT NULL,
state INTEGER, state VARCHAR(20),
status_msg TEXT, status_msg TEXT,
mtime INTEGER, -- miliseconds since last state change mtime BIGINT -- miliseconds since last state change
FOREIGN KEY(user_id) REFERENCES users(id)
); );
-- For each of /my/ users which possibly-remote users are allowed to see their -- For each of /my/ users which possibly-remote users are allowed to see their
-- presence state -- presence state
CREATE TABLE IF NOT EXISTS presence_allow_inbound( CREATE TABLE IF NOT EXISTS presence_allow_inbound(
observed_user_id INTEGER NOT NULL, observed_user_id TEXT NOT NULL,
observer_user_id TEXT, -- a UserID, observer_user_id TEXT NOT NULL -- a UserID,
FOREIGN KEY(observed_user_id) REFERENCES users(id)
); );
-- For each of /my/ users (watcher), which possibly-remote users are they -- For each of /my/ users (watcher), which possibly-remote users are they
-- watching? -- watching?
CREATE TABLE IF NOT EXISTS presence_list( CREATE TABLE IF NOT EXISTS presence_list(
user_id INTEGER NOT NULL, user_id TEXT NOT NULL,
observed_user_id TEXT, -- a UserID, observed_user_id TEXT NOT NULL, -- a UserID,
accepted BOOLEAN, accepted BOOLEAN NOT NULL
FOREIGN KEY(user_id) REFERENCES users(id)
); );

View file

@ -13,8 +13,7 @@
* limitations under the License. * limitations under the License.
*/ */
CREATE TABLE IF NOT EXISTS profiles( CREATE TABLE IF NOT EXISTS profiles(
user_id INTEGER NOT NULL, user_id TEXT NOT NULL,
displayname TEXT, displayname TEXT,
avatar_url TEXT, avatar_url TEXT
FOREIGN KEY(user_id) REFERENCES users(id)
); );

View file

@ -15,8 +15,8 @@
CREATE TABLE IF NOT EXISTS redactions ( CREATE TABLE IF NOT EXISTS redactions (
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
redacts TEXT NOT NULL, redacts TEXT NOT NULL,
CONSTRAINT ev_uniq UNIQUE (event_id) UNIQUE (event_id)
); );
CREATE INDEX IF NOT EXISTS redactions_event_id ON redactions (event_id); CREATE INDEX redactions_event_id ON redactions (event_id);
CREATE INDEX IF NOT EXISTS redactions_redacts ON redactions (redacts); CREATE INDEX redactions_redacts ON redactions (redacts);

View file

@ -22,6 +22,3 @@ CREATE TABLE IF NOT EXISTS room_alias_servers(
room_alias TEXT NOT NULL, room_alias TEXT NOT NULL,
server TEXT NOT NULL server TEXT NOT NULL
); );

View file

@ -30,18 +30,11 @@ CREATE TABLE IF NOT EXISTS state_groups_state(
CREATE TABLE IF NOT EXISTS event_to_state_groups( CREATE TABLE IF NOT EXISTS event_to_state_groups(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
state_group INTEGER NOT NULL, state_group INTEGER NOT NULL,
CONSTRAINT event_to_state_groups_uniq UNIQUE (event_id) UNIQUE (event_id)
); );
CREATE INDEX IF NOT EXISTS state_groups_id ON state_groups(id); CREATE INDEX state_groups_id ON state_groups(id);
CREATE INDEX IF NOT EXISTS state_groups_state_id ON state_groups_state( CREATE INDEX state_groups_state_id ON state_groups_state(state_group);
state_group CREATE INDEX state_groups_state_tuple ON state_groups_state(room_id, type, state_key);
); CREATE INDEX event_to_state_groups_id ON event_to_state_groups(event_id);
CREATE INDEX IF NOT EXISTS state_groups_state_tuple ON state_groups_state(
room_id, type, state_key
);
CREATE INDEX IF NOT EXISTS event_to_state_groups_id ON event_to_state_groups(
event_id
);

View file

@ -16,15 +16,14 @@
CREATE TABLE IF NOT EXISTS received_transactions( CREATE TABLE IF NOT EXISTS received_transactions(
transaction_id TEXT, transaction_id TEXT,
origin TEXT, origin TEXT,
ts INTEGER, ts BIGINT,
response_code INTEGER, response_code INTEGER,
response_json TEXT, response_json bytea,
has_been_referenced BOOL default 0, -- Whether thishas been referenced by a prev_tx has_been_referenced SMALLINT DEFAULT 0, -- Whether thishas been referenced by a prev_tx
CONSTRAINT uniquesss UNIQUE (transaction_id, origin) ON CONFLICT REPLACE UNIQUE (transaction_id, origin)
); );
CREATE UNIQUE INDEX IF NOT EXISTS transactions_txid ON received_transactions(transaction_id, origin); CREATE INDEX transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0;
CREATE INDEX IF NOT EXISTS transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0;
-- Stores what transactions we've sent, what their response was (if we got one) and whether we have -- Stores what transactions we've sent, what their response was (if we got one) and whether we have
@ -35,17 +34,14 @@ CREATE TABLE IF NOT EXISTS sent_transactions(
destination TEXT, destination TEXT,
response_code INTEGER DEFAULT 0, response_code INTEGER DEFAULT 0,
response_json TEXT, response_json TEXT,
ts INTEGER ts BIGINT
); );
CREATE INDEX IF NOT EXISTS sent_transaction_dest ON sent_transactions(destination); CREATE INDEX sent_transaction_dest ON sent_transactions(destination);
CREATE INDEX IF NOT EXISTS sent_transaction_dest_referenced ON sent_transactions( CREATE INDEX sent_transaction_txn_id ON sent_transactions(transaction_id);
destination
);
CREATE INDEX IF NOT EXISTS sent_transaction_txn_id ON sent_transactions(transaction_id);
-- So that we can do an efficient look up of all transactions that have yet to be successfully -- So that we can do an efficient look up of all transactions that have yet to be successfully
-- sent. -- sent.
CREATE INDEX IF NOT EXISTS sent_transaction_sent ON sent_transactions(response_code); CREATE INDEX sent_transaction_sent ON sent_transactions(response_code);
-- For sent transactions only. -- For sent transactions only.
@ -56,13 +52,12 @@ CREATE TABLE IF NOT EXISTS transaction_id_to_pdu(
pdu_origin TEXT pdu_origin TEXT
); );
CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_tx ON transaction_id_to_pdu(transaction_id, destination); CREATE INDEX transaction_id_to_pdu_tx ON transaction_id_to_pdu(transaction_id, destination);
CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination); CREATE INDEX transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination);
CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_index ON transaction_id_to_pdu(transaction_id, destination);
-- To track destination health -- To track destination health
CREATE TABLE IF NOT EXISTS destinations( CREATE TABLE IF NOT EXISTS destinations(
destination TEXT PRIMARY KEY, destination TEXT PRIMARY KEY,
retry_last_ts INTEGER, retry_last_ts BIGINT,
retry_interval INTEGER retry_interval INTEGER
); );

View file

@ -16,19 +16,18 @@ CREATE TABLE IF NOT EXISTS users(
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT, name TEXT,
password_hash TEXT, password_hash TEXT,
creation_ts INTEGER, creation_ts BIGINT,
admin BOOL DEFAULT 0 NOT NULL, admin SMALLINT DEFAULT 0 NOT NULL,
UNIQUE(name) ON CONFLICT ROLLBACK UNIQUE(name)
); );
CREATE TABLE IF NOT EXISTS access_tokens( CREATE TABLE IF NOT EXISTS access_tokens(
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL, user_id TEXT NOT NULL,
device_id TEXT, device_id TEXT,
token TEXT NOT NULL, token TEXT NOT NULL,
last_used INTEGER, last_used BIGINT,
FOREIGN KEY(user_id) REFERENCES users(id), UNIQUE(token)
UNIQUE(token) ON CONFLICT ROLLBACK
); );
CREATE TABLE IF NOT EXISTS user_ips ( CREATE TABLE IF NOT EXISTS user_ips (
@ -37,9 +36,8 @@ CREATE TABLE IF NOT EXISTS user_ips (
device_id TEXT, device_id TEXT,
ip TEXT NOT NULL, ip TEXT NOT NULL,
user_agent TEXT NOT NULL, user_agent TEXT NOT NULL,
last_seen INTEGER NOT NULL, last_seen BIGINT NOT NULL,
CONSTRAINT user_ip UNIQUE (user, access_token, ip, user_agent) ON CONFLICT REPLACE UNIQUE (user, access_token, ip, user_agent)
); );
CREATE INDEX IF NOT EXISTS user_ips_user ON user_ips(user); CREATE INDEX user_ips_user ON user_ips(user);

View file

@ -0,0 +1,48 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS application_services(
id BIGINT PRIMARY KEY,
url TEXT,
token TEXT,
hs_token TEXT,
sender TEXT,
UNIQUE(token)
);
CREATE TABLE IF NOT EXISTS application_services_regex(
id BIGINT PRIMARY KEY,
as_id BIGINT NOT NULL,
namespace INTEGER, /* enum[room_id|room_alias|user_id] */
regex TEXT,
FOREIGN KEY(as_id) REFERENCES application_services(id)
);
CREATE TABLE IF NOT EXISTS application_services_state(
as_id TEXT PRIMARY KEY,
state VARCHAR(5),
last_txn INTEGER
);
CREATE TABLE IF NOT EXISTS application_services_txns(
as_id TEXT NOT NULL,
txn_id INTEGER NOT NULL,
event_ids TEXT NOT NULL,
UNIQUE(as_id, txn_id)
);
CREATE INDEX application_services_txns_id ON application_services_txns (
as_id
);

View file

@ -0,0 +1,89 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS event_forward_extremities(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
UNIQUE (event_id, room_id)
);
CREATE INDEX ev_extrem_room ON event_forward_extremities(room_id);
CREATE INDEX ev_extrem_id ON event_forward_extremities(event_id);
CREATE TABLE IF NOT EXISTS event_backward_extremities(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
UNIQUE (event_id, room_id)
);
CREATE INDEX ev_b_extrem_room ON event_backward_extremities(room_id);
CREATE INDEX ev_b_extrem_id ON event_backward_extremities(event_id);
CREATE TABLE IF NOT EXISTS event_edges(
event_id TEXT NOT NULL,
prev_event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
is_state BOOL NOT NULL,
UNIQUE (event_id, prev_event_id, room_id, is_state)
);
CREATE INDEX ev_edges_id ON event_edges(event_id);
CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id);
CREATE TABLE IF NOT EXISTS room_depth(
room_id TEXT NOT NULL,
min_depth INTEGER NOT NULL,
UNIQUE (room_id)
);
CREATE INDEX room_depth_room ON room_depth(room_id);
create TABLE IF NOT EXISTS event_destinations(
event_id TEXT NOT NULL,
destination TEXT NOT NULL,
delivered_ts BIGINT DEFAULT 0, -- or 0 if not delivered
UNIQUE (event_id, destination)
);
CREATE INDEX event_destinations_id ON event_destinations(event_id);
CREATE TABLE IF NOT EXISTS state_forward_extremities(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL,
UNIQUE (event_id, room_id)
);
CREATE INDEX st_extrem_keys ON state_forward_extremities(
room_id, type, state_key
);
CREATE INDEX st_extrem_id ON state_forward_extremities(event_id);
CREATE TABLE IF NOT EXISTS event_auth(
event_id TEXT NOT NULL,
auth_id TEXT NOT NULL,
room_id TEXT NOT NULL,
UNIQUE (event_id, auth_id, room_id)
);
CREATE INDEX evauth_edges_id ON event_auth(event_id);
CREATE INDEX evauth_edges_auth_id ON event_auth(auth_id);

View file

@ -0,0 +1,55 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS event_content_hashes (
event_id TEXT,
algorithm TEXT,
hash bytea,
UNIQUE (event_id, algorithm)
);
CREATE INDEX event_content_hashes_id ON event_content_hashes(event_id);
CREATE TABLE IF NOT EXISTS event_reference_hashes (
event_id TEXT,
algorithm TEXT,
hash bytea,
UNIQUE (event_id, algorithm)
);
CREATE INDEX event_reference_hashes_id ON event_reference_hashes(event_id);
CREATE TABLE IF NOT EXISTS event_signatures (
event_id TEXT,
signature_name TEXT,
key_id TEXT,
signature bytea,
UNIQUE (event_id, signature_name, key_id)
);
CREATE INDEX event_signatures_id ON event_signatures(event_id);
CREATE TABLE IF NOT EXISTS event_edge_hashes(
event_id TEXT,
prev_event_id TEXT,
algorithm TEXT,
hash bytea,
UNIQUE (event_id, prev_event_id, algorithm)
);
CREATE INDEX event_edge_hashes_id ON event_edge_hashes(event_id);

View file

@ -0,0 +1,128 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS events(
stream_ordering INTEGER PRIMARY KEY,
topological_ordering BIGINT NOT NULL,
event_id TEXT NOT NULL,
type TEXT NOT NULL,
room_id TEXT NOT NULL,
content TEXT NOT NULL,
unrecognized_keys TEXT,
processed BOOL NOT NULL,
outlier BOOL NOT NULL,
depth BIGINT DEFAULT 0 NOT NULL,
UNIQUE (event_id)
);
CREATE INDEX events_stream_ordering ON events (stream_ordering);
CREATE INDEX events_topological_ordering ON events (topological_ordering);
CREATE INDEX events_order ON events (topological_ordering, stream_ordering);
CREATE INDEX events_room_id ON events (room_id);
CREATE INDEX events_order_room ON events (
room_id, topological_ordering, stream_ordering
);
CREATE TABLE IF NOT EXISTS event_json(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
internal_metadata TEXT NOT NULL,
json TEXT NOT NULL,
UNIQUE (event_id)
);
CREATE INDEX event_json_room_id ON event_json(room_id);
CREATE TABLE IF NOT EXISTS state_events(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL,
prev_state TEXT,
UNIQUE (event_id)
);
CREATE INDEX state_events_room_id ON state_events (room_id);
CREATE INDEX state_events_type ON state_events (type);
CREATE INDEX state_events_state_key ON state_events (state_key);
CREATE TABLE IF NOT EXISTS current_state_events(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL,
UNIQUE (event_id),
UNIQUE (room_id, type, state_key)
);
CREATE INDEX current_state_events_room_id ON current_state_events (room_id);
CREATE INDEX current_state_events_type ON current_state_events (type);
CREATE INDEX current_state_events_state_key ON current_state_events (state_key);
CREATE TABLE IF NOT EXISTS room_memberships(
event_id TEXT NOT NULL,
user_id TEXT NOT NULL,
sender TEXT NOT NULL,
room_id TEXT NOT NULL,
membership TEXT NOT NULL,
UNIQUE (event_id)
);
CREATE INDEX room_memberships_room_id ON room_memberships (room_id);
CREATE INDEX room_memberships_user_id ON room_memberships (user_id);
CREATE TABLE IF NOT EXISTS feedback(
event_id TEXT NOT NULL,
feedback_type TEXT,
target_event_id TEXT,
sender TEXT,
room_id TEXT,
UNIQUE (event_id)
);
CREATE TABLE IF NOT EXISTS topics(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
topic TEXT NOT NULL,
UNIQUE (event_id)
);
CREATE INDEX topics_room_id ON topics(room_id);
CREATE TABLE IF NOT EXISTS room_names(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
name TEXT NOT NULL,
UNIQUE (event_id)
);
CREATE INDEX room_names_room_id ON room_names(room_id);
CREATE TABLE IF NOT EXISTS rooms(
room_id TEXT PRIMARY KEY NOT NULL,
is_public BOOL,
creator TEXT
);
CREATE TABLE IF NOT EXISTS room_hosts(
room_id TEXT NOT NULL,
host TEXT NOT NULL,
UNIQUE (room_id, host)
);
CREATE INDEX room_hosts_room_id ON room_hosts (room_id);

View file

@ -0,0 +1,31 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS server_tls_certificates(
server_name TEXT, -- Server name.
fingerprint TEXT, -- Certificate fingerprint.
from_server TEXT, -- Which key server the certificate was fetched from.
ts_added_ms BIGINT, -- When the certifcate was added.
tls_certificate bytea, -- DER encoded x509 certificate.
UNIQUE (server_name, fingerprint)
);
CREATE TABLE IF NOT EXISTS server_signature_keys(
server_name TEXT, -- Server name.
key_id TEXT, -- Key version.
from_server TEXT, -- Which key server the key was fetched form.
ts_added_ms BIGINT, -- When the key was added.
verify_key bytea, -- NACL verification key.
UNIQUE (server_name, key_id)
);

View file

@ -0,0 +1,68 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS local_media_repository (
media_id TEXT, -- The id used to refer to the media.
media_type TEXT, -- The MIME-type of the media.
media_length INTEGER, -- Length of the media in bytes.
created_ts BIGINT, -- When the content was uploaded in ms.
upload_name TEXT, -- The name the media was uploaded with.
user_id TEXT, -- The user who uploaded the file.
UNIQUE (media_id)
);
CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails (
media_id TEXT, -- The id used to refer to the media.
thumbnail_width INTEGER, -- The width of the thumbnail in pixels.
thumbnail_height INTEGER, -- The height of the thumbnail in pixels.
thumbnail_type TEXT, -- The MIME-type of the thumbnail.
thumbnail_method TEXT, -- The method used to make the thumbnail.
thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
UNIQUE (
media_id, thumbnail_width, thumbnail_height, thumbnail_type
)
);
CREATE INDEX local_media_repository_thumbnails_media_id
ON local_media_repository_thumbnails (media_id);
CREATE TABLE IF NOT EXISTS remote_media_cache (
media_origin TEXT, -- The remote HS the media came from.
media_id TEXT, -- The id used to refer to the media on that server.
media_type TEXT, -- The MIME-type of the media.
created_ts BIGINT, -- When the content was uploaded in ms.
upload_name TEXT, -- The name the media was uploaded with.
media_length INTEGER, -- Length of the media in bytes.
filesystem_id TEXT, -- The name used to store the media on disk.
UNIQUE (media_origin, media_id)
);
CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails (
media_origin TEXT, -- The remote HS the media came from.
media_id TEXT, -- The id used to refer to the media.
thumbnail_width INTEGER, -- The width of the thumbnail in pixels.
thumbnail_height INTEGER, -- The height of the thumbnail in pixels.
thumbnail_method TEXT, -- The method used to make the thumbnail
thumbnail_type TEXT, -- The MIME-type of the thumbnail.
thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
filesystem_id TEXT, -- The name used to store the media on disk.
UNIQUE (
media_origin, media_id, thumbnail_width, thumbnail_height,
thumbnail_type
)
);
CREATE INDEX remote_media_cache_thumbnails_media_id
ON remote_media_cache_thumbnails (media_id);

View file

@ -0,0 +1,40 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS presence(
user_id TEXT NOT NULL,
state VARCHAR(20),
status_msg TEXT,
mtime BIGINT, -- miliseconds since last state change
UNIQUE (user_id)
);
-- For each of /my/ users which possibly-remote users are allowed to see their
-- presence state
CREATE TABLE IF NOT EXISTS presence_allow_inbound(
observed_user_id TEXT NOT NULL,
observer_user_id TEXT NOT NULL, -- a UserID,
UNIQUE (observed_user_id, observer_user_id)
);
-- For each of /my/ users (watcher), which possibly-remote users are they
-- watching?
CREATE TABLE IF NOT EXISTS presence_list(
user_id TEXT NOT NULL,
observed_user_id TEXT NOT NULL, -- a UserID,
accepted BOOLEAN NOT NULL,
UNIQUE (user_id, observed_user_id)
);
CREATE INDEX presence_list_user_id ON presence_list (user_id);

View file

@ -0,0 +1,20 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS profiles(
user_id TEXT NOT NULL,
displayname TEXT,
avatar_url TEXT,
UNIQUE(user_id)
);

View file

@ -0,0 +1,73 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS rejections(
event_id TEXT NOT NULL,
reason TEXT NOT NULL,
last_check TEXT NOT NULL,
UNIQUE (event_id)
);
-- Push notification endpoints that users have configured
CREATE TABLE IF NOT EXISTS pushers (
id BIGINT PRIMARY KEY,
user_name TEXT NOT NULL,
profile_tag VARCHAR(32) NOT NULL,
kind VARCHAR(8) NOT NULL,
app_id VARCHAR(64) NOT NULL,
app_display_name VARCHAR(64) NOT NULL,
device_display_name VARCHAR(128) NOT NULL,
pushkey bytea NOT NULL,
ts BIGINT NOT NULL,
lang VARCHAR(8),
data bytea,
last_token TEXT,
last_success BIGINT,
failing_since BIGINT,
UNIQUE (app_id, pushkey)
);
CREATE TABLE IF NOT EXISTS push_rules (
id BIGINT PRIMARY KEY,
user_name TEXT NOT NULL,
rule_id TEXT NOT NULL,
priority_class SMALLINT NOT NULL,
priority INTEGER NOT NULL DEFAULT 0,
conditions TEXT NOT NULL,
actions TEXT NOT NULL,
UNIQUE(user_name, rule_id)
);
CREATE INDEX push_rules_user_name on push_rules (user_name);
CREATE TABLE IF NOT EXISTS user_filters(
user_id TEXT,
filter_id BIGINT,
filter_json bytea
);
CREATE INDEX user_filters_by_user_id_filter_id ON user_filters(
user_id, filter_id
);
CREATE TABLE IF NOT EXISTS push_rules_enable (
id BIGINT PRIMARY KEY,
user_name TEXT NOT NULL,
rule_id TEXT NOT NULL,
enabled SMALLINT,
UNIQUE(user_name, rule_id)
);
CREATE INDEX push_rules_enable_user_name on push_rules_enable (user_name);

View file

@ -0,0 +1,22 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS redactions (
event_id TEXT NOT NULL,
redacts TEXT NOT NULL,
UNIQUE (event_id)
);
CREATE INDEX redactions_event_id ON redactions (event_id);
CREATE INDEX redactions_redacts ON redactions (redacts);

View file

@ -0,0 +1,29 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS room_aliases(
room_alias TEXT NOT NULL,
room_id TEXT NOT NULL,
UNIQUE (room_alias)
);
CREATE INDEX room_aliases_id ON room_aliases(room_id);
CREATE TABLE IF NOT EXISTS room_alias_servers(
room_alias TEXT NOT NULL,
server TEXT NOT NULL
);
CREATE INDEX room_alias_servers_alias ON room_alias_servers(room_alias);

View file

@ -0,0 +1,40 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS state_groups(
id BIGINT PRIMARY KEY,
room_id TEXT NOT NULL,
event_id TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS state_groups_state(
state_group BIGINT NOT NULL,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL,
event_id TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS event_to_state_groups(
event_id TEXT NOT NULL,
state_group BIGINT NOT NULL,
UNIQUE (event_id)
);
CREATE INDEX state_groups_id ON state_groups(id);
CREATE INDEX state_groups_state_id ON state_groups_state(state_group);
CREATE INDEX state_groups_state_tuple ON state_groups_state(room_id, type, state_key);
CREATE INDEX event_to_state_groups_id ON event_to_state_groups(event_id);

View file

@ -0,0 +1,63 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Stores what transaction ids we have received and what our response was
CREATE TABLE IF NOT EXISTS received_transactions(
transaction_id TEXT,
origin TEXT,
ts BIGINT,
response_code INTEGER,
response_json bytea,
has_been_referenced smallint default 0, -- Whether thishas been referenced by a prev_tx
UNIQUE (transaction_id, origin)
);
CREATE INDEX transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0;
-- Stores what transactions we've sent, what their response was (if we got one) and whether we have
-- since referenced the transaction in another outgoing transaction
CREATE TABLE IF NOT EXISTS sent_transactions(
id BIGINT PRIMARY KEY, -- This is used to apply insertion ordering
transaction_id TEXT,
destination TEXT,
response_code INTEGER DEFAULT 0,
response_json TEXT,
ts BIGINT
);
CREATE INDEX sent_transaction_dest ON sent_transactions(destination);
CREATE INDEX sent_transaction_txn_id ON sent_transactions(transaction_id);
-- So that we can do an efficient look up of all transactions that have yet to be successfully
-- sent.
CREATE INDEX sent_transaction_sent ON sent_transactions(response_code);
-- For sent transactions only.
CREATE TABLE IF NOT EXISTS transaction_id_to_pdu(
transaction_id INTEGER,
destination TEXT,
pdu_id TEXT,
pdu_origin TEXT,
UNIQUE (transaction_id, destination)
);
CREATE INDEX transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination);
-- To track destination health
CREATE TABLE IF NOT EXISTS destinations(
destination TEXT PRIMARY KEY,
retry_last_ts BIGINT,
retry_interval INTEGER
);

View file

@ -0,0 +1,42 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS users(
name TEXT,
password_hash TEXT,
creation_ts BIGINT,
admin SMALLINT DEFAULT 0 NOT NULL,
UNIQUE(name)
);
CREATE TABLE IF NOT EXISTS access_tokens(
id BIGINT PRIMARY KEY,
user_id TEXT NOT NULL,
device_id TEXT,
token TEXT NOT NULL,
last_used BIGINT,
UNIQUE(token)
);
CREATE TABLE IF NOT EXISTS user_ips (
user_id TEXT NOT NULL,
access_token TEXT NOT NULL,
device_id TEXT,
ip TEXT NOT NULL,
user_agent TEXT NOT NULL,
last_seen BIGINT NOT NULL
);
CREATE INDEX user_ips_user ON user_ips(user_id);
CREATE INDEX user_ips_user_ip ON user_ips(user_id, access_token, ip);

View file

@ -14,17 +14,14 @@
*/ */
CREATE TABLE IF NOT EXISTS schema_version( CREATE TABLE IF NOT EXISTS schema_version(
Lock char(1) NOT NULL DEFAULT 'X', -- Makes sure this table only has one row. Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
version INTEGER NOT NULL, version INTEGER NOT NULL,
upgraded BOOL NOT NULL, -- Whether we reached this version from an upgrade or an initial schema. upgraded BOOL NOT NULL, -- Whether we reached this version from an upgrade or an initial schema.
CONSTRAINT schema_version_lock_x CHECK (Lock='X') CHECK (Lock='X')
CONSTRAINT schema_version_lock_uniq UNIQUE (Lock)
); );
CREATE TABLE IF NOT EXISTS applied_schema_deltas( CREATE TABLE IF NOT EXISTS applied_schema_deltas(
version INTEGER NOT NULL, version INTEGER NOT NULL,
file TEXT NOT NULL, file TEXT NOT NULL,
CONSTRAINT schema_deltas_ver_file UNIQUE (version, file) ON CONFLICT IGNORE UNIQUE(version, file)
); );
CREATE INDEX IF NOT EXISTS schema_deltas_ver ON applied_schema_deltas(version);

View file

@ -56,7 +56,6 @@ class SignatureStore(SQLBaseStore):
"algorithm": algorithm, "algorithm": algorithm,
"hash": buffer(hash_bytes), "hash": buffer(hash_bytes),
}, },
or_ignore=True,
) )
def get_event_reference_hashes(self, event_ids): def get_event_reference_hashes(self, event_ids):
@ -100,7 +99,7 @@ class SignatureStore(SQLBaseStore):
" WHERE event_id = ?" " WHERE event_id = ?"
) )
txn.execute(query, (event_id, )) txn.execute(query, (event_id, ))
return dict(txn.fetchall()) return {k: v for k, v in txn.fetchall()}
def _store_event_reference_hash_txn(self, txn, event_id, algorithm, def _store_event_reference_hash_txn(self, txn, event_id, algorithm,
hash_bytes): hash_bytes):
@ -119,7 +118,6 @@ class SignatureStore(SQLBaseStore):
"algorithm": algorithm, "algorithm": algorithm,
"hash": buffer(hash_bytes), "hash": buffer(hash_bytes),
}, },
or_ignore=True,
) )
def _get_event_signatures_txn(self, txn, event_id): def _get_event_signatures_txn(self, txn, event_id):
@ -164,7 +162,6 @@ class SignatureStore(SQLBaseStore):
"key_id": key_id, "key_id": key_id,
"signature": buffer(signature_bytes), "signature": buffer(signature_bytes),
}, },
or_ignore=True,
) )
def _get_prev_event_hashes_txn(self, txn, event_id): def _get_prev_event_hashes_txn(self, txn, event_id):
@ -198,5 +195,4 @@ class SignatureStore(SQLBaseStore):
"algorithm": algorithm, "algorithm": algorithm,
"hash": buffer(hash_bytes), "hash": buffer(hash_bytes),
}, },
or_ignore=True,
) )

View file

@ -17,6 +17,8 @@ from ._base import SQLBaseStore
from twisted.internet import defer from twisted.internet import defer
from synapse.util.stringutils import random_string
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -91,14 +93,15 @@ class StateStore(SQLBaseStore):
state_group = context.state_group state_group = context.state_group
if not state_group: if not state_group:
state_group = self._simple_insert_txn( state_group = self._state_groups_id_gen.get_next_txn(txn)
self._simple_insert_txn(
txn, txn,
table="state_groups", table="state_groups",
values={ values={
"id": state_group,
"room_id": event.room_id, "room_id": event.room_id,
"event_id": event.event_id, "event_id": event.event_id,
}, },
or_ignore=True,
) )
for state in state_events.values(): for state in state_events.values():
@ -112,7 +115,6 @@ class StateStore(SQLBaseStore):
"state_key": state.state_key, "state_key": state.state_key,
"event_id": state.event_id, "event_id": state.event_id,
}, },
or_ignore=True,
) )
self._simple_insert_txn( self._simple_insert_txn(
@ -122,7 +124,6 @@ class StateStore(SQLBaseStore):
"state_group": state_group, "state_group": state_group,
"event_id": event.event_id, "event_id": event.event_id,
}, },
or_replace=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -154,3 +155,7 @@ class StateStore(SQLBaseStore):
events = yield self._parse_events(results) events = yield self._parse_events(results)
defer.returnValue(events) defer.returnValue(events)
def _make_group_id(clock):
return str(int(clock.time_msec())) + random_string(5)

View file

@ -35,7 +35,7 @@ what sort order was used:
from twisted.internet import defer from twisted.internet import defer
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -110,7 +110,7 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
if self.topological is None: if self.topological is None:
return "(%d < %s)" % (self.stream, "stream_ordering") return "(%d < %s)" % (self.stream, "stream_ordering")
else: else:
return "(%d < %s OR (%d == %s AND %d < %s))" % ( return "(%d < %s OR (%d = %s AND %d < %s))" % (
self.topological, "topological_ordering", self.topological, "topological_ordering",
self.topological, "topological_ordering", self.topological, "topological_ordering",
self.stream, "stream_ordering", self.stream, "stream_ordering",
@ -120,7 +120,7 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
if self.topological is None: if self.topological is None:
return "(%d >= %s)" % (self.stream, "stream_ordering") return "(%d >= %s)" % (self.stream, "stream_ordering")
else: else:
return "(%d > %s OR (%d == %s AND %d >= %s))" % ( return "(%d > %s OR (%d = %s AND %d >= %s))" % (
self.topological, "topological_ordering", self.topological, "topological_ordering",
self.topological, "topological_ordering", self.topological, "topological_ordering",
self.stream, "stream_ordering", self.stream, "stream_ordering",
@ -240,7 +240,7 @@ class StreamStore(SQLBaseStore):
sql = ( sql = (
"SELECT e.event_id, e.stream_ordering FROM events AS e WHERE " "SELECT e.event_id, e.stream_ordering FROM events AS e WHERE "
"(e.outlier = 0 AND (room_id IN (%(current)s)) OR " "(e.outlier = ? AND (room_id IN (%(current)s)) OR "
"(event_id IN (%(invites)s))) " "(event_id IN (%(invites)s))) "
"AND e.stream_ordering > ? AND e.stream_ordering <= ? " "AND e.stream_ordering > ? AND e.stream_ordering <= ? "
"ORDER BY stream_ordering ASC LIMIT %(limit)d " "ORDER BY stream_ordering ASC LIMIT %(limit)d "
@ -251,7 +251,7 @@ class StreamStore(SQLBaseStore):
} }
def f(txn): def f(txn):
txn.execute(sql, (user_id, user_id, from_id.stream, to_id.stream,)) txn.execute(sql, (False, user_id, user_id, from_id.stream, to_id.stream,))
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
@ -283,7 +283,7 @@ class StreamStore(SQLBaseStore):
# Tokens really represent positions between elements, but we use # Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence # the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities. # we have a bit of asymmetry when it comes to equalities.
args = [room_id] args = [False, room_id]
if direction == 'b': if direction == 'b':
order = "DESC" order = "DESC"
bounds = _StreamToken.parse(from_key).upper_bound() bounds = _StreamToken.parse(from_key).upper_bound()
@ -307,7 +307,7 @@ class StreamStore(SQLBaseStore):
sql = ( sql = (
"SELECT * FROM events" "SELECT * FROM events"
" WHERE outlier = 0 AND room_id = ? AND %(bounds)s" " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
" ORDER BY topological_ordering %(order)s," " ORDER BY topological_ordering %(order)s,"
" stream_ordering %(order)s %(limit)s" " stream_ordering %(order)s %(limit)s"
) % { ) % {
@ -358,7 +358,7 @@ class StreamStore(SQLBaseStore):
sql = ( sql = (
"SELECT stream_ordering, topological_ordering, event_id" "SELECT stream_ordering, topological_ordering, event_id"
" FROM events" " FROM events"
" WHERE room_id = ? AND stream_ordering <= ? AND outlier = 0" " WHERE room_id = ? AND stream_ordering <= ? AND outlier = ?"
" ORDER BY topological_ordering DESC, stream_ordering DESC" " ORDER BY topological_ordering DESC, stream_ordering DESC"
" LIMIT ?" " LIMIT ?"
) )
@ -368,17 +368,17 @@ class StreamStore(SQLBaseStore):
"SELECT stream_ordering, topological_ordering, event_id" "SELECT stream_ordering, topological_ordering, event_id"
" FROM events" " FROM events"
" WHERE room_id = ? AND stream_ordering > ?" " WHERE room_id = ? AND stream_ordering > ?"
" AND stream_ordering <= ? AND outlier = 0" " AND stream_ordering <= ? AND outlier = ?"
" ORDER BY topological_ordering DESC, stream_ordering DESC" " ORDER BY topological_ordering DESC, stream_ordering DESC"
" LIMIT ?" " LIMIT ?"
) )
def get_recent_events_for_room_txn(txn): def get_recent_events_for_room_txn(txn):
if from_token is None: if from_token is None:
txn.execute(sql, (room_id, end_token.stream, limit,)) txn.execute(sql, (room_id, end_token.stream, False, limit,))
else: else:
txn.execute(sql, ( txn.execute(sql, (
room_id, from_token.stream, end_token.stream, limit room_id, from_token.stream, end_token.stream, False, limit
)) ))
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
@ -413,12 +413,10 @@ class StreamStore(SQLBaseStore):
"get_recent_events_for_room", get_recent_events_for_room_txn "get_recent_events_for_room", get_recent_events_for_room_txn
) )
@cached(num_args=0) @defer.inlineCallbacks
def get_room_events_max_id(self): def get_room_events_max_id(self):
return self.runInteraction( token = yield self._stream_id_gen.get_max_token(self)
"get_room_events_max_id", defer.returnValue("s%d" % (token,))
self._get_room_events_max_id_txn
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_min_token(self): def _get_min_token(self):
@ -433,27 +431,6 @@ class StreamStore(SQLBaseStore):
defer.returnValue(self.min_token) defer.returnValue(self.min_token)
def get_next_stream_id(self):
with self._next_stream_id_lock:
i = self._next_stream_id
self._next_stream_id += 1
return i
def _get_room_events_max_id_txn(self, txn):
txn.execute(
"SELECT MAX(stream_ordering) as m FROM events"
)
res = self.cursor_to_dict(txn)
logger.debug("get_room_events_max_id: %s", res)
if not res or not res[0] or not res[0]["m"]:
return "s0"
key = res[0]["m"]
return "s%d" % (key,)
@staticmethod @staticmethod
def _set_before_and_after(events, rows): def _set_before_and_after(events, rows):
for event, row in zip(events, rows): for event, row in zip(events, rows):

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, Table, cached from ._base import SQLBaseStore, cached
from collections import namedtuple from collections import namedtuple
@ -76,22 +76,18 @@ class TransactionStore(SQLBaseStore):
response_json (str) response_json (str)
""" """
return self.runInteraction( return self._simple_insert(
"set_received_txn_response", table=ReceivedTransactionsTable.table_name,
self._set_received_txn_response, values={
transaction_id, origin, code, response_dict "transaction_id": transaction_id,
"origin": origin,
"response_code": code,
"response_json": response_dict,
},
or_ignore=True,
desc="set_received_txn_response",
) )
def _set_received_txn_response(self, txn, transaction_id, origin, code,
response_json):
query = (
"UPDATE %s "
"SET response_code = ?, response_json = ? "
"WHERE transaction_id = ? AND origin = ?"
) % ReceivedTransactionsTable.table_name
txn.execute(query, (code, response_json, transaction_id, origin))
def prep_send_transaction(self, transaction_id, destination, def prep_send_transaction(self, transaction_id, destination,
origin_server_ts): origin_server_ts):
"""Persists an outgoing transaction and calculates the values for the """Persists an outgoing transaction and calculates the values for the
@ -118,41 +114,38 @@ class TransactionStore(SQLBaseStore):
def _prep_send_transaction(self, txn, transaction_id, destination, def _prep_send_transaction(self, txn, transaction_id, destination,
origin_server_ts): origin_server_ts):
next_id = self._transaction_id_gen.get_next_txn(txn)
# First we find out what the prev_txns should be. # First we find out what the prev_txns should be.
# Since we know that we are only sending one transaction at a time, # Since we know that we are only sending one transaction at a time,
# we can simply take the last one. # we can simply take the last one.
query = "%s ORDER BY id DESC LIMIT 1" % ( query = (
SentTransactions.select_statement("destination = ?"), "SELECT * FROM sent_transactions"
" WHERE destination = ?"
" ORDER BY id DESC LIMIT 1"
) )
results = txn.execute(query, (destination,)) txn.execute(query, (destination,))
results = SentTransactions.decode_results(results) results = self.cursor_to_dict(txn)
prev_txns = [r.transaction_id for r in results] prev_txns = [r["transaction_id"] for r in results]
# Actually add the new transaction to the sent_transactions table. # Actually add the new transaction to the sent_transactions table.
query = SentTransactions.insert_statement() self._simple_insert_txn(
txn.execute(query, SentTransactions.EntryType( txn,
None, table=SentTransactions.table_name,
transaction_id=transaction_id, values={
destination=destination, "id": next_id,
ts=origin_server_ts, "transaction_id": transaction_id,
response_code=0, "destination": destination,
response_json=None "ts": origin_server_ts,
)) "response_code": 0,
"response_json": None,
}
)
# Update the tx id -> pdu id mapping # TODO Update the tx id -> pdu id mapping
# values = [
# (transaction_id, destination, pdu[0], pdu[1])
# for pdu in pdu_list
# ]
#
# logger.debug("Inserting: %s", repr(values))
#
# query = TransactionsToPduTable.insert_statement()
# txn.executemany(query, values)
return prev_txns return prev_txns
@ -171,15 +164,20 @@ class TransactionStore(SQLBaseStore):
transaction_id, destination, code, response_dict transaction_id, destination, code, response_dict
) )
def _delivered_txn(cls, txn, transaction_id, destination, def _delivered_txn(self, txn, transaction_id, destination,
code, response_json): code, response_json):
query = ( self._simple_update_one_txn(
"UPDATE %s " txn,
"SET response_code = ?, response_json = ? " table=SentTransactions.table_name,
"WHERE transaction_id = ? AND destination = ?" keyvalues={
) % SentTransactions.table_name "transaction_id": transaction_id,
"destination": destination,
txn.execute(query, (code, response_json, transaction_id, destination)) },
updatevalues={
"response_code": code,
"response_json": None, # For now, don't persist response_json
}
)
def get_transactions_after(self, transaction_id, destination): def get_transactions_after(self, transaction_id, destination):
"""Get all transactions after a given local transaction_id. """Get all transactions after a given local transaction_id.
@ -189,25 +187,26 @@ class TransactionStore(SQLBaseStore):
destination (str) destination (str)
Returns: Returns:
list: A list of `ReceivedTransactionsTable.EntryType` list: A list of dicts
""" """
return self.runInteraction( return self.runInteraction(
"get_transactions_after", "get_transactions_after",
self._get_transactions_after, transaction_id, destination self._get_transactions_after, transaction_id, destination
) )
def _get_transactions_after(cls, txn, transaction_id, destination): def _get_transactions_after(self, txn, transaction_id, destination):
where = ( query = (
"destination = ? AND id > (select id FROM %s WHERE " "SELECT * FROM sent_transactions"
"transaction_id = ? AND destination = ?)" " WHERE destination = ? AND id >"
) % ( " ("
SentTransactions.table_name " SELECT id FROM sent_transactions"
" WHERE transaction_id = ? AND destination = ?"
" )"
) )
query = SentTransactions.select_statement(where)
txn.execute(query, (destination, transaction_id, destination)) txn.execute(query, (destination, transaction_id, destination))
return ReceivedTransactionsTable.decode_results(txn.fetchall()) return self.cursor_to_dict(txn)
@cached() @cached()
def get_destination_retry_timings(self, destination): def get_destination_retry_timings(self, destination):
@ -218,19 +217,24 @@ class TransactionStore(SQLBaseStore):
Returns: Returns:
None if not retrying None if not retrying
Otherwise a DestinationsTable.EntryType for the retry scheme Otherwise a dict for the retry scheme
""" """
return self.runInteraction( return self.runInteraction(
"get_destination_retry_timings", "get_destination_retry_timings",
self._get_destination_retry_timings, destination) self._get_destination_retry_timings, destination)
def _get_destination_retry_timings(cls, txn, destination): def _get_destination_retry_timings(self, txn, destination):
query = DestinationsTable.select_statement("destination = ?") result = self._simple_select_one_txn(
txn.execute(query, (destination,)) txn,
result = txn.fetchall() table=DestinationsTable.table_name,
if result: keyvalues={
result = DestinationsTable.decode_single_result(result) "destination": destination,
if result.retry_last_ts > 0: },
retcols=DestinationsTable.fields,
allow_none=True,
)
if result and result["retry_last_ts"] > 0:
return result return result
else: else:
return None return None
@ -249,11 +253,11 @@ class TransactionStore(SQLBaseStore):
# As this is the new value, we might as well prefill the cache # As this is the new value, we might as well prefill the cache
self.get_destination_retry_timings.prefill( self.get_destination_retry_timings.prefill(
destination, destination,
DestinationsTable.EntryType( {
destination, "destination": destination,
retry_last_ts, "retry_last_ts": retry_last_ts,
retry_interval "retry_interval": retry_interval
) },
) )
# XXX: we could chose to not bother persisting this if our cache thinks # XXX: we could chose to not bother persisting this if our cache thinks
@ -266,22 +270,38 @@ class TransactionStore(SQLBaseStore):
retry_interval, retry_interval,
) )
def _set_destination_retry_timings(cls, txn, destination, def _set_destination_retry_timings(self, txn, destination,
retry_last_ts, retry_interval): retry_last_ts, retry_interval):
query = ( query = (
"INSERT OR REPLACE INTO %s " "UPDATE destinations"
"(destination, retry_last_ts, retry_interval) " " SET retry_last_ts = ?, retry_interval = ?"
"VALUES (?, ?, ?) " " WHERE destination = ?"
) % DestinationsTable.table_name )
txn.execute(query, (destination, retry_last_ts, retry_interval)) txn.execute(
query,
(
retry_last_ts, retry_interval, destination,
)
)
if txn.rowcount == 0:
# destination wasn't already in table. Insert it.
self._simple_insert_txn(
txn,
table="destinations",
values={
"destination": destination,
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval,
}
)
def get_destinations_needing_retry(self): def get_destinations_needing_retry(self):
"""Get all destinations which are due a retry for sending a transaction. """Get all destinations which are due a retry for sending a transaction.
Returns: Returns:
list: A list of `DestinationsTable.EntryType` list: A list of dicts
""" """
return self.runInteraction( return self.runInteraction(
@ -289,14 +309,17 @@ class TransactionStore(SQLBaseStore):
self._get_destinations_needing_retry self._get_destinations_needing_retry
) )
def _get_destinations_needing_retry(cls, txn): def _get_destinations_needing_retry(self, txn):
where = "retry_last_ts > 0 and retry_next_ts < now()" query = (
query = DestinationsTable.select_statement(where) "SELECT * FROM destinations"
txn.execute(query) " WHERE retry_last_ts > 0 and retry_next_ts < ?"
return DestinationsTable.decode_results(txn.fetchall()) )
txn.execute(query, (self._clock.time_msec(),))
return self.cursor_to_dict(txn)
class ReceivedTransactionsTable(Table): class ReceivedTransactionsTable(object):
table_name = "received_transactions" table_name = "received_transactions"
fields = [ fields = [
@ -308,10 +331,8 @@ class ReceivedTransactionsTable(Table):
"has_been_referenced", "has_been_referenced",
] ]
EntryType = namedtuple("ReceivedTransactionsEntry", fields)
class SentTransactions(object):
class SentTransactions(Table):
table_name = "sent_transactions" table_name = "sent_transactions"
fields = [ fields = [
@ -326,7 +347,7 @@ class SentTransactions(Table):
EntryType = namedtuple("SentTransactionsEntry", fields) EntryType = namedtuple("SentTransactionsEntry", fields)
class TransactionsToPduTable(Table): class TransactionsToPduTable(object):
table_name = "transaction_id_to_pdu" table_name = "transaction_id_to_pdu"
fields = [ fields = [
@ -336,10 +357,8 @@ class TransactionsToPduTable(Table):
"pdu_origin", "pdu_origin",
] ]
EntryType = namedtuple("TransactionsToPduEntry", fields)
class DestinationsTable(object):
class DestinationsTable(Table):
table_name = "destinations" table_name = "destinations"
fields = [ fields = [
@ -347,5 +366,3 @@ class DestinationsTable(Table):
"retry_last_ts", "retry_last_ts",
"retry_interval", "retry_interval",
] ]
EntryType = namedtuple("DestinationsEntry", fields)

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,131 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from collections import deque
import contextlib
import threading
class IdGenerator(object):
def __init__(self, table, column, store):
self.table = table
self.column = column
self.store = store
self._lock = threading.Lock()
self._next_id = None
@defer.inlineCallbacks
def get_next(self):
with self._lock:
if not self._next_id:
res = yield self.store._execute_and_decode(
"IdGenerator_%s" % (self.table,),
"SELECT MAX(%s) as mx FROM %s" % (self.column, self.table,)
)
self._next_id = (res and res[0] and res[0]["mx"]) or 1
i = self._next_id
self._next_id += 1
defer.returnValue(i)
def get_next_txn(self, txn):
with self._lock:
if self._next_id:
i = self._next_id
self._next_id += 1
return i
else:
txn.execute(
"SELECT MAX(%s) FROM %s" % (self.column, self.table,)
)
val, = txn.fetchone()
cur = val or 0
cur += 1
self._next_id = cur + 1
return cur
class StreamIdGenerator(object):
"""Used to generate new stream ids when persisting events while keeping
track of which transactions have been completed.
This allows us to get the "current" stream id, i.e. the stream id such that
all ids less than or equal to it have completed. This handles the fact that
persistence of events can complete out of order.
Usage:
with stream_id_gen.get_next_txn(txn) as stream_id:
# ... persist event ...
"""
def __init__(self):
self._lock = threading.Lock()
self._current_max = None
self._unfinished_ids = deque()
def get_next_txn(self, txn):
"""
Usage:
with stream_id_gen.get_next_txn(txn) as stream_id:
# ... persist event ...
"""
with self._lock:
if not self._current_max:
self._compute_current_max(txn)
self._current_max += 1
next_id = self._current_max
self._unfinished_ids.append(next_id)
@contextlib.contextmanager
def manager():
try:
yield next_id
finally:
with self._lock:
self._unfinished_ids.remove(next_id)
return manager()
@defer.inlineCallbacks
def get_max_token(self, store):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
with self._lock:
if self._unfinished_ids:
defer.returnValue(self._unfinished_ids[0] - 1)
if not self._current_max:
yield store.runInteraction(
"_compute_current_max",
self._compute_current_max,
)
defer.returnValue(self._current_max)
def _compute_current_max(self, txn):
txn.execute("SELECT MAX(stream_ordering) FROM events")
val, = txn.fetchone()
self._current_max = int(val) if val else 1
return self._current_max

View file

@ -14,6 +14,10 @@
# limitations under the License. # limitations under the License.
from functools import wraps
import threading
class LruCache(object): class LruCache(object):
"""Least-recently-used cache.""" """Least-recently-used cache."""
# TODO(mjark) Add mutex for linked list for thread safety. # TODO(mjark) Add mutex for linked list for thread safety.
@ -24,6 +28,16 @@ class LruCache(object):
PREV, NEXT, KEY, VALUE = 0, 1, 2, 3 PREV, NEXT, KEY, VALUE = 0, 1, 2, 3
lock = threading.Lock()
def synchronized(f):
@wraps(f)
def inner(*args, **kwargs):
with lock:
return f(*args, **kwargs)
return inner
def add_node(key, value): def add_node(key, value):
prev_node = list_root prev_node = list_root
next_node = prev_node[NEXT] next_node = prev_node[NEXT]
@ -51,6 +65,7 @@ class LruCache(object):
next_node[PREV] = prev_node next_node[PREV] = prev_node
cache.pop(node[KEY], None) cache.pop(node[KEY], None)
@synchronized
def cache_get(key, default=None): def cache_get(key, default=None):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
@ -59,6 +74,7 @@ class LruCache(object):
else: else:
return default return default
@synchronized
def cache_set(key, value): def cache_set(key, value):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
@ -69,6 +85,7 @@ class LruCache(object):
if len(cache) > max_size: if len(cache) > max_size:
delete_node(list_root[PREV]) delete_node(list_root[PREV])
@synchronized
def cache_set_default(key, value): def cache_set_default(key, value):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
@ -79,6 +96,7 @@ class LruCache(object):
delete_node(list_root[PREV]) delete_node(list_root[PREV])
return value return value
@synchronized
def cache_pop(key, default=None): def cache_pop(key, default=None):
node = cache.get(key, None) node = cache.get(key, None)
if node: if node:
@ -87,9 +105,11 @@ class LruCache(object):
else: else:
return default return default
@synchronized
def cache_len(): def cache_len():
return len(cache) return len(cache)
@synchronized
def cache_contains(key): def cache_contains(key):
return key in cache return key in cache

Some files were not shown because too many files have changed in this diff Show more