Merge branch 'develop' of github.com:matrix-org/synapse into mysql

This commit is contained in:
Erik Johnston 2015-03-20 16:31:48 +00:00
commit f6583796fe
37 changed files with 542 additions and 145 deletions

View file

@ -1,3 +1,12 @@
Changes in synapse v0.8.1 (2015-03-18)
======================================
* Disable registration by default. New users can be added using the command
``register_new_matrix_user`` or by enabling registration in the config.
* Add metrics to synapse. To enable metrics use config options
``enable_metrics`` and ``metrics_port``.
* Fix bug where banning only kicked the user.
Changes in synapse v0.8.0 (2015-03-06) Changes in synapse v0.8.0 (2015-03-06)
====================================== ======================================

View file

@ -128,6 +128,17 @@ To set up your homeserver, run (in your virtualenv, as before)::
Substituting your host and domain name as appropriate. Substituting your host and domain name as appropriate.
By default, registration of new users is disabled. You can either enable
registration in the config (it is then recommended to also set up CAPTCHA), or
you can use the command line to register new users::
$ source ~/.synapse/bin/activate
$ register_new_matrix_user -c homeserver.yaml https://localhost:8448
New user localpart: erikj
Password:
Confirm password:
Success!
For reliable VoIP calls to be routed via this homeserver, you MUST configure For reliable VoIP calls to be routed via this homeserver, you MUST configure
a TURN server. See docs/turn-howto.rst for details. a TURN server. See docs/turn-howto.rst for details.

149
register_new_matrix_user Executable file
View file

@ -0,0 +1,149 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import getpass
import hashlib
import hmac
import json
import sys
import urllib2
import yaml
def request_registration(user, password, server_location, shared_secret):
mac = hmac.new(
key=shared_secret,
msg=user,
digestmod=hashlib.sha1,
).hexdigest()
data = {
"user": user,
"password": password,
"mac": mac,
"type": "org.matrix.login.shared_secret",
}
server_location = server_location.rstrip("/")
print "Sending registration request..."
req = urllib2.Request(
"%s/_matrix/client/api/v1/register" % (server_location,),
data=json.dumps(data),
headers={'Content-Type': 'application/json'}
)
try:
f = urllib2.urlopen(req)
f.read()
f.close()
print "Success."
except urllib2.HTTPError as e:
print "ERROR! Received %d %s" % (e.code, e.reason,)
if 400 <= e.code < 500:
if e.info().type == "application/json":
resp = json.load(e)
if "error" in resp:
print resp["error"]
sys.exit(1)
def register_new_user(user, password, server_location, shared_secret):
if not user:
try:
default_user = getpass.getuser()
except:
default_user = None
if default_user:
user = raw_input("New user localpart [%s]: " % (default_user,))
if not user:
user = default_user
else:
user = raw_input("New user localpart: ")
if not user:
print "Invalid user name"
sys.exit(1)
if not password:
password = getpass.getpass("Password: ")
if not password:
print "Password cannot be blank."
sys.exit(1)
confirm_password = getpass.getpass("Confirm password: ")
if password != confirm_password:
print "Passwords do not match"
sys.exit(1)
request_registration(user, password, server_location, shared_secret)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Used to register new users with a given home server when"
" registration has been disabled. The home server must be"
" configured with the 'registration_shared_secret' option"
" set.",
)
parser.add_argument(
"-u", "--user",
default=None,
help="Local part of the new user. Will prompt if omitted.",
)
parser.add_argument(
"-p", "--password",
default=None,
help="New password for user. Will prompt if omitted.",
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"-c", "--config",
type=argparse.FileType('r'),
help="Path to server config file. Used to read in shared secret.",
)
group.add_argument(
"-k", "--shared-secret",
help="Shared secret as defined in server config file.",
)
parser.add_argument(
"server_url",
default="https://localhost:8448",
nargs='?',
help="URL to use to talk to the home server. Defaults to "
" 'https://localhost:8448'.",
)
args = parser.parse_args()
if "config" in args and args.config:
config = yaml.safe_load(args.config)
secret = config.get("registration_shared_secret", None)
if not secret:
print "No 'registration_shared_secret' defined in config."
sys.exit(1)
else:
secret = args.shared_secret
register_new_user(args.user, args.password, args.server_url, secret)

View file

@ -45,7 +45,7 @@ setup(
version=version, version=version,
packages=find_packages(exclude=["tests", "tests.*"]), packages=find_packages(exclude=["tests", "tests.*"]),
description="Reference Synapse Home Server", description="Reference Synapse Home Server",
install_requires=dependencies["REQUIREMENTS"].keys(), install_requires=dependencies['requirements'](include_conditional=True).keys(),
setup_requires=[ setup_requires=[
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0 "Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
"setuptools_trial", "setuptools_trial",
@ -55,5 +55,5 @@ setup(
include_package_data=True, include_package_data=True,
zip_safe=False, zip_safe=False,
long_description=long_description, long_description=long_description,
scripts=["synctl"], scripts=["synctl", "register_new_matrix_user"],
) )

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.8.0" __version__ = "0.8.1-r2"

View file

@ -388,7 +388,7 @@ class Auth(object):
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
try: try:
ret = yield self.store.get_user_by_token(token=token) ret = yield self.store.get_user_by_token(token)
if not ret: if not ret:
raise StoreError(400, "Unknown token") raise StoreError(400, "Unknown token")
user_info = { user_info = {

View file

@ -60,6 +60,7 @@ class LoginType(object):
EMAIL_IDENTITY = u"m.login.email.identity" EMAIL_IDENTITY = u"m.login.email.identity"
RECAPTCHA = u"m.login.recaptcha" RECAPTCHA = u"m.login.recaptcha"
APPLICATION_SERVICE = u"m.login.application_service" APPLICATION_SERVICE = u"m.login.application_service"
SHARED_SECRET = u"org.matrix.login.shared_secret"
class EventTypes(object): class EventTypes(object):

View file

@ -60,9 +60,9 @@ import re
import resource import resource
import subprocess import subprocess
import sqlite3 import sqlite3
import syweb
import yaml import yaml
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -84,6 +84,7 @@ class SynapseHomeServer(HomeServer):
return AppServiceRestResource(self) return AppServiceRestResource(self)
def build_resource_for_web_client(self): def build_resource_for_web_client(self):
import syweb
syweb_path = os.path.dirname(syweb.__file__) syweb_path = os.path.dirname(syweb.__file__)
webclient_path = os.path.join(syweb_path, "webclient") webclient_path = os.path.join(syweb_path, "webclient")
return File(webclient_path) # TODO configurable? return File(webclient_path) # TODO configurable?
@ -131,7 +132,7 @@ class SynapseHomeServer(HomeServer):
True. True.
""" """
config = self.get_config() config = self.get_config()
web_client = config.webclient web_client = config.web_client
# list containing (path_str, Resource) e.g: # list containing (path_str, Resource) e.g:
# [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ] # [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ]
@ -344,7 +345,8 @@ def setup(config_options):
config.setup_logging() config.setup_logging()
check_requirements() # check any extra requirements we have now we have a config
check_requirements(config)
version_string = get_version_string() version_string = get_version_string()
@ -472,6 +474,7 @@ def run(hs):
def main(): def main():
with LoggingContext("main"): with LoggingContext("main"):
# check base requirements
check_requirements() check_requirements()
hs = setup(sys.argv[1:]) hs = setup(sys.argv[1:])
run(hs) run(hs)

View file

@ -15,19 +15,46 @@
from ._base import Config from ._base import Config
from synapse.util.stringutils import random_string_with_symbols
import distutils.util
class RegistrationConfig(Config): class RegistrationConfig(Config):
def __init__(self, args): def __init__(self, args):
super(RegistrationConfig, self).__init__(args) super(RegistrationConfig, self).__init__(args)
self.disable_registration = args.disable_registration
# `args.disable_registration` may either be a bool or a string depending
# on if the option was given a value (e.g. --disable-registration=false
# would set `args.disable_registration` to "false" not False.)
self.disable_registration = bool(
distutils.util.strtobool(str(args.disable_registration))
)
self.registration_shared_secret = args.registration_shared_secret
@classmethod @classmethod
def add_arguments(cls, parser): def add_arguments(cls, parser):
super(RegistrationConfig, cls).add_arguments(parser) super(RegistrationConfig, cls).add_arguments(parser)
reg_group = parser.add_argument_group("registration") reg_group = parser.add_argument_group("registration")
reg_group.add_argument( reg_group.add_argument(
"--disable-registration", "--disable-registration",
action='store_true', const=True,
help="Disable registration of new users." default=True,
nargs='?',
help="Disable registration of new users.",
) )
reg_group.add_argument(
"--registration-shared-secret", type=str,
help="If set, allows registration by anyone who also has the shared"
" secret, even if registration is otherwise disabled.",
)
@classmethod
def generate_config(cls, args, config_dir_path):
if args.disable_registration is None:
args.disable_registration = True
if args.registration_shared_secret is None:
args.registration_shared_secret = random_string_with_symbols(50)

View file

@ -28,7 +28,7 @@ class ServerConfig(Config):
self.unsecure_port = args.unsecure_port self.unsecure_port = args.unsecure_port
self.daemonize = args.daemonize self.daemonize = args.daemonize
self.pid_file = self.abspath(args.pid_file) self.pid_file = self.abspath(args.pid_file)
self.webclient = True self.web_client = args.web_client
self.manhole = args.manhole self.manhole = args.manhole
self.soft_file_limit = args.soft_file_limit self.soft_file_limit = args.soft_file_limit
@ -68,6 +68,8 @@ class ServerConfig(Config):
server_group.add_argument('--pid-file', default="homeserver.pid", server_group.add_argument('--pid-file', default="homeserver.pid",
help="When running as a daemon, the file to" help="When running as a daemon, the file to"
" store the pid in") " store the pid in")
server_group.add_argument('--web_client', default=True, type=bool,
help="Whether or not to serve a web client")
server_group.add_argument("--manhole", metavar="PORT", dest="manhole", server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
type=int, type=int,
help="Turn on the twisted telnet manhole" help="Turn on the twisted telnet manhole"

View file

@ -361,4 +361,5 @@ SERVLET_CLASSES = (
FederationInviteServlet, FederationInviteServlet,
FederationQueryAuthServlet, FederationQueryAuthServlet,
FederationGetMissingEventsServlet, FederationGetMissingEventsServlet,
FederationEventAuthServlet,
) )

View file

@ -290,6 +290,8 @@ class FederationHandler(BaseHandler):
""" """
logger.debug("Joining %s to %s", joinee, room_id) logger.debug("Joining %s to %s", joinee, room_id)
yield self.store.clean_room_for_join(room_id)
origin, pdu = yield self.replication_layer.make_join( origin, pdu = yield self.replication_layer.make_join(
target_hosts, target_hosts,
room_id, room_id,

View file

@ -31,6 +31,7 @@ import base64
import bcrypt import bcrypt
import json import json
import logging import logging
import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -63,6 +64,13 @@ class RegistrationHandler(BaseHandler):
password_hash = bcrypt.hashpw(password, bcrypt.gensalt()) password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
if localpart: if localpart:
if localpart and urllib.quote(localpart) != localpart:
raise SynapseError(
400,
"User ID must only contain characters which do not"
" require URL encoding."
)
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()

View file

@ -51,8 +51,8 @@ class RestServlet(object):
pattern = self.PATTERN pattern = self.PATTERN
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"): for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
if hasattr(self, "on_%s" % (method)): if hasattr(self, "on_%s" % (method,)):
method_handler = getattr(self, "on_%s" % (method)) method_handler = getattr(self, "on_%s" % (method,))
http_server.register_path(method, pattern, method_handler) http_server.register_path(method, pattern, method_handler)
else: else:
raise NotImplementedError("RestServlet must register something.") raise NotImplementedError("RestServlet must register something.")

View file

@ -5,7 +5,6 @@ logger = logging.getLogger(__name__)
REQUIREMENTS = { REQUIREMENTS = {
"syutil>=0.0.3": ["syutil"], "syutil>=0.0.3": ["syutil"],
"matrix_angular_sdk>=0.6.5": ["syweb>=0.6.5"],
"Twisted==14.0.2": ["twisted==14.0.2"], "Twisted==14.0.2": ["twisted==14.0.2"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyopenssl>=0.14": ["OpenSSL>=0.14"],
@ -18,6 +17,19 @@ REQUIREMENTS = {
"pillow": ["PIL"], "pillow": ["PIL"],
"pydenticon": ["pydenticon"], "pydenticon": ["pydenticon"],
} }
CONDITIONAL_REQUIREMENTS = {
"web_client": {
"matrix_angular_sdk>=0.6.5": ["syweb>=0.6.5"],
}
}
def requirements(config=None, include_conditional=False):
reqs = REQUIREMENTS.copy()
for key, req in CONDITIONAL_REQUIREMENTS.items():
if (config and getattr(config, key)) or include_conditional:
reqs.update(req)
return reqs
def github_link(project, version, egg): def github_link(project, version, egg):
@ -46,10 +58,11 @@ class MissingRequirementError(Exception):
pass pass
def check_requirements(): def check_requirements(config=None):
"""Checks that all the modules needed by synapse have been correctly """Checks that all the modules needed by synapse have been correctly
installed and are at the correct version""" installed and are at the correct version"""
for dependency, module_requirements in REQUIREMENTS.items(): for dependency, module_requirements in (
requirements(config, include_conditional=False).items()):
for module_requirement in module_requirements: for module_requirement in module_requirements:
if ">=" in module_requirement: if ">=" in module_requirement:
module_name, required_version = module_requirement.split(">=") module_name, required_version = module_requirement.split(">=")
@ -110,7 +123,7 @@ def list_requirements():
egg = link.split("#egg=")[1] egg = link.split("#egg=")[1]
linked.append(egg.split('-')[0]) linked.append(egg.split('-')[0])
result.append(link) result.append(link)
for requirement in REQUIREMENTS: for requirement in requirements(include_conditional=True):
is_linked = False is_linked = False
for link in linked: for link in linked:
if requirement.replace('-', '_').startswith(link): if requirement.replace('-', '_').startswith(link):

View file

@ -27,7 +27,6 @@ from hashlib import sha1
import hmac import hmac
import simplejson as json import simplejson as json
import logging import logging
import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -110,14 +109,22 @@ class RegisterRestServlet(ClientV1RestServlet):
login_type = register_json["type"] login_type = register_json["type"]
is_application_server = login_type == LoginType.APPLICATION_SERVICE is_application_server = login_type == LoginType.APPLICATION_SERVICE
if self.disable_registration and not is_application_server: is_using_shared_secret = login_type == LoginType.SHARED_SECRET
can_register = (
not self.disable_registration
or is_application_server
or is_using_shared_secret
)
if not can_register:
raise SynapseError(403, "Registration has been disabled") raise SynapseError(403, "Registration has been disabled")
stages = { stages = {
LoginType.RECAPTCHA: self._do_recaptcha, LoginType.RECAPTCHA: self._do_recaptcha,
LoginType.PASSWORD: self._do_password, LoginType.PASSWORD: self._do_password,
LoginType.EMAIL_IDENTITY: self._do_email_identity, LoginType.EMAIL_IDENTITY: self._do_email_identity,
LoginType.APPLICATION_SERVICE: self._do_app_service LoginType.APPLICATION_SERVICE: self._do_app_service,
LoginType.SHARED_SECRET: self._do_shared_secret,
} }
session_info = self._get_session_info(request, session) session_info = self._get_session_info(request, session)
@ -255,14 +262,11 @@ class RegisterRestServlet(ClientV1RestServlet):
) )
password = register_json["password"].encode("utf-8") password = register_json["password"].encode("utf-8")
desired_user_id = (register_json["user"].encode("utf-8") desired_user_id = (
if "user" in register_json else None) register_json["user"].encode("utf-8")
if (desired_user_id if "user" in register_json else None
and urllib.quote(desired_user_id) != desired_user_id): )
raise SynapseError(
400,
"User ID must only contain characters which do not " +
"require URL encoding.")
handler = self.handlers.registration_handler handler = self.handlers.registration_handler
(user_id, token) = yield handler.register( (user_id, token) = yield handler.register(
localpart=desired_user_id, localpart=desired_user_id,
@ -304,6 +308,51 @@ class RegisterRestServlet(ClientV1RestServlet):
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
}) })
@defer.inlineCallbacks
def _do_shared_secret(self, request, register_json, session):
yield run_on_reactor()
if not isinstance(register_json.get("mac", None), basestring):
raise SynapseError(400, "Expected mac.")
if not isinstance(register_json.get("user", None), basestring):
raise SynapseError(400, "Expected 'user' key.")
if not isinstance(register_json.get("password", None), basestring):
raise SynapseError(400, "Expected 'password' key.")
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
user = register_json["user"].encode("utf-8")
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
got_mac = str(register_json["mac"])
want_mac = hmac.new(
key=self.hs.config.registration_shared_secret,
msg=user,
digestmod=sha1,
).hexdigest()
password = register_json["password"].encode("utf-8")
if compare_digest(want_mac, got_mac):
handler = self.handlers.registration_handler
user_id, token = yield handler.register(
localpart=user,
password=password,
)
self._remove_session(session)
defer.returnValue({
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
})
else:
raise SynapseError(
403, "HMAC incorrect",
)
def _parse_json(request): def _parse_json(request):
try: try:

View file

@ -91,7 +91,7 @@ class DataStore(RoomMemberStore, RoomStore,
"user_agent": user_agent, "user_agent": user_agent,
"last_seen": int(self._clock.time_msec()), "last_seen": int(self._clock.time_msec()),
}, },
or_replace=True, desc="insert_client_ip",
) )
def get_user_ip_and_agents(self, user): def get_user_ip_and_agents(self, user):
@ -101,6 +101,7 @@ class DataStore(RoomMemberStore, RoomStore,
retcols=[ retcols=[
"device_id", "access_token", "ip", "user_agent", "last_seen" "device_id", "access_token", "ip", "user_agent", "last_seen"
], ],
desc="get_user_ip_and_agents",
) )

View file

@ -25,6 +25,7 @@ import synapse.metrics
from twisted.internet import defer from twisted.internet import defer
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
import functools
import simplejson as json import simplejson as json
import sys import sys
import time import time
@ -53,13 +54,12 @@ cache_counter = metrics.register_cache(
# TODO(paul): # TODO(paul):
# * more generic key management
# * consider other eviction strategies - LRU? # * consider other eviction strategies - LRU?
def cached(max_entries=1000): def cached(max_entries=1000, num_args=1):
""" A method decorator that applies a memoizing cache around the function. """ A method decorator that applies a memoizing cache around the function.
The function is presumed to take one additional argument, which is used as The function is presumed to take zero or more arguments, which are used in
the key for the cache. Cache hits are served directly from the cache; a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value. misses use the function body to generate the value.
The wrapped function has an additional member, a callable called The wrapped function has an additional member, a callable called
@ -75,25 +75,41 @@ def cached(max_entries=1000):
caches_by_name[name] = cache caches_by_name[name] = cache
def prefill(key, value): def prefill(*args): # because I can't *keyargs, value
keyargs = args[:-1]
value = args[-1]
if len(keyargs) != num_args:
raise ValueError("Expected a call to have %d arguments", num_args)
while len(cache) > max_entries: while len(cache) > max_entries:
cache.popitem(last=False) cache.popitem(last=False)
cache[key] = value cache[keyargs] = value
@functools.wraps(orig)
@defer.inlineCallbacks @defer.inlineCallbacks
def wrapped(self, key): def wrapped(self, *keyargs):
if key in cache: if len(keyargs) != num_args:
raise ValueError("Expected a call to have %d arguments", num_args)
if keyargs in cache:
cache_counter.inc_hits(name) cache_counter.inc_hits(name)
defer.returnValue(cache[key]) defer.returnValue(cache[keyargs])
cache_counter.inc_misses(name) cache_counter.inc_misses(name)
ret = yield orig(self, key) ret = yield orig(self, *keyargs)
prefill(key, ret)
prefill_args = keyargs + (ret,)
prefill(*prefill_args)
defer.returnValue(ret) defer.returnValue(ret)
def invalidate(key): def invalidate(*keyargs):
cache.pop(key, None) if len(keyargs) != num_args:
raise ValueError("Expected a call to have %d arguments", num_args)
cache.pop(keyargs, None)
wrapped.invalidate = invalidate wrapped.invalidate = invalidate
wrapped.prefill = prefill wrapped.prefill = prefill
@ -325,7 +341,8 @@ class SQLBaseStore(object):
# "Simple" SQL API methods that operate on a single table with no JOINs, # "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns. # no complex WHERE clauses, just a dict of values for columns.
def _simple_insert(self, table, values, or_replace=False, or_ignore=False): def _simple_insert(self, table, values, or_replace=False, or_ignore=False,
desc="_simple_insert"):
"""Executes an INSERT query on the named table. """Executes an INSERT query on the named table.
Args: Args:
@ -334,7 +351,7 @@ class SQLBaseStore(object):
or_replace : bool; if True performs an INSERT OR REPLACE or_replace : bool; if True performs an INSERT OR REPLACE
""" """
return self.runInteraction( return self.runInteraction(
"_simple_insert", desc,
self._simple_insert_txn, table, values, or_replace=or_replace, self._simple_insert_txn, table, values, or_replace=or_replace,
or_ignore=or_ignore, or_ignore=or_ignore,
) )
@ -357,7 +374,7 @@ class SQLBaseStore(object):
txn.execute(sql, values.values()) txn.execute(sql, values.values())
return txn.lastrowid return txn.lastrowid
def _simple_upsert(self, table, keyvalues, values): def _simple_upsert(self, table, keyvalues, values, desc="_simple_upsert"):
""" """
Args: Args:
table (str): The table to upsert into table (str): The table to upsert into
@ -366,7 +383,7 @@ class SQLBaseStore(object):
Returns: A deferred Returns: A deferred
""" """
return self.runInteraction( return self.runInteraction(
"_simple_upsert", desc,
self._simple_upsert_txn, table, keyvalues, values self._simple_upsert_txn, table, keyvalues, values
) )
@ -402,7 +419,7 @@ class SQLBaseStore(object):
txn.execute(sql, allvalues.values()) txn.execute(sql, allvalues.values())
def _simple_select_one(self, table, keyvalues, retcols, def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False): allow_none=False, desc="_simple_select_one"):
"""Executes a SELECT query on the named table, which is expected to """Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it. return a single row, returning a single column from it.
@ -414,12 +431,15 @@ class SQLBaseStore(object):
allow_none : If true, return None instead of failing if the SELECT allow_none : If true, return None instead of failing if the SELECT
statement returns no rows statement returns no rows
""" """
return self._simple_selectupdate_one( return self.runInteraction(
table, keyvalues, retcols=retcols, allow_none=allow_none desc,
self._simple_select_one_txn,
table, keyvalues, retcols, allow_none,
) )
def _simple_select_one_onecol(self, table, keyvalues, retcol, def _simple_select_one_onecol(self, table, keyvalues, retcol,
allow_none=False): allow_none=False,
desc="_simple_select_one_onecol"):
"""Executes a SELECT query on the named table, which is expected to """Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it." return a single row, returning a single column from it."
@ -429,7 +449,7 @@ class SQLBaseStore(object):
retcol : string giving the name of the column to return retcol : string giving the name of the column to return
""" """
return self.runInteraction( return self.runInteraction(
"_simple_select_one_onecol", desc,
self._simple_select_one_onecol_txn, self._simple_select_one_onecol_txn,
table, keyvalues, retcol, allow_none=allow_none, table, keyvalues, retcol, allow_none=allow_none,
) )
@ -464,7 +484,8 @@ class SQLBaseStore(object):
return [r[0] for r in txn.fetchall()] return [r[0] for r in txn.fetchall()]
def _simple_select_onecol(self, table, keyvalues, retcol): def _simple_select_onecol(self, table, keyvalues, retcol,
desc="_simple_select_onecol"):
"""Executes a SELECT query on the named table, which returns a list """Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows. comprising of the values of the named column from the selected rows.
@ -477,12 +498,13 @@ class SQLBaseStore(object):
Deferred: Results in a list Deferred: Results in a list
""" """
return self.runInteraction( return self.runInteraction(
"_simple_select_onecol", desc,
self._simple_select_onecol_txn, self._simple_select_onecol_txn,
table, keyvalues, retcol table, keyvalues, retcol
) )
def _simple_select_list(self, table, keyvalues, retcols): def _simple_select_list(self, table, keyvalues, retcols,
desc="_simple_select_list"):
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows, returning the result as a list of dicts.
@ -493,7 +515,7 @@ class SQLBaseStore(object):
retcols : list of strings giving the names of the columns to return retcols : list of strings giving the names of the columns to return
""" """
return self.runInteraction( return self.runInteraction(
"_simple_select_list", desc,
self._simple_select_list_txn, self._simple_select_list_txn,
table, keyvalues, retcols table, keyvalues, retcols
) )
@ -525,7 +547,7 @@ class SQLBaseStore(object):
return self.cursor_to_dict(txn) return self.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues, def _simple_update_one(self, table, keyvalues, updatevalues,
retcols=None): desc="_simple_update_one"):
"""Executes an UPDATE query on the named table, setting new values for """Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values. columns in a row matching the key values.
@ -543,29 +565,37 @@ class SQLBaseStore(object):
get-and-set. This can be used to implement compare-and-set by putting get-and-set. This can be used to implement compare-and-set by putting
the update column in the 'keyvalues' dict as well. the update column in the 'keyvalues' dict as well.
""" """
return self._simple_selectupdate_one(table, keyvalues, updatevalues, return self.runInteraction(
retcols=retcols) desc,
self._simple_update_one_txn,
def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None, table, keyvalues, updatevalues,
retcols=None, allow_none=False):
""" Combined SELECT then UPDATE."""
if retcols:
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k) for k in keyvalues)
) )
if updatevalues: def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
update_sql = "UPDATE %s SET %s WHERE %s" % ( update_sql = "UPDATE %s SET %s WHERE %s" % (
table, table,
", ".join("%s = ?" % (k,) for k in updatevalues), ", ".join("%s = ?" % (k,) for k in updatevalues),
" AND ".join("%s = ?" % (k,) for k in keyvalues) " AND ".join("%s = ?" % (k,) for k in keyvalues)
) )
def func(txn): txn.execute(
ret = None update_sql,
if retcols: updatevalues.values() + keyvalues.values()
)
if txn.rowcount == 0:
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
allow_none=False):
select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k) for k in keyvalues)
)
txn.execute(select_sql, keyvalues.values()) txn.execute(select_sql, keyvalues.values())
row = txn.fetchone() row = txn.fetchone()
@ -576,12 +606,29 @@ class SQLBaseStore(object):
if txn.rowcount > 1: if txn.rowcount > 1:
raise StoreError(500, "More than one row matched") raise StoreError(500, "More than one row matched")
ret = dict(zip(retcols, row)) return dict(zip(retcols, row))
def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
retcols=None, allow_none=False,
desc="_simple_selectupdate_one"):
""" Combined SELECT then UPDATE."""
def func(txn):
ret = None
if retcols:
ret = self._simple_select_one_txn(
txn,
table=table,
keyvalues=keyvalues,
retcols=retcols,
allow_none=allow_none,
)
if updatevalues: if updatevalues:
txn.execute( self._simple_update_one_txn(
update_sql, txn,
updatevalues.values() + keyvalues.values() table=table,
keyvalues=keyvalues,
updatevalues=updatevalues,
) )
# if txn.rowcount == 0: # if txn.rowcount == 0:
@ -590,9 +637,9 @@ class SQLBaseStore(object):
raise StoreError(500, "More than one row matched") raise StoreError(500, "More than one row matched")
return ret return ret
return self.runInteraction("_simple_selectupdate_one", func) return self.runInteraction(desc, func)
def _simple_delete_one(self, table, keyvalues): def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
"""Executes a DELETE query on the named table, expecting to delete a """Executes a DELETE query on the named table, expecting to delete a
single row. single row.
@ -611,9 +658,9 @@ class SQLBaseStore(object):
raise StoreError(404, "No row found") raise StoreError(404, "No row found")
if txn.rowcount > 1: if txn.rowcount > 1:
raise StoreError(500, "more than one row matched") raise StoreError(500, "more than one row matched")
return self.runInteraction("_simple_delete_one", func) return self.runInteraction(desc, func)
def _simple_delete(self, table, keyvalues): def _simple_delete(self, table, keyvalues, desc="_simple_delete"):
"""Executes a DELETE query on the named table. """Executes a DELETE query on the named table.
Args: Args:
@ -621,7 +668,7 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the row with keyvalues : dict of column names and values to select the row with
""" """
return self.runInteraction("_simple_delete", self._simple_delete_txn) return self.runInteraction(desc, self._simple_delete_txn)
def _simple_delete_txn(self, txn, table, keyvalues): def _simple_delete_txn(self, txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % ( sql = "DELETE FROM %s WHERE %s" % (

View file

@ -48,6 +48,7 @@ class DirectoryStore(SQLBaseStore):
{"room_alias": room_alias.to_string()}, {"room_alias": room_alias.to_string()},
"room_id", "room_id",
allow_none=True, allow_none=True,
desc="get_association_from_room_alias",
) )
if not room_id: if not room_id:
@ -58,6 +59,7 @@ class DirectoryStore(SQLBaseStore):
"room_alias_servers", "room_alias_servers",
{"room_alias": room_alias.to_string()}, {"room_alias": room_alias.to_string()},
"server", "server",
desc="get_association_from_room_alias",
) )
if not servers: if not servers:
@ -87,6 +89,7 @@ class DirectoryStore(SQLBaseStore):
"room_alias": room_alias.to_string(), "room_alias": room_alias.to_string(),
"room_id": room_id, "room_id": room_id,
}, },
desc="create_room_alias_association",
) )
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
raise SynapseError( raise SynapseError(
@ -100,7 +103,8 @@ class DirectoryStore(SQLBaseStore):
{ {
"room_alias": room_alias.to_string(), "room_alias": room_alias.to_string(),
"server": server, "server": server,
} },
desc="create_room_alias_association",
) )
def delete_room_alias(self, room_alias): def delete_room_alias(self, room_alias):
@ -139,4 +143,5 @@ class DirectoryStore(SQLBaseStore):
"room_aliases", "room_aliases",
{"room_id": room_id}, {"room_id": room_id},
"room_alias", "room_alias",
desc="get_aliases_for_room",
) )

View file

@ -426,3 +426,15 @@ class EventFederationStore(SQLBaseStore):
) )
return events[:limit] return events[:limit]
def clean_room_for_join(self, room_id):
return self.runInteraction(
"clean_room_for_join",
self._clean_room_for_join_txn,
room_id,
)
def _clean_room_for_join_txn(self, txn, room_id):
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,))

View file

@ -52,6 +52,7 @@ class EventsStore(SQLBaseStore):
is_new_state=is_new_state, is_new_state=is_new_state,
current_state=current_state, current_state=current_state,
) )
self.get_room_events_max_id.invalidate()
except _RollbackButIsFineException: except _RollbackButIsFineException:
pass pass
@ -242,7 +243,6 @@ class EventsStore(SQLBaseStore):
if stream_ordering is None: if stream_ordering is None:
stream_ordering = self.get_next_stream_id() stream_ordering = self.get_next_stream_id()
unrec = { unrec = {
k: v k: v
for k, v in event.get_dict().items() for k, v in event.get_dict().items()

View file

@ -31,6 +31,7 @@ class FilteringStore(SQLBaseStore):
}, },
retcol="filter_json", retcol="filter_json",
allow_none=False, allow_none=False,
desc="get_user_filter",
) )
defer.returnValue(json.loads(def_json)) defer.returnValue(json.loads(def_json))

View file

@ -32,6 +32,7 @@ class MediaRepositoryStore(SQLBaseStore):
{"media_id": media_id}, {"media_id": media_id},
("media_type", "media_length", "upload_name", "created_ts"), ("media_type", "media_length", "upload_name", "created_ts"),
allow_none=True, allow_none=True,
desc="get_local_media",
) )
def store_local_media(self, media_id, media_type, time_now_ms, upload_name, def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
@ -45,7 +46,8 @@ class MediaRepositoryStore(SQLBaseStore):
"upload_name": upload_name, "upload_name": upload_name,
"media_length": media_length, "media_length": media_length,
"user_id": user_id.to_string(), "user_id": user_id.to_string(),
} },
desc="store_local_media",
) )
def get_local_media_thumbnails(self, media_id): def get_local_media_thumbnails(self, media_id):
@ -55,7 +57,8 @@ class MediaRepositoryStore(SQLBaseStore):
( (
"thumbnail_width", "thumbnail_height", "thumbnail_method", "thumbnail_width", "thumbnail_height", "thumbnail_method",
"thumbnail_type", "thumbnail_length", "thumbnail_type", "thumbnail_length",
) ),
desc="get_local_media_thumbnails",
) )
def store_local_thumbnail(self, media_id, thumbnail_width, def store_local_thumbnail(self, media_id, thumbnail_width,
@ -70,7 +73,8 @@ class MediaRepositoryStore(SQLBaseStore):
"thumbnail_method": thumbnail_method, "thumbnail_method": thumbnail_method,
"thumbnail_type": thumbnail_type, "thumbnail_type": thumbnail_type,
"thumbnail_length": thumbnail_length, "thumbnail_length": thumbnail_length,
} },
desc="store_local_thumbnail",
) )
def get_cached_remote_media(self, origin, media_id): def get_cached_remote_media(self, origin, media_id):
@ -82,6 +86,7 @@ class MediaRepositoryStore(SQLBaseStore):
"filesystem_id", "filesystem_id",
), ),
allow_none=True, allow_none=True,
desc="get_cached_remote_media",
) )
def store_cached_remote_media(self, origin, media_id, media_type, def store_cached_remote_media(self, origin, media_id, media_type,
@ -97,7 +102,8 @@ class MediaRepositoryStore(SQLBaseStore):
"created_ts": time_now_ms, "created_ts": time_now_ms,
"upload_name": upload_name, "upload_name": upload_name,
"filesystem_id": filesystem_id, "filesystem_id": filesystem_id,
} },
desc="store_cached_remote_media",
) )
def get_remote_media_thumbnails(self, origin, media_id): def get_remote_media_thumbnails(self, origin, media_id):
@ -107,7 +113,8 @@ class MediaRepositoryStore(SQLBaseStore):
( (
"thumbnail_width", "thumbnail_height", "thumbnail_method", "thumbnail_width", "thumbnail_height", "thumbnail_method",
"thumbnail_type", "thumbnail_length", "filesystem_id", "thumbnail_type", "thumbnail_length", "filesystem_id",
) ),
desc="get_remote_media_thumbnails",
) )
def store_remote_media_thumbnail(self, origin, media_id, filesystem_id, def store_remote_media_thumbnail(self, origin, media_id, filesystem_id,
@ -125,5 +132,6 @@ class MediaRepositoryStore(SQLBaseStore):
"thumbnail_type": thumbnail_type, "thumbnail_type": thumbnail_type,
"thumbnail_length": thumbnail_length, "thumbnail_length": thumbnail_length,
"filesystem_id": filesystem_id, "filesystem_id": filesystem_id,
} },
desc="store_remote_media_thumbnail",
) )

View file

@ -21,6 +21,7 @@ class PresenceStore(SQLBaseStore):
return self._simple_insert( return self._simple_insert(
table="presence", table="presence",
values={"user_id": user_localpart}, values={"user_id": user_localpart},
desc="create_presence",
) )
def has_presence_state(self, user_localpart): def has_presence_state(self, user_localpart):
@ -29,6 +30,7 @@ class PresenceStore(SQLBaseStore):
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
retcols=["user_id"], retcols=["user_id"],
allow_none=True, allow_none=True,
desc="has_presence_state",
) )
def get_presence_state(self, user_localpart): def get_presence_state(self, user_localpart):
@ -36,6 +38,7 @@ class PresenceStore(SQLBaseStore):
table="presence", table="presence",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
retcols=["state", "status_msg", "mtime"], retcols=["state", "status_msg", "mtime"],
desc="get_presence_state",
) )
def set_presence_state(self, user_localpart, new_state): def set_presence_state(self, user_localpart, new_state):
@ -45,6 +48,7 @@ class PresenceStore(SQLBaseStore):
updatevalues={"state": new_state["state"], updatevalues={"state": new_state["state"],
"status_msg": new_state["status_msg"], "status_msg": new_state["status_msg"],
"mtime": self._clock.time_msec()}, "mtime": self._clock.time_msec()},
desc="set_presence_state",
) )
def allow_presence_visible(self, observed_localpart, observer_userid): def allow_presence_visible(self, observed_localpart, observer_userid):
@ -52,6 +56,7 @@ class PresenceStore(SQLBaseStore):
table="presence_allow_inbound", table="presence_allow_inbound",
values={"observed_user_id": observed_localpart, values={"observed_user_id": observed_localpart,
"observer_user_id": observer_userid}, "observer_user_id": observer_userid},
desc="allow_presence_visible",
) )
def disallow_presence_visible(self, observed_localpart, observer_userid): def disallow_presence_visible(self, observed_localpart, observer_userid):
@ -59,6 +64,7 @@ class PresenceStore(SQLBaseStore):
table="presence_allow_inbound", table="presence_allow_inbound",
keyvalues={"observed_user_id": observed_localpart, keyvalues={"observed_user_id": observed_localpart,
"observer_user_id": observer_userid}, "observer_user_id": observer_userid},
desc="disallow_presence_visible",
) )
def is_presence_visible(self, observed_localpart, observer_userid): def is_presence_visible(self, observed_localpart, observer_userid):
@ -68,6 +74,7 @@ class PresenceStore(SQLBaseStore):
"observer_user_id": observer_userid}, "observer_user_id": observer_userid},
retcols=["observed_user_id"], retcols=["observed_user_id"],
allow_none=True, allow_none=True,
desc="is_presence_visible",
) )
def add_presence_list_pending(self, observer_localpart, observed_userid): def add_presence_list_pending(self, observer_localpart, observed_userid):
@ -76,6 +83,7 @@ class PresenceStore(SQLBaseStore):
values={"user_id": observer_localpart, values={"user_id": observer_localpart,
"observed_user_id": observed_userid, "observed_user_id": observed_userid,
"accepted": False}, "accepted": False},
desc="add_presence_list_pending",
) )
def set_presence_list_accepted(self, observer_localpart, observed_userid): def set_presence_list_accepted(self, observer_localpart, observed_userid):
@ -84,6 +92,7 @@ class PresenceStore(SQLBaseStore):
keyvalues={"user_id": observer_localpart, keyvalues={"user_id": observer_localpart,
"observed_user_id": observed_userid}, "observed_user_id": observed_userid},
updatevalues={"accepted": True}, updatevalues={"accepted": True},
desc="set_presence_list_accepted",
) )
def get_presence_list(self, observer_localpart, accepted=None): def get_presence_list(self, observer_localpart, accepted=None):
@ -95,6 +104,7 @@ class PresenceStore(SQLBaseStore):
table="presence_list", table="presence_list",
keyvalues=keyvalues, keyvalues=keyvalues,
retcols=["observed_user_id", "accepted"], retcols=["observed_user_id", "accepted"],
desc="get_presence_list",
) )
def del_presence_list(self, observer_localpart, observed_userid): def del_presence_list(self, observer_localpart, observed_userid):
@ -102,4 +112,5 @@ class PresenceStore(SQLBaseStore):
table="presence_list", table="presence_list",
keyvalues={"user_id": observer_localpart, keyvalues={"user_id": observer_localpart,
"observed_user_id": observed_userid}, "observed_user_id": observed_userid},
desc="del_presence_list",
) )

View file

@ -21,6 +21,7 @@ class ProfileStore(SQLBaseStore):
return self._simple_insert( return self._simple_insert(
table="profiles", table="profiles",
values={"user_id": user_localpart}, values={"user_id": user_localpart},
desc="create_profile",
) )
def get_profile_displayname(self, user_localpart): def get_profile_displayname(self, user_localpart):
@ -28,6 +29,7 @@ class ProfileStore(SQLBaseStore):
table="profiles", table="profiles",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
retcol="displayname", retcol="displayname",
desc="get_profile_displayname",
) )
def set_profile_displayname(self, user_localpart, new_displayname): def set_profile_displayname(self, user_localpart, new_displayname):
@ -35,6 +37,7 @@ class ProfileStore(SQLBaseStore):
table="profiles", table="profiles",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
updatevalues={"displayname": new_displayname}, updatevalues={"displayname": new_displayname},
desc="set_profile_displayname",
) )
def get_profile_avatar_url(self, user_localpart): def get_profile_avatar_url(self, user_localpart):
@ -42,6 +45,7 @@ class ProfileStore(SQLBaseStore):
table="profiles", table="profiles",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
retcol="avatar_url", retcol="avatar_url",
desc="get_profile_avatar_url",
) )
def set_profile_avatar_url(self, user_localpart, new_avatar_url): def set_profile_avatar_url(self, user_localpart, new_avatar_url):
@ -49,4 +53,5 @@ class ProfileStore(SQLBaseStore):
table="profiles", table="profiles",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
updatevalues={"avatar_url": new_avatar_url}, updatevalues={"avatar_url": new_avatar_url},
desc="set_profile_avatar_url",
) )

View file

@ -50,7 +50,8 @@ class PushRuleStore(SQLBaseStore):
results = yield self._simple_select_list( results = yield self._simple_select_list(
PushRuleEnableTable.table_name, PushRuleEnableTable.table_name,
{'user_name': user_name}, {'user_name': user_name},
PushRuleEnableTable.fields PushRuleEnableTable.fields,
desc="get_push_rules_enabled_for_user",
) )
defer.returnValue( defer.returnValue(
{r['rule_id']: False if r['enabled'] == 0 else True for r in results} {r['rule_id']: False if r['enabled'] == 0 else True for r in results}
@ -201,7 +202,8 @@ class PushRuleStore(SQLBaseStore):
""" """
yield self._simple_delete_one( yield self._simple_delete_one(
PushRuleTable.table_name, PushRuleTable.table_name,
{'user_name': user_name, 'rule_id': rule_id} {'user_name': user_name, 'rule_id': rule_id},
desc="delete_push_rule",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -209,7 +211,8 @@ class PushRuleStore(SQLBaseStore):
yield self._simple_upsert( yield self._simple_upsert(
PushRuleEnableTable.table_name, PushRuleEnableTable.table_name,
{'user_name': user_name, 'rule_id': rule_id}, {'user_name': user_name, 'rule_id': rule_id},
{'enabled': enabled} {'enabled': enabled},
desc="set_push_rule_enabled",
) )

View file

@ -114,7 +114,9 @@ class PusherStore(SQLBaseStore):
ts=pushkey_ts, ts=pushkey_ts,
lang=lang, lang=lang,
data=data data=data
)) ),
desc="add_pusher",
)
except Exception as e: except Exception as e:
logger.error("create_pusher with failed: %s", e) logger.error("create_pusher with failed: %s", e)
raise StoreError(500, "Problem creating pusher.") raise StoreError(500, "Problem creating pusher.")
@ -123,7 +125,8 @@ class PusherStore(SQLBaseStore):
def delete_pusher_by_app_id_pushkey(self, app_id, pushkey): def delete_pusher_by_app_id_pushkey(self, app_id, pushkey):
yield self._simple_delete_one( yield self._simple_delete_one(
PushersTable.table_name, PushersTable.table_name,
dict(app_id=app_id, pushkey=pushkey) {"app_id": app_id, "pushkey": pushkey},
desc="delete_pusher_by_app_id_pushkey",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -131,7 +134,8 @@ class PusherStore(SQLBaseStore):
yield self._simple_update_one( yield self._simple_update_one(
PushersTable.table_name, PushersTable.table_name,
{'app_id': app_id, 'pushkey': pushkey}, {'app_id': app_id, 'pushkey': pushkey},
{'last_token': last_token} {'last_token': last_token},
desc="update_pusher_last_token",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -140,7 +144,8 @@ class PusherStore(SQLBaseStore):
yield self._simple_update_one( yield self._simple_update_one(
PushersTable.table_name, PushersTable.table_name,
{'app_id': app_id, 'pushkey': pushkey}, {'app_id': app_id, 'pushkey': pushkey},
{'last_token': last_token, 'last_success': last_success} {'last_token': last_token, 'last_success': last_success},
desc="update_pusher_last_token_and_success",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -148,7 +153,8 @@ class PusherStore(SQLBaseStore):
yield self._simple_update_one( yield self._simple_update_one(
PushersTable.table_name, PushersTable.table_name,
{'app_id': app_id, 'pushkey': pushkey}, {'app_id': app_id, 'pushkey': pushkey},
{'failing_since': failing_since} {'failing_since': failing_since},
desc="update_pusher_failing_since",
) )

View file

@ -19,7 +19,7 @@ from sqlite3 import IntegrityError
from synapse.api.errors import StoreError, Codes from synapse.api.errors import StoreError, Codes
from ._base import SQLBaseStore from ._base import SQLBaseStore, cached
class RegistrationStore(SQLBaseStore): class RegistrationStore(SQLBaseStore):
@ -44,7 +44,8 @@ class RegistrationStore(SQLBaseStore):
{ {
"user_id": user_id, "user_id": user_id,
"token": token "token": token
} },
desc="add_access_token_to_user",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -87,6 +88,11 @@ class RegistrationStore(SQLBaseStore):
"get_user_by_id", self.cursor_to_dict, query, user_id "get_user_by_id", self.cursor_to_dict, query, user_id
) )
@cached()
# TODO(paul): Currently there's no code to invalidate this cache. That
# means if/when we ever add internal ways to invalidate access tokens or
# change whether a user is a server admin, those will need to invoke
# store.get_user_by_token.invalidate(token)
def get_user_by_token(self, token): def get_user_by_token(self, token):
"""Get a user from the given access token. """Get a user from the given access token.
@ -111,6 +117,7 @@ class RegistrationStore(SQLBaseStore):
keyvalues={"name": user.to_string()}, keyvalues={"name": user.to_string()},
retcol="admin", retcol="admin",
allow_none=True, allow_none=True,
desc="is_server_admin",
) )
defer.returnValue(res if res else False) defer.returnValue(res if res else False)

View file

@ -29,7 +29,7 @@ class RejectionsStore(SQLBaseStore):
"event_id": event_id, "event_id": event_id,
"reason": reason, "reason": reason,
"last_check": self._clock.time_msec(), "last_check": self._clock.time_msec(),
} },
) )
def get_rejection_reason(self, event_id): def get_rejection_reason(self, event_id):
@ -40,4 +40,5 @@ class RejectionsStore(SQLBaseStore):
"event_id": event_id, "event_id": event_id,
}, },
allow_none=True, allow_none=True,
desc="get_rejection_reason",
) )

View file

@ -15,11 +15,9 @@
from twisted.internet import defer from twisted.internet import defer
from sqlite3 import IntegrityError
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from ._base import SQLBaseStore, Table from ._base import SQLBaseStore
import collections import collections
import logging import logging
@ -27,8 +25,9 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
OpsLevel = collections.namedtuple("OpsLevel", ( OpsLevel = collections.namedtuple(
"ban_level", "kick_level", "redact_level") "OpsLevel",
("ban_level", "kick_level", "redact_level",)
) )
@ -47,13 +46,15 @@ class RoomStore(SQLBaseStore):
StoreError if the room could not be stored. StoreError if the room could not be stored.
""" """
try: try:
yield self._simple_insert(RoomsTable.table_name, dict( yield self._simple_insert(
room_id=room_id, RoomsTable.table_name,
creator=room_creator_user_id, {
is_public=is_public "room_id": room_id,
)) "creator": room_creator_user_id,
except IntegrityError: "is_public": is_public,
raise StoreError(409, "Room ID in use.") },
desc="store_room",
)
except Exception as e: except Exception as e:
logger.error("store_room with room_id=%s failed: %s", room_id, e) logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.") raise StoreError(500, "Problem creating room.")
@ -66,12 +67,11 @@ class RoomStore(SQLBaseStore):
Returns: Returns:
A namedtuple containing the room information, or an empty list. A namedtuple containing the room information, or an empty list.
""" """
query = RoomsTable.select_statement("room_id=?") return self._simple_select_one(
return self._execute( table=RoomsTable.table_name,
"get_room", keyvalues={"room_id": room_id},
lambda txn: RoomsTable.decode_single_result(txn.fetchall()), retcols=RoomsTable.fields,
query, desc="get_room",
room_id,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -146,7 +146,7 @@ class RoomStore(SQLBaseStore):
"event_id": event.event_id, "event_id": event.event_id,
"room_id": event.room_id, "room_id": event.room_id,
"topic": event.content["topic"], "topic": event.content["topic"],
} },
) )
def _store_room_name_txn(self, txn, event): def _store_room_name_txn(self, txn, event):
@ -199,7 +199,7 @@ class RoomStore(SQLBaseStore):
defer.returnValue((name, aliases)) defer.returnValue((name, aliases))
class RoomsTable(Table): class RoomsTable(object):
table_name = "rooms" table_name = "rooms"
fields = [ fields = [
@ -207,5 +207,3 @@ class RoomsTable(Table):
"is_public", "is_public",
"creator" "creator"
] ]
EntryType = collections.namedtuple("RoomEntry", fields)

View file

@ -212,7 +212,8 @@ class RoomMemberStore(SQLBaseStore):
return self._simple_select_onecol( return self._simple_select_onecol(
"room_hosts", "room_hosts",
{"room_id": room_id}, {"room_id": room_id},
"host" "host",
desc="get_joined_hosts_for_room",
) )
def _get_members_by_dict(self, where_dict): def _get_members_by_dict(self, where_dict):

View file

@ -160,3 +160,4 @@ class StateStore(SQLBaseStore):
def _make_group_id(clock): def _make_group_id(clock):
return str(int(clock.time_msec())) + random_string(5) return str(int(clock.time_msec())) + random_string(5)

View file

@ -35,7 +35,7 @@ what sort order was used:
from twisted.internet import defer from twisted.internet import defer
from ._base import SQLBaseStore from ._base import SQLBaseStore, cached
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -413,6 +413,7 @@ class StreamStore(SQLBaseStore):
"get_recent_events_for_room", get_recent_events_for_room_txn "get_recent_events_for_room", get_recent_events_for_room_txn
) )
@cached(num_args=0)
def get_room_events_max_id(self): def get_room_events_max_id(self):
return self.runInteraction( return self.runInteraction(
"get_room_events_max_id", "get_room_events_max_id",

View file

@ -46,15 +46,19 @@ class TransactionStore(SQLBaseStore):
) )
def _get_received_txn_response(self, txn, transaction_id, origin): def _get_received_txn_response(self, txn, transaction_id, origin):
where_clause = "transaction_id = ? AND origin = ?" result = self._simple_select_one_txn(
query = ReceivedTransactionsTable.select_statement(where_clause) txn,
table=ReceivedTransactionsTable.table_name,
keyvalues={
"transaction_id": transaction_id,
"origin": origin,
},
retcols=ReceivedTransactionsTable.fields,
allow_none=True,
)
txn.execute(query, (transaction_id, origin)) if result and result.response_code:
return result["response_code"], result["response_json"]
results = ReceivedTransactionsTable.decode_results(txn.fetchall())
if results and results[0].response_code:
return (results[0].response_code, results[0].response_json)
else: else:
return None return None

View file

@ -16,6 +16,10 @@
import random import random
import string import string
_string_with_symbols = (
string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
)
def origin_from_ucid(ucid): def origin_from_ucid(ucid):
return ucid.split("@", 1)[1] return ucid.split("@", 1)[1]
@ -23,3 +27,9 @@ def origin_from_ucid(ucid):
def random_string(length): def random_string(length):
return ''.join(random.choice(string.ascii_letters) for _ in xrange(length)) return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
def random_string_with_symbols(length):
return ''.join(
random.choice(_string_with_symbols) for _ in xrange(length)
)

View file

@ -180,7 +180,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = ("Old Value",) self.mock_txn.fetchone.return_value = ("Old Value",)
ret = yield self.datastore._simple_update_one( ret = yield self.datastore._simple_selectupdate_one(
table="tablename", table="tablename",
keyvalues={"keycol": "TheKey"}, keyvalues={"keycol": "TheKey"},
updatevalues={"columname": "New Value"}, updatevalues={"columname": "New Value"},

View file

@ -44,7 +44,7 @@ class RoomStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_room(self): def test_get_room(self):
self.assertObjectHasAttributes( self.assertDictContainsSubset(
{"room_id": self.room.to_string(), {"room_id": self.room.to_string(),
"creator": self.u_creator.to_string(), "creator": self.u_creator.to_string(),
"is_public": True}, "is_public": True},