mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-21 12:14:29 +03:00
Merge branch 'develop' of github.com:matrix-org/synapse into mysql
This commit is contained in:
commit
f6583796fe
37 changed files with 542 additions and 145 deletions
|
@ -1,3 +1,12 @@
|
||||||
|
Changes in synapse v0.8.1 (2015-03-18)
|
||||||
|
======================================
|
||||||
|
|
||||||
|
* Disable registration by default. New users can be added using the command
|
||||||
|
``register_new_matrix_user`` or by enabling registration in the config.
|
||||||
|
* Add metrics to synapse. To enable metrics use config options
|
||||||
|
``enable_metrics`` and ``metrics_port``.
|
||||||
|
* Fix bug where banning only kicked the user.
|
||||||
|
|
||||||
Changes in synapse v0.8.0 (2015-03-06)
|
Changes in synapse v0.8.0 (2015-03-06)
|
||||||
======================================
|
======================================
|
||||||
|
|
||||||
|
|
11
README.rst
11
README.rst
|
@ -128,6 +128,17 @@ To set up your homeserver, run (in your virtualenv, as before)::
|
||||||
|
|
||||||
Substituting your host and domain name as appropriate.
|
Substituting your host and domain name as appropriate.
|
||||||
|
|
||||||
|
By default, registration of new users is disabled. You can either enable
|
||||||
|
registration in the config (it is then recommended to also set up CAPTCHA), or
|
||||||
|
you can use the command line to register new users::
|
||||||
|
|
||||||
|
$ source ~/.synapse/bin/activate
|
||||||
|
$ register_new_matrix_user -c homeserver.yaml https://localhost:8448
|
||||||
|
New user localpart: erikj
|
||||||
|
Password:
|
||||||
|
Confirm password:
|
||||||
|
Success!
|
||||||
|
|
||||||
For reliable VoIP calls to be routed via this homeserver, you MUST configure
|
For reliable VoIP calls to be routed via this homeserver, you MUST configure
|
||||||
a TURN server. See docs/turn-howto.rst for details.
|
a TURN server. See docs/turn-howto.rst for details.
|
||||||
|
|
||||||
|
|
149
register_new_matrix_user
Executable file
149
register_new_matrix_user
Executable file
|
@ -0,0 +1,149 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import getpass
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import urllib2
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def request_registration(user, password, server_location, shared_secret):
|
||||||
|
mac = hmac.new(
|
||||||
|
key=shared_secret,
|
||||||
|
msg=user,
|
||||||
|
digestmod=hashlib.sha1,
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"user": user,
|
||||||
|
"password": password,
|
||||||
|
"mac": mac,
|
||||||
|
"type": "org.matrix.login.shared_secret",
|
||||||
|
}
|
||||||
|
|
||||||
|
server_location = server_location.rstrip("/")
|
||||||
|
|
||||||
|
print "Sending registration request..."
|
||||||
|
|
||||||
|
req = urllib2.Request(
|
||||||
|
"%s/_matrix/client/api/v1/register" % (server_location,),
|
||||||
|
data=json.dumps(data),
|
||||||
|
headers={'Content-Type': 'application/json'}
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
f = urllib2.urlopen(req)
|
||||||
|
f.read()
|
||||||
|
f.close()
|
||||||
|
print "Success."
|
||||||
|
except urllib2.HTTPError as e:
|
||||||
|
print "ERROR! Received %d %s" % (e.code, e.reason,)
|
||||||
|
if 400 <= e.code < 500:
|
||||||
|
if e.info().type == "application/json":
|
||||||
|
resp = json.load(e)
|
||||||
|
if "error" in resp:
|
||||||
|
print resp["error"]
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def register_new_user(user, password, server_location, shared_secret):
|
||||||
|
if not user:
|
||||||
|
try:
|
||||||
|
default_user = getpass.getuser()
|
||||||
|
except:
|
||||||
|
default_user = None
|
||||||
|
|
||||||
|
if default_user:
|
||||||
|
user = raw_input("New user localpart [%s]: " % (default_user,))
|
||||||
|
if not user:
|
||||||
|
user = default_user
|
||||||
|
else:
|
||||||
|
user = raw_input("New user localpart: ")
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
print "Invalid user name"
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if not password:
|
||||||
|
password = getpass.getpass("Password: ")
|
||||||
|
|
||||||
|
if not password:
|
||||||
|
print "Password cannot be blank."
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
confirm_password = getpass.getpass("Confirm password: ")
|
||||||
|
|
||||||
|
if password != confirm_password:
|
||||||
|
print "Passwords do not match"
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
request_registration(user, password, server_location, shared_secret)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Used to register new users with a given home server when"
|
||||||
|
" registration has been disabled. The home server must be"
|
||||||
|
" configured with the 'registration_shared_secret' option"
|
||||||
|
" set.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-u", "--user",
|
||||||
|
default=None,
|
||||||
|
help="Local part of the new user. Will prompt if omitted.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-p", "--password",
|
||||||
|
default=None,
|
||||||
|
help="New password for user. Will prompt if omitted.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group = parser.add_mutually_exclusive_group(required=True)
|
||||||
|
group.add_argument(
|
||||||
|
"-c", "--config",
|
||||||
|
type=argparse.FileType('r'),
|
||||||
|
help="Path to server config file. Used to read in shared secret.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"-k", "--shared-secret",
|
||||||
|
help="Shared secret as defined in server config file.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"server_url",
|
||||||
|
default="https://localhost:8448",
|
||||||
|
nargs='?',
|
||||||
|
help="URL to use to talk to the home server. Defaults to "
|
||||||
|
" 'https://localhost:8448'.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if "config" in args and args.config:
|
||||||
|
config = yaml.safe_load(args.config)
|
||||||
|
secret = config.get("registration_shared_secret", None)
|
||||||
|
if not secret:
|
||||||
|
print "No 'registration_shared_secret' defined in config."
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
secret = args.shared_secret
|
||||||
|
|
||||||
|
register_new_user(args.user, args.password, args.server_url, secret)
|
4
setup.py
4
setup.py
|
@ -45,7 +45,7 @@ setup(
|
||||||
version=version,
|
version=version,
|
||||||
packages=find_packages(exclude=["tests", "tests.*"]),
|
packages=find_packages(exclude=["tests", "tests.*"]),
|
||||||
description="Reference Synapse Home Server",
|
description="Reference Synapse Home Server",
|
||||||
install_requires=dependencies["REQUIREMENTS"].keys(),
|
install_requires=dependencies['requirements'](include_conditional=True).keys(),
|
||||||
setup_requires=[
|
setup_requires=[
|
||||||
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
|
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
|
||||||
"setuptools_trial",
|
"setuptools_trial",
|
||||||
|
@ -55,5 +55,5 @@ setup(
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
scripts=["synctl"],
|
scripts=["synctl", "register_new_matrix_user"],
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,4 +16,4 @@
|
||||||
""" This is a reference implementation of a Matrix home server.
|
""" This is a reference implementation of a Matrix home server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.8.0"
|
__version__ = "0.8.1-r2"
|
||||||
|
|
|
@ -388,7 +388,7 @@ class Auth(object):
|
||||||
AuthError if no user by that token exists or the token is invalid.
|
AuthError if no user by that token exists or the token is invalid.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
ret = yield self.store.get_user_by_token(token=token)
|
ret = yield self.store.get_user_by_token(token)
|
||||||
if not ret:
|
if not ret:
|
||||||
raise StoreError(400, "Unknown token")
|
raise StoreError(400, "Unknown token")
|
||||||
user_info = {
|
user_info = {
|
||||||
|
|
|
@ -60,6 +60,7 @@ class LoginType(object):
|
||||||
EMAIL_IDENTITY = u"m.login.email.identity"
|
EMAIL_IDENTITY = u"m.login.email.identity"
|
||||||
RECAPTCHA = u"m.login.recaptcha"
|
RECAPTCHA = u"m.login.recaptcha"
|
||||||
APPLICATION_SERVICE = u"m.login.application_service"
|
APPLICATION_SERVICE = u"m.login.application_service"
|
||||||
|
SHARED_SECRET = u"org.matrix.login.shared_secret"
|
||||||
|
|
||||||
|
|
||||||
class EventTypes(object):
|
class EventTypes(object):
|
||||||
|
|
|
@ -60,9 +60,9 @@ import re
|
||||||
import resource
|
import resource
|
||||||
import subprocess
|
import subprocess
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import syweb
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -84,6 +84,7 @@ class SynapseHomeServer(HomeServer):
|
||||||
return AppServiceRestResource(self)
|
return AppServiceRestResource(self)
|
||||||
|
|
||||||
def build_resource_for_web_client(self):
|
def build_resource_for_web_client(self):
|
||||||
|
import syweb
|
||||||
syweb_path = os.path.dirname(syweb.__file__)
|
syweb_path = os.path.dirname(syweb.__file__)
|
||||||
webclient_path = os.path.join(syweb_path, "webclient")
|
webclient_path = os.path.join(syweb_path, "webclient")
|
||||||
return File(webclient_path) # TODO configurable?
|
return File(webclient_path) # TODO configurable?
|
||||||
|
@ -131,7 +132,7 @@ class SynapseHomeServer(HomeServer):
|
||||||
True.
|
True.
|
||||||
"""
|
"""
|
||||||
config = self.get_config()
|
config = self.get_config()
|
||||||
web_client = config.webclient
|
web_client = config.web_client
|
||||||
|
|
||||||
# list containing (path_str, Resource) e.g:
|
# list containing (path_str, Resource) e.g:
|
||||||
# [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ]
|
# [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ]
|
||||||
|
@ -344,7 +345,8 @@ def setup(config_options):
|
||||||
|
|
||||||
config.setup_logging()
|
config.setup_logging()
|
||||||
|
|
||||||
check_requirements()
|
# check any extra requirements we have now we have a config
|
||||||
|
check_requirements(config)
|
||||||
|
|
||||||
version_string = get_version_string()
|
version_string = get_version_string()
|
||||||
|
|
||||||
|
@ -472,6 +474,7 @@ def run(hs):
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
with LoggingContext("main"):
|
with LoggingContext("main"):
|
||||||
|
# check base requirements
|
||||||
check_requirements()
|
check_requirements()
|
||||||
hs = setup(sys.argv[1:])
|
hs = setup(sys.argv[1:])
|
||||||
run(hs)
|
run(hs)
|
||||||
|
|
|
@ -15,19 +15,46 @@
|
||||||
|
|
||||||
from ._base import Config
|
from ._base import Config
|
||||||
|
|
||||||
|
from synapse.util.stringutils import random_string_with_symbols
|
||||||
|
|
||||||
|
import distutils.util
|
||||||
|
|
||||||
|
|
||||||
class RegistrationConfig(Config):
|
class RegistrationConfig(Config):
|
||||||
|
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
super(RegistrationConfig, self).__init__(args)
|
super(RegistrationConfig, self).__init__(args)
|
||||||
self.disable_registration = args.disable_registration
|
|
||||||
|
# `args.disable_registration` may either be a bool or a string depending
|
||||||
|
# on if the option was given a value (e.g. --disable-registration=false
|
||||||
|
# would set `args.disable_registration` to "false" not False.)
|
||||||
|
self.disable_registration = bool(
|
||||||
|
distutils.util.strtobool(str(args.disable_registration))
|
||||||
|
)
|
||||||
|
self.registration_shared_secret = args.registration_shared_secret
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_arguments(cls, parser):
|
def add_arguments(cls, parser):
|
||||||
super(RegistrationConfig, cls).add_arguments(parser)
|
super(RegistrationConfig, cls).add_arguments(parser)
|
||||||
reg_group = parser.add_argument_group("registration")
|
reg_group = parser.add_argument_group("registration")
|
||||||
|
|
||||||
reg_group.add_argument(
|
reg_group.add_argument(
|
||||||
"--disable-registration",
|
"--disable-registration",
|
||||||
action='store_true',
|
const=True,
|
||||||
help="Disable registration of new users."
|
default=True,
|
||||||
|
nargs='?',
|
||||||
|
help="Disable registration of new users.",
|
||||||
)
|
)
|
||||||
|
reg_group.add_argument(
|
||||||
|
"--registration-shared-secret", type=str,
|
||||||
|
help="If set, allows registration by anyone who also has the shared"
|
||||||
|
" secret, even if registration is otherwise disabled.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_config(cls, args, config_dir_path):
|
||||||
|
if args.disable_registration is None:
|
||||||
|
args.disable_registration = True
|
||||||
|
|
||||||
|
if args.registration_shared_secret is None:
|
||||||
|
args.registration_shared_secret = random_string_with_symbols(50)
|
||||||
|
|
|
@ -28,7 +28,7 @@ class ServerConfig(Config):
|
||||||
self.unsecure_port = args.unsecure_port
|
self.unsecure_port = args.unsecure_port
|
||||||
self.daemonize = args.daemonize
|
self.daemonize = args.daemonize
|
||||||
self.pid_file = self.abspath(args.pid_file)
|
self.pid_file = self.abspath(args.pid_file)
|
||||||
self.webclient = True
|
self.web_client = args.web_client
|
||||||
self.manhole = args.manhole
|
self.manhole = args.manhole
|
||||||
self.soft_file_limit = args.soft_file_limit
|
self.soft_file_limit = args.soft_file_limit
|
||||||
|
|
||||||
|
@ -68,6 +68,8 @@ class ServerConfig(Config):
|
||||||
server_group.add_argument('--pid-file', default="homeserver.pid",
|
server_group.add_argument('--pid-file', default="homeserver.pid",
|
||||||
help="When running as a daemon, the file to"
|
help="When running as a daemon, the file to"
|
||||||
" store the pid in")
|
" store the pid in")
|
||||||
|
server_group.add_argument('--web_client', default=True, type=bool,
|
||||||
|
help="Whether or not to serve a web client")
|
||||||
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
|
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
|
||||||
type=int,
|
type=int,
|
||||||
help="Turn on the twisted telnet manhole"
|
help="Turn on the twisted telnet manhole"
|
||||||
|
|
|
@ -361,4 +361,5 @@ SERVLET_CLASSES = (
|
||||||
FederationInviteServlet,
|
FederationInviteServlet,
|
||||||
FederationQueryAuthServlet,
|
FederationQueryAuthServlet,
|
||||||
FederationGetMissingEventsServlet,
|
FederationGetMissingEventsServlet,
|
||||||
|
FederationEventAuthServlet,
|
||||||
)
|
)
|
||||||
|
|
|
@ -290,6 +290,8 @@ class FederationHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
logger.debug("Joining %s to %s", joinee, room_id)
|
logger.debug("Joining %s to %s", joinee, room_id)
|
||||||
|
|
||||||
|
yield self.store.clean_room_for_join(room_id)
|
||||||
|
|
||||||
origin, pdu = yield self.replication_layer.make_join(
|
origin, pdu = yield self.replication_layer.make_join(
|
||||||
target_hosts,
|
target_hosts,
|
||||||
room_id,
|
room_id,
|
||||||
|
|
|
@ -31,6 +31,7 @@ import base64
|
||||||
import bcrypt
|
import bcrypt
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import urllib
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -63,6 +64,13 @@ class RegistrationHandler(BaseHandler):
|
||||||
password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
|
password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
|
||||||
|
|
||||||
if localpart:
|
if localpart:
|
||||||
|
if localpart and urllib.quote(localpart) != localpart:
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"User ID must only contain characters which do not"
|
||||||
|
" require URL encoding."
|
||||||
|
)
|
||||||
|
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
|
||||||
|
|
|
@ -51,8 +51,8 @@ class RestServlet(object):
|
||||||
pattern = self.PATTERN
|
pattern = self.PATTERN
|
||||||
|
|
||||||
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
|
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
|
||||||
if hasattr(self, "on_%s" % (method)):
|
if hasattr(self, "on_%s" % (method,)):
|
||||||
method_handler = getattr(self, "on_%s" % (method))
|
method_handler = getattr(self, "on_%s" % (method,))
|
||||||
http_server.register_path(method, pattern, method_handler)
|
http_server.register_path(method, pattern, method_handler)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("RestServlet must register something.")
|
raise NotImplementedError("RestServlet must register something.")
|
||||||
|
|
|
@ -5,7 +5,6 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
REQUIREMENTS = {
|
REQUIREMENTS = {
|
||||||
"syutil>=0.0.3": ["syutil"],
|
"syutil>=0.0.3": ["syutil"],
|
||||||
"matrix_angular_sdk>=0.6.5": ["syweb>=0.6.5"],
|
|
||||||
"Twisted==14.0.2": ["twisted==14.0.2"],
|
"Twisted==14.0.2": ["twisted==14.0.2"],
|
||||||
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
|
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
|
||||||
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
|
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
|
||||||
|
@ -18,6 +17,19 @@ REQUIREMENTS = {
|
||||||
"pillow": ["PIL"],
|
"pillow": ["PIL"],
|
||||||
"pydenticon": ["pydenticon"],
|
"pydenticon": ["pydenticon"],
|
||||||
}
|
}
|
||||||
|
CONDITIONAL_REQUIREMENTS = {
|
||||||
|
"web_client": {
|
||||||
|
"matrix_angular_sdk>=0.6.5": ["syweb>=0.6.5"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def requirements(config=None, include_conditional=False):
|
||||||
|
reqs = REQUIREMENTS.copy()
|
||||||
|
for key, req in CONDITIONAL_REQUIREMENTS.items():
|
||||||
|
if (config and getattr(config, key)) or include_conditional:
|
||||||
|
reqs.update(req)
|
||||||
|
return reqs
|
||||||
|
|
||||||
|
|
||||||
def github_link(project, version, egg):
|
def github_link(project, version, egg):
|
||||||
|
@ -46,10 +58,11 @@ class MissingRequirementError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def check_requirements():
|
def check_requirements(config=None):
|
||||||
"""Checks that all the modules needed by synapse have been correctly
|
"""Checks that all the modules needed by synapse have been correctly
|
||||||
installed and are at the correct version"""
|
installed and are at the correct version"""
|
||||||
for dependency, module_requirements in REQUIREMENTS.items():
|
for dependency, module_requirements in (
|
||||||
|
requirements(config, include_conditional=False).items()):
|
||||||
for module_requirement in module_requirements:
|
for module_requirement in module_requirements:
|
||||||
if ">=" in module_requirement:
|
if ">=" in module_requirement:
|
||||||
module_name, required_version = module_requirement.split(">=")
|
module_name, required_version = module_requirement.split(">=")
|
||||||
|
@ -110,7 +123,7 @@ def list_requirements():
|
||||||
egg = link.split("#egg=")[1]
|
egg = link.split("#egg=")[1]
|
||||||
linked.append(egg.split('-')[0])
|
linked.append(egg.split('-')[0])
|
||||||
result.append(link)
|
result.append(link)
|
||||||
for requirement in REQUIREMENTS:
|
for requirement in requirements(include_conditional=True):
|
||||||
is_linked = False
|
is_linked = False
|
||||||
for link in linked:
|
for link in linked:
|
||||||
if requirement.replace('-', '_').startswith(link):
|
if requirement.replace('-', '_').startswith(link):
|
||||||
|
|
|
@ -27,7 +27,6 @@ from hashlib import sha1
|
||||||
import hmac
|
import hmac
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -110,14 +109,22 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
login_type = register_json["type"]
|
login_type = register_json["type"]
|
||||||
|
|
||||||
is_application_server = login_type == LoginType.APPLICATION_SERVICE
|
is_application_server = login_type == LoginType.APPLICATION_SERVICE
|
||||||
if self.disable_registration and not is_application_server:
|
is_using_shared_secret = login_type == LoginType.SHARED_SECRET
|
||||||
|
|
||||||
|
can_register = (
|
||||||
|
not self.disable_registration
|
||||||
|
or is_application_server
|
||||||
|
or is_using_shared_secret
|
||||||
|
)
|
||||||
|
if not can_register:
|
||||||
raise SynapseError(403, "Registration has been disabled")
|
raise SynapseError(403, "Registration has been disabled")
|
||||||
|
|
||||||
stages = {
|
stages = {
|
||||||
LoginType.RECAPTCHA: self._do_recaptcha,
|
LoginType.RECAPTCHA: self._do_recaptcha,
|
||||||
LoginType.PASSWORD: self._do_password,
|
LoginType.PASSWORD: self._do_password,
|
||||||
LoginType.EMAIL_IDENTITY: self._do_email_identity,
|
LoginType.EMAIL_IDENTITY: self._do_email_identity,
|
||||||
LoginType.APPLICATION_SERVICE: self._do_app_service
|
LoginType.APPLICATION_SERVICE: self._do_app_service,
|
||||||
|
LoginType.SHARED_SECRET: self._do_shared_secret,
|
||||||
}
|
}
|
||||||
|
|
||||||
session_info = self._get_session_info(request, session)
|
session_info = self._get_session_info(request, session)
|
||||||
|
@ -255,14 +262,11 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
password = register_json["password"].encode("utf-8")
|
password = register_json["password"].encode("utf-8")
|
||||||
desired_user_id = (register_json["user"].encode("utf-8")
|
desired_user_id = (
|
||||||
if "user" in register_json else None)
|
register_json["user"].encode("utf-8")
|
||||||
if (desired_user_id
|
if "user" in register_json else None
|
||||||
and urllib.quote(desired_user_id) != desired_user_id):
|
)
|
||||||
raise SynapseError(
|
|
||||||
400,
|
|
||||||
"User ID must only contain characters which do not " +
|
|
||||||
"require URL encoding.")
|
|
||||||
handler = self.handlers.registration_handler
|
handler = self.handlers.registration_handler
|
||||||
(user_id, token) = yield handler.register(
|
(user_id, token) = yield handler.register(
|
||||||
localpart=desired_user_id,
|
localpart=desired_user_id,
|
||||||
|
@ -304,6 +308,51 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _do_shared_secret(self, request, register_json, session):
|
||||||
|
yield run_on_reactor()
|
||||||
|
|
||||||
|
if not isinstance(register_json.get("mac", None), basestring):
|
||||||
|
raise SynapseError(400, "Expected mac.")
|
||||||
|
if not isinstance(register_json.get("user", None), basestring):
|
||||||
|
raise SynapseError(400, "Expected 'user' key.")
|
||||||
|
if not isinstance(register_json.get("password", None), basestring):
|
||||||
|
raise SynapseError(400, "Expected 'password' key.")
|
||||||
|
|
||||||
|
if not self.hs.config.registration_shared_secret:
|
||||||
|
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||||
|
|
||||||
|
user = register_json["user"].encode("utf-8")
|
||||||
|
|
||||||
|
# str() because otherwise hmac complains that 'unicode' does not
|
||||||
|
# have the buffer interface
|
||||||
|
got_mac = str(register_json["mac"])
|
||||||
|
|
||||||
|
want_mac = hmac.new(
|
||||||
|
key=self.hs.config.registration_shared_secret,
|
||||||
|
msg=user,
|
||||||
|
digestmod=sha1,
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
password = register_json["password"].encode("utf-8")
|
||||||
|
|
||||||
|
if compare_digest(want_mac, got_mac):
|
||||||
|
handler = self.handlers.registration_handler
|
||||||
|
user_id, token = yield handler.register(
|
||||||
|
localpart=user,
|
||||||
|
password=password,
|
||||||
|
)
|
||||||
|
self._remove_session(session)
|
||||||
|
defer.returnValue({
|
||||||
|
"user_id": user_id,
|
||||||
|
"access_token": token,
|
||||||
|
"home_server": self.hs.hostname,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
raise SynapseError(
|
||||||
|
403, "HMAC incorrect",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _parse_json(request):
|
def _parse_json(request):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -91,7 +91,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
"user_agent": user_agent,
|
"user_agent": user_agent,
|
||||||
"last_seen": int(self._clock.time_msec()),
|
"last_seen": int(self._clock.time_msec()),
|
||||||
},
|
},
|
||||||
or_replace=True,
|
desc="insert_client_ip",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_user_ip_and_agents(self, user):
|
def get_user_ip_and_agents(self, user):
|
||||||
|
@ -101,6 +101,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
retcols=[
|
retcols=[
|
||||||
"device_id", "access_token", "ip", "user_agent", "last_seen"
|
"device_id", "access_token", "ip", "user_agent", "last_seen"
|
||||||
],
|
],
|
||||||
|
desc="get_user_ip_and_agents",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,7 @@ import synapse.metrics
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from collections import namedtuple, OrderedDict
|
from collections import namedtuple, OrderedDict
|
||||||
|
import functools
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
@ -53,13 +54,12 @@ cache_counter = metrics.register_cache(
|
||||||
|
|
||||||
|
|
||||||
# TODO(paul):
|
# TODO(paul):
|
||||||
# * more generic key management
|
|
||||||
# * consider other eviction strategies - LRU?
|
# * consider other eviction strategies - LRU?
|
||||||
def cached(max_entries=1000):
|
def cached(max_entries=1000, num_args=1):
|
||||||
""" A method decorator that applies a memoizing cache around the function.
|
""" A method decorator that applies a memoizing cache around the function.
|
||||||
|
|
||||||
The function is presumed to take one additional argument, which is used as
|
The function is presumed to take zero or more arguments, which are used in
|
||||||
the key for the cache. Cache hits are served directly from the cache;
|
a tuple as the key for the cache. Hits are served directly from the cache;
|
||||||
misses use the function body to generate the value.
|
misses use the function body to generate the value.
|
||||||
|
|
||||||
The wrapped function has an additional member, a callable called
|
The wrapped function has an additional member, a callable called
|
||||||
|
@ -75,25 +75,41 @@ def cached(max_entries=1000):
|
||||||
|
|
||||||
caches_by_name[name] = cache
|
caches_by_name[name] = cache
|
||||||
|
|
||||||
def prefill(key, value):
|
def prefill(*args): # because I can't *keyargs, value
|
||||||
|
keyargs = args[:-1]
|
||||||
|
value = args[-1]
|
||||||
|
|
||||||
|
if len(keyargs) != num_args:
|
||||||
|
raise ValueError("Expected a call to have %d arguments", num_args)
|
||||||
|
|
||||||
while len(cache) > max_entries:
|
while len(cache) > max_entries:
|
||||||
cache.popitem(last=False)
|
cache.popitem(last=False)
|
||||||
|
|
||||||
cache[key] = value
|
cache[keyargs] = value
|
||||||
|
|
||||||
|
@functools.wraps(orig)
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def wrapped(self, key):
|
def wrapped(self, *keyargs):
|
||||||
if key in cache:
|
if len(keyargs) != num_args:
|
||||||
|
raise ValueError("Expected a call to have %d arguments", num_args)
|
||||||
|
|
||||||
|
if keyargs in cache:
|
||||||
cache_counter.inc_hits(name)
|
cache_counter.inc_hits(name)
|
||||||
defer.returnValue(cache[key])
|
defer.returnValue(cache[keyargs])
|
||||||
|
|
||||||
cache_counter.inc_misses(name)
|
cache_counter.inc_misses(name)
|
||||||
ret = yield orig(self, key)
|
ret = yield orig(self, *keyargs)
|
||||||
prefill(key, ret)
|
|
||||||
|
prefill_args = keyargs + (ret,)
|
||||||
|
prefill(*prefill_args)
|
||||||
|
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
def invalidate(key):
|
def invalidate(*keyargs):
|
||||||
cache.pop(key, None)
|
if len(keyargs) != num_args:
|
||||||
|
raise ValueError("Expected a call to have %d arguments", num_args)
|
||||||
|
|
||||||
|
cache.pop(keyargs, None)
|
||||||
|
|
||||||
wrapped.invalidate = invalidate
|
wrapped.invalidate = invalidate
|
||||||
wrapped.prefill = prefill
|
wrapped.prefill = prefill
|
||||||
|
@ -325,7 +341,8 @@ class SQLBaseStore(object):
|
||||||
# "Simple" SQL API methods that operate on a single table with no JOINs,
|
# "Simple" SQL API methods that operate on a single table with no JOINs,
|
||||||
# no complex WHERE clauses, just a dict of values for columns.
|
# no complex WHERE clauses, just a dict of values for columns.
|
||||||
|
|
||||||
def _simple_insert(self, table, values, or_replace=False, or_ignore=False):
|
def _simple_insert(self, table, values, or_replace=False, or_ignore=False,
|
||||||
|
desc="_simple_insert"):
|
||||||
"""Executes an INSERT query on the named table.
|
"""Executes an INSERT query on the named table.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -334,7 +351,7 @@ class SQLBaseStore(object):
|
||||||
or_replace : bool; if True performs an INSERT OR REPLACE
|
or_replace : bool; if True performs an INSERT OR REPLACE
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"_simple_insert",
|
desc,
|
||||||
self._simple_insert_txn, table, values, or_replace=or_replace,
|
self._simple_insert_txn, table, values, or_replace=or_replace,
|
||||||
or_ignore=or_ignore,
|
or_ignore=or_ignore,
|
||||||
)
|
)
|
||||||
|
@ -357,7 +374,7 @@ class SQLBaseStore(object):
|
||||||
txn.execute(sql, values.values())
|
txn.execute(sql, values.values())
|
||||||
return txn.lastrowid
|
return txn.lastrowid
|
||||||
|
|
||||||
def _simple_upsert(self, table, keyvalues, values):
|
def _simple_upsert(self, table, keyvalues, values, desc="_simple_upsert"):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
table (str): The table to upsert into
|
table (str): The table to upsert into
|
||||||
|
@ -366,7 +383,7 @@ class SQLBaseStore(object):
|
||||||
Returns: A deferred
|
Returns: A deferred
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"_simple_upsert",
|
desc,
|
||||||
self._simple_upsert_txn, table, keyvalues, values
|
self._simple_upsert_txn, table, keyvalues, values
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -402,7 +419,7 @@ class SQLBaseStore(object):
|
||||||
txn.execute(sql, allvalues.values())
|
txn.execute(sql, allvalues.values())
|
||||||
|
|
||||||
def _simple_select_one(self, table, keyvalues, retcols,
|
def _simple_select_one(self, table, keyvalues, retcols,
|
||||||
allow_none=False):
|
allow_none=False, desc="_simple_select_one"):
|
||||||
"""Executes a SELECT query on the named table, which is expected to
|
"""Executes a SELECT query on the named table, which is expected to
|
||||||
return a single row, returning a single column from it.
|
return a single row, returning a single column from it.
|
||||||
|
|
||||||
|
@ -414,12 +431,15 @@ class SQLBaseStore(object):
|
||||||
allow_none : If true, return None instead of failing if the SELECT
|
allow_none : If true, return None instead of failing if the SELECT
|
||||||
statement returns no rows
|
statement returns no rows
|
||||||
"""
|
"""
|
||||||
return self._simple_selectupdate_one(
|
return self.runInteraction(
|
||||||
table, keyvalues, retcols=retcols, allow_none=allow_none
|
desc,
|
||||||
|
self._simple_select_one_txn,
|
||||||
|
table, keyvalues, retcols, allow_none,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _simple_select_one_onecol(self, table, keyvalues, retcol,
|
def _simple_select_one_onecol(self, table, keyvalues, retcol,
|
||||||
allow_none=False):
|
allow_none=False,
|
||||||
|
desc="_simple_select_one_onecol"):
|
||||||
"""Executes a SELECT query on the named table, which is expected to
|
"""Executes a SELECT query on the named table, which is expected to
|
||||||
return a single row, returning a single column from it."
|
return a single row, returning a single column from it."
|
||||||
|
|
||||||
|
@ -429,7 +449,7 @@ class SQLBaseStore(object):
|
||||||
retcol : string giving the name of the column to return
|
retcol : string giving the name of the column to return
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"_simple_select_one_onecol",
|
desc,
|
||||||
self._simple_select_one_onecol_txn,
|
self._simple_select_one_onecol_txn,
|
||||||
table, keyvalues, retcol, allow_none=allow_none,
|
table, keyvalues, retcol, allow_none=allow_none,
|
||||||
)
|
)
|
||||||
|
@ -464,7 +484,8 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
return [r[0] for r in txn.fetchall()]
|
return [r[0] for r in txn.fetchall()]
|
||||||
|
|
||||||
def _simple_select_onecol(self, table, keyvalues, retcol):
|
def _simple_select_onecol(self, table, keyvalues, retcol,
|
||||||
|
desc="_simple_select_onecol"):
|
||||||
"""Executes a SELECT query on the named table, which returns a list
|
"""Executes a SELECT query on the named table, which returns a list
|
||||||
comprising of the values of the named column from the selected rows.
|
comprising of the values of the named column from the selected rows.
|
||||||
|
|
||||||
|
@ -477,12 +498,13 @@ class SQLBaseStore(object):
|
||||||
Deferred: Results in a list
|
Deferred: Results in a list
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"_simple_select_onecol",
|
desc,
|
||||||
self._simple_select_onecol_txn,
|
self._simple_select_onecol_txn,
|
||||||
table, keyvalues, retcol
|
table, keyvalues, retcol
|
||||||
)
|
)
|
||||||
|
|
||||||
def _simple_select_list(self, table, keyvalues, retcols):
|
def _simple_select_list(self, table, keyvalues, retcols,
|
||||||
|
desc="_simple_select_list"):
|
||||||
"""Executes a SELECT query on the named table, which may return zero or
|
"""Executes a SELECT query on the named table, which may return zero or
|
||||||
more rows, returning the result as a list of dicts.
|
more rows, returning the result as a list of dicts.
|
||||||
|
|
||||||
|
@ -493,7 +515,7 @@ class SQLBaseStore(object):
|
||||||
retcols : list of strings giving the names of the columns to return
|
retcols : list of strings giving the names of the columns to return
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"_simple_select_list",
|
desc,
|
||||||
self._simple_select_list_txn,
|
self._simple_select_list_txn,
|
||||||
table, keyvalues, retcols
|
table, keyvalues, retcols
|
||||||
)
|
)
|
||||||
|
@ -525,7 +547,7 @@ class SQLBaseStore(object):
|
||||||
return self.cursor_to_dict(txn)
|
return self.cursor_to_dict(txn)
|
||||||
|
|
||||||
def _simple_update_one(self, table, keyvalues, updatevalues,
|
def _simple_update_one(self, table, keyvalues, updatevalues,
|
||||||
retcols=None):
|
desc="_simple_update_one"):
|
||||||
"""Executes an UPDATE query on the named table, setting new values for
|
"""Executes an UPDATE query on the named table, setting new values for
|
||||||
columns in a row matching the key values.
|
columns in a row matching the key values.
|
||||||
|
|
||||||
|
@ -543,29 +565,37 @@ class SQLBaseStore(object):
|
||||||
get-and-set. This can be used to implement compare-and-set by putting
|
get-and-set. This can be used to implement compare-and-set by putting
|
||||||
the update column in the 'keyvalues' dict as well.
|
the update column in the 'keyvalues' dict as well.
|
||||||
"""
|
"""
|
||||||
return self._simple_selectupdate_one(table, keyvalues, updatevalues,
|
return self.runInteraction(
|
||||||
retcols=retcols)
|
desc,
|
||||||
|
self._simple_update_one_txn,
|
||||||
def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
|
table, keyvalues, updatevalues,
|
||||||
retcols=None, allow_none=False):
|
|
||||||
""" Combined SELECT then UPDATE."""
|
|
||||||
if retcols:
|
|
||||||
select_sql = "SELECT %s FROM %s WHERE %s" % (
|
|
||||||
", ".join(retcols),
|
|
||||||
table,
|
|
||||||
" AND ".join("%s = ?" % (k) for k in keyvalues)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if updatevalues:
|
def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
|
||||||
update_sql = "UPDATE %s SET %s WHERE %s" % (
|
update_sql = "UPDATE %s SET %s WHERE %s" % (
|
||||||
table,
|
table,
|
||||||
", ".join("%s = ?" % (k,) for k in updatevalues),
|
", ".join("%s = ?" % (k,) for k in updatevalues),
|
||||||
" AND ".join("%s = ?" % (k,) for k in keyvalues)
|
" AND ".join("%s = ?" % (k,) for k in keyvalues)
|
||||||
)
|
)
|
||||||
|
|
||||||
def func(txn):
|
txn.execute(
|
||||||
ret = None
|
update_sql,
|
||||||
if retcols:
|
updatevalues.values() + keyvalues.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
if txn.rowcount == 0:
|
||||||
|
raise StoreError(404, "No row found")
|
||||||
|
if txn.rowcount > 1:
|
||||||
|
raise StoreError(500, "More than one row matched")
|
||||||
|
|
||||||
|
def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
|
||||||
|
allow_none=False):
|
||||||
|
select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
|
||||||
|
", ".join(retcols),
|
||||||
|
table,
|
||||||
|
" AND ".join("%s = ?" % (k) for k in keyvalues)
|
||||||
|
)
|
||||||
|
|
||||||
txn.execute(select_sql, keyvalues.values())
|
txn.execute(select_sql, keyvalues.values())
|
||||||
|
|
||||||
row = txn.fetchone()
|
row = txn.fetchone()
|
||||||
|
@ -576,12 +606,29 @@ class SQLBaseStore(object):
|
||||||
if txn.rowcount > 1:
|
if txn.rowcount > 1:
|
||||||
raise StoreError(500, "More than one row matched")
|
raise StoreError(500, "More than one row matched")
|
||||||
|
|
||||||
ret = dict(zip(retcols, row))
|
return dict(zip(retcols, row))
|
||||||
|
|
||||||
|
def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
|
||||||
|
retcols=None, allow_none=False,
|
||||||
|
desc="_simple_selectupdate_one"):
|
||||||
|
""" Combined SELECT then UPDATE."""
|
||||||
|
def func(txn):
|
||||||
|
ret = None
|
||||||
|
if retcols:
|
||||||
|
ret = self._simple_select_one_txn(
|
||||||
|
txn,
|
||||||
|
table=table,
|
||||||
|
keyvalues=keyvalues,
|
||||||
|
retcols=retcols,
|
||||||
|
allow_none=allow_none,
|
||||||
|
)
|
||||||
|
|
||||||
if updatevalues:
|
if updatevalues:
|
||||||
txn.execute(
|
self._simple_update_one_txn(
|
||||||
update_sql,
|
txn,
|
||||||
updatevalues.values() + keyvalues.values()
|
table=table,
|
||||||
|
keyvalues=keyvalues,
|
||||||
|
updatevalues=updatevalues,
|
||||||
)
|
)
|
||||||
|
|
||||||
# if txn.rowcount == 0:
|
# if txn.rowcount == 0:
|
||||||
|
@ -590,9 +637,9 @@ class SQLBaseStore(object):
|
||||||
raise StoreError(500, "More than one row matched")
|
raise StoreError(500, "More than one row matched")
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
return self.runInteraction("_simple_selectupdate_one", func)
|
return self.runInteraction(desc, func)
|
||||||
|
|
||||||
def _simple_delete_one(self, table, keyvalues):
|
def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
|
||||||
"""Executes a DELETE query on the named table, expecting to delete a
|
"""Executes a DELETE query on the named table, expecting to delete a
|
||||||
single row.
|
single row.
|
||||||
|
|
||||||
|
@ -611,9 +658,9 @@ class SQLBaseStore(object):
|
||||||
raise StoreError(404, "No row found")
|
raise StoreError(404, "No row found")
|
||||||
if txn.rowcount > 1:
|
if txn.rowcount > 1:
|
||||||
raise StoreError(500, "more than one row matched")
|
raise StoreError(500, "more than one row matched")
|
||||||
return self.runInteraction("_simple_delete_one", func)
|
return self.runInteraction(desc, func)
|
||||||
|
|
||||||
def _simple_delete(self, table, keyvalues):
|
def _simple_delete(self, table, keyvalues, desc="_simple_delete"):
|
||||||
"""Executes a DELETE query on the named table.
|
"""Executes a DELETE query on the named table.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -621,7 +668,7 @@ class SQLBaseStore(object):
|
||||||
keyvalues : dict of column names and values to select the row with
|
keyvalues : dict of column names and values to select the row with
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.runInteraction("_simple_delete", self._simple_delete_txn)
|
return self.runInteraction(desc, self._simple_delete_txn)
|
||||||
|
|
||||||
def _simple_delete_txn(self, txn, table, keyvalues):
|
def _simple_delete_txn(self, txn, table, keyvalues):
|
||||||
sql = "DELETE FROM %s WHERE %s" % (
|
sql = "DELETE FROM %s WHERE %s" % (
|
||||||
|
|
|
@ -48,6 +48,7 @@ class DirectoryStore(SQLBaseStore):
|
||||||
{"room_alias": room_alias.to_string()},
|
{"room_alias": room_alias.to_string()},
|
||||||
"room_id",
|
"room_id",
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
|
desc="get_association_from_room_alias",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not room_id:
|
if not room_id:
|
||||||
|
@ -58,6 +59,7 @@ class DirectoryStore(SQLBaseStore):
|
||||||
"room_alias_servers",
|
"room_alias_servers",
|
||||||
{"room_alias": room_alias.to_string()},
|
{"room_alias": room_alias.to_string()},
|
||||||
"server",
|
"server",
|
||||||
|
desc="get_association_from_room_alias",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not servers:
|
if not servers:
|
||||||
|
@ -87,6 +89,7 @@ class DirectoryStore(SQLBaseStore):
|
||||||
"room_alias": room_alias.to_string(),
|
"room_alias": room_alias.to_string(),
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
},
|
},
|
||||||
|
desc="create_room_alias_association",
|
||||||
)
|
)
|
||||||
except sqlite3.IntegrityError:
|
except sqlite3.IntegrityError:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
|
@ -100,7 +103,8 @@ class DirectoryStore(SQLBaseStore):
|
||||||
{
|
{
|
||||||
"room_alias": room_alias.to_string(),
|
"room_alias": room_alias.to_string(),
|
||||||
"server": server,
|
"server": server,
|
||||||
}
|
},
|
||||||
|
desc="create_room_alias_association",
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_room_alias(self, room_alias):
|
def delete_room_alias(self, room_alias):
|
||||||
|
@ -139,4 +143,5 @@ class DirectoryStore(SQLBaseStore):
|
||||||
"room_aliases",
|
"room_aliases",
|
||||||
{"room_id": room_id},
|
{"room_id": room_id},
|
||||||
"room_alias",
|
"room_alias",
|
||||||
|
desc="get_aliases_for_room",
|
||||||
)
|
)
|
||||||
|
|
|
@ -426,3 +426,15 @@ class EventFederationStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
return events[:limit]
|
return events[:limit]
|
||||||
|
|
||||||
|
def clean_room_for_join(self, room_id):
|
||||||
|
return self.runInteraction(
|
||||||
|
"clean_room_for_join",
|
||||||
|
self._clean_room_for_join_txn,
|
||||||
|
room_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _clean_room_for_join_txn(self, txn, room_id):
|
||||||
|
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
|
||||||
|
|
||||||
|
txn.execute(query, (room_id,))
|
||||||
|
|
|
@ -52,6 +52,7 @@ class EventsStore(SQLBaseStore):
|
||||||
is_new_state=is_new_state,
|
is_new_state=is_new_state,
|
||||||
current_state=current_state,
|
current_state=current_state,
|
||||||
)
|
)
|
||||||
|
self.get_room_events_max_id.invalidate()
|
||||||
except _RollbackButIsFineException:
|
except _RollbackButIsFineException:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -242,7 +243,6 @@ class EventsStore(SQLBaseStore):
|
||||||
if stream_ordering is None:
|
if stream_ordering is None:
|
||||||
stream_ordering = self.get_next_stream_id()
|
stream_ordering = self.get_next_stream_id()
|
||||||
|
|
||||||
|
|
||||||
unrec = {
|
unrec = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in event.get_dict().items()
|
for k, v in event.get_dict().items()
|
||||||
|
|
|
@ -31,6 +31,7 @@ class FilteringStore(SQLBaseStore):
|
||||||
},
|
},
|
||||||
retcol="filter_json",
|
retcol="filter_json",
|
||||||
allow_none=False,
|
allow_none=False,
|
||||||
|
desc="get_user_filter",
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(json.loads(def_json))
|
defer.returnValue(json.loads(def_json))
|
||||||
|
|
|
@ -32,6 +32,7 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||||
{"media_id": media_id},
|
{"media_id": media_id},
|
||||||
("media_type", "media_length", "upload_name", "created_ts"),
|
("media_type", "media_length", "upload_name", "created_ts"),
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
|
desc="get_local_media",
|
||||||
)
|
)
|
||||||
|
|
||||||
def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
|
def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
|
||||||
|
@ -45,7 +46,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||||
"upload_name": upload_name,
|
"upload_name": upload_name,
|
||||||
"media_length": media_length,
|
"media_length": media_length,
|
||||||
"user_id": user_id.to_string(),
|
"user_id": user_id.to_string(),
|
||||||
}
|
},
|
||||||
|
desc="store_local_media",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_local_media_thumbnails(self, media_id):
|
def get_local_media_thumbnails(self, media_id):
|
||||||
|
@ -55,7 +57,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||||
(
|
(
|
||||||
"thumbnail_width", "thumbnail_height", "thumbnail_method",
|
"thumbnail_width", "thumbnail_height", "thumbnail_method",
|
||||||
"thumbnail_type", "thumbnail_length",
|
"thumbnail_type", "thumbnail_length",
|
||||||
)
|
),
|
||||||
|
desc="get_local_media_thumbnails",
|
||||||
)
|
)
|
||||||
|
|
||||||
def store_local_thumbnail(self, media_id, thumbnail_width,
|
def store_local_thumbnail(self, media_id, thumbnail_width,
|
||||||
|
@ -70,7 +73,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||||
"thumbnail_method": thumbnail_method,
|
"thumbnail_method": thumbnail_method,
|
||||||
"thumbnail_type": thumbnail_type,
|
"thumbnail_type": thumbnail_type,
|
||||||
"thumbnail_length": thumbnail_length,
|
"thumbnail_length": thumbnail_length,
|
||||||
}
|
},
|
||||||
|
desc="store_local_thumbnail",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_cached_remote_media(self, origin, media_id):
|
def get_cached_remote_media(self, origin, media_id):
|
||||||
|
@ -82,6 +86,7 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||||
"filesystem_id",
|
"filesystem_id",
|
||||||
),
|
),
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
|
desc="get_cached_remote_media",
|
||||||
)
|
)
|
||||||
|
|
||||||
def store_cached_remote_media(self, origin, media_id, media_type,
|
def store_cached_remote_media(self, origin, media_id, media_type,
|
||||||
|
@ -97,7 +102,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||||
"created_ts": time_now_ms,
|
"created_ts": time_now_ms,
|
||||||
"upload_name": upload_name,
|
"upload_name": upload_name,
|
||||||
"filesystem_id": filesystem_id,
|
"filesystem_id": filesystem_id,
|
||||||
}
|
},
|
||||||
|
desc="store_cached_remote_media",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_remote_media_thumbnails(self, origin, media_id):
|
def get_remote_media_thumbnails(self, origin, media_id):
|
||||||
|
@ -107,7 +113,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||||
(
|
(
|
||||||
"thumbnail_width", "thumbnail_height", "thumbnail_method",
|
"thumbnail_width", "thumbnail_height", "thumbnail_method",
|
||||||
"thumbnail_type", "thumbnail_length", "filesystem_id",
|
"thumbnail_type", "thumbnail_length", "filesystem_id",
|
||||||
)
|
),
|
||||||
|
desc="get_remote_media_thumbnails",
|
||||||
)
|
)
|
||||||
|
|
||||||
def store_remote_media_thumbnail(self, origin, media_id, filesystem_id,
|
def store_remote_media_thumbnail(self, origin, media_id, filesystem_id,
|
||||||
|
@ -125,5 +132,6 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||||
"thumbnail_type": thumbnail_type,
|
"thumbnail_type": thumbnail_type,
|
||||||
"thumbnail_length": thumbnail_length,
|
"thumbnail_length": thumbnail_length,
|
||||||
"filesystem_id": filesystem_id,
|
"filesystem_id": filesystem_id,
|
||||||
}
|
},
|
||||||
|
desc="store_remote_media_thumbnail",
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,6 +21,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
return self._simple_insert(
|
return self._simple_insert(
|
||||||
table="presence",
|
table="presence",
|
||||||
values={"user_id": user_localpart},
|
values={"user_id": user_localpart},
|
||||||
|
desc="create_presence",
|
||||||
)
|
)
|
||||||
|
|
||||||
def has_presence_state(self, user_localpart):
|
def has_presence_state(self, user_localpart):
|
||||||
|
@ -29,6 +30,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"user_id": user_localpart},
|
||||||
retcols=["user_id"],
|
retcols=["user_id"],
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
|
desc="has_presence_state",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_presence_state(self, user_localpart):
|
def get_presence_state(self, user_localpart):
|
||||||
|
@ -36,6 +38,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
table="presence",
|
table="presence",
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"user_id": user_localpart},
|
||||||
retcols=["state", "status_msg", "mtime"],
|
retcols=["state", "status_msg", "mtime"],
|
||||||
|
desc="get_presence_state",
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_presence_state(self, user_localpart, new_state):
|
def set_presence_state(self, user_localpart, new_state):
|
||||||
|
@ -45,6 +48,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
updatevalues={"state": new_state["state"],
|
updatevalues={"state": new_state["state"],
|
||||||
"status_msg": new_state["status_msg"],
|
"status_msg": new_state["status_msg"],
|
||||||
"mtime": self._clock.time_msec()},
|
"mtime": self._clock.time_msec()},
|
||||||
|
desc="set_presence_state",
|
||||||
)
|
)
|
||||||
|
|
||||||
def allow_presence_visible(self, observed_localpart, observer_userid):
|
def allow_presence_visible(self, observed_localpart, observer_userid):
|
||||||
|
@ -52,6 +56,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
table="presence_allow_inbound",
|
table="presence_allow_inbound",
|
||||||
values={"observed_user_id": observed_localpart,
|
values={"observed_user_id": observed_localpart,
|
||||||
"observer_user_id": observer_userid},
|
"observer_user_id": observer_userid},
|
||||||
|
desc="allow_presence_visible",
|
||||||
)
|
)
|
||||||
|
|
||||||
def disallow_presence_visible(self, observed_localpart, observer_userid):
|
def disallow_presence_visible(self, observed_localpart, observer_userid):
|
||||||
|
@ -59,6 +64,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
table="presence_allow_inbound",
|
table="presence_allow_inbound",
|
||||||
keyvalues={"observed_user_id": observed_localpart,
|
keyvalues={"observed_user_id": observed_localpart,
|
||||||
"observer_user_id": observer_userid},
|
"observer_user_id": observer_userid},
|
||||||
|
desc="disallow_presence_visible",
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_presence_visible(self, observed_localpart, observer_userid):
|
def is_presence_visible(self, observed_localpart, observer_userid):
|
||||||
|
@ -68,6 +74,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
"observer_user_id": observer_userid},
|
"observer_user_id": observer_userid},
|
||||||
retcols=["observed_user_id"],
|
retcols=["observed_user_id"],
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
|
desc="is_presence_visible",
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_presence_list_pending(self, observer_localpart, observed_userid):
|
def add_presence_list_pending(self, observer_localpart, observed_userid):
|
||||||
|
@ -76,6 +83,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
values={"user_id": observer_localpart,
|
values={"user_id": observer_localpart,
|
||||||
"observed_user_id": observed_userid,
|
"observed_user_id": observed_userid,
|
||||||
"accepted": False},
|
"accepted": False},
|
||||||
|
desc="add_presence_list_pending",
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_presence_list_accepted(self, observer_localpart, observed_userid):
|
def set_presence_list_accepted(self, observer_localpart, observed_userid):
|
||||||
|
@ -84,6 +92,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
keyvalues={"user_id": observer_localpart,
|
keyvalues={"user_id": observer_localpart,
|
||||||
"observed_user_id": observed_userid},
|
"observed_user_id": observed_userid},
|
||||||
updatevalues={"accepted": True},
|
updatevalues={"accepted": True},
|
||||||
|
desc="set_presence_list_accepted",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_presence_list(self, observer_localpart, accepted=None):
|
def get_presence_list(self, observer_localpart, accepted=None):
|
||||||
|
@ -95,6 +104,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
table="presence_list",
|
table="presence_list",
|
||||||
keyvalues=keyvalues,
|
keyvalues=keyvalues,
|
||||||
retcols=["observed_user_id", "accepted"],
|
retcols=["observed_user_id", "accepted"],
|
||||||
|
desc="get_presence_list",
|
||||||
)
|
)
|
||||||
|
|
||||||
def del_presence_list(self, observer_localpart, observed_userid):
|
def del_presence_list(self, observer_localpart, observed_userid):
|
||||||
|
@ -102,4 +112,5 @@ class PresenceStore(SQLBaseStore):
|
||||||
table="presence_list",
|
table="presence_list",
|
||||||
keyvalues={"user_id": observer_localpart,
|
keyvalues={"user_id": observer_localpart,
|
||||||
"observed_user_id": observed_userid},
|
"observed_user_id": observed_userid},
|
||||||
|
desc="del_presence_list",
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,6 +21,7 @@ class ProfileStore(SQLBaseStore):
|
||||||
return self._simple_insert(
|
return self._simple_insert(
|
||||||
table="profiles",
|
table="profiles",
|
||||||
values={"user_id": user_localpart},
|
values={"user_id": user_localpart},
|
||||||
|
desc="create_profile",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_profile_displayname(self, user_localpart):
|
def get_profile_displayname(self, user_localpart):
|
||||||
|
@ -28,6 +29,7 @@ class ProfileStore(SQLBaseStore):
|
||||||
table="profiles",
|
table="profiles",
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"user_id": user_localpart},
|
||||||
retcol="displayname",
|
retcol="displayname",
|
||||||
|
desc="get_profile_displayname",
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_profile_displayname(self, user_localpart, new_displayname):
|
def set_profile_displayname(self, user_localpart, new_displayname):
|
||||||
|
@ -35,6 +37,7 @@ class ProfileStore(SQLBaseStore):
|
||||||
table="profiles",
|
table="profiles",
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"user_id": user_localpart},
|
||||||
updatevalues={"displayname": new_displayname},
|
updatevalues={"displayname": new_displayname},
|
||||||
|
desc="set_profile_displayname",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_profile_avatar_url(self, user_localpart):
|
def get_profile_avatar_url(self, user_localpart):
|
||||||
|
@ -42,6 +45,7 @@ class ProfileStore(SQLBaseStore):
|
||||||
table="profiles",
|
table="profiles",
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"user_id": user_localpart},
|
||||||
retcol="avatar_url",
|
retcol="avatar_url",
|
||||||
|
desc="get_profile_avatar_url",
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
|
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
|
||||||
|
@ -49,4 +53,5 @@ class ProfileStore(SQLBaseStore):
|
||||||
table="profiles",
|
table="profiles",
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"user_id": user_localpart},
|
||||||
updatevalues={"avatar_url": new_avatar_url},
|
updatevalues={"avatar_url": new_avatar_url},
|
||||||
|
desc="set_profile_avatar_url",
|
||||||
)
|
)
|
||||||
|
|
|
@ -50,7 +50,8 @@ class PushRuleStore(SQLBaseStore):
|
||||||
results = yield self._simple_select_list(
|
results = yield self._simple_select_list(
|
||||||
PushRuleEnableTable.table_name,
|
PushRuleEnableTable.table_name,
|
||||||
{'user_name': user_name},
|
{'user_name': user_name},
|
||||||
PushRuleEnableTable.fields
|
PushRuleEnableTable.fields,
|
||||||
|
desc="get_push_rules_enabled_for_user",
|
||||||
)
|
)
|
||||||
defer.returnValue(
|
defer.returnValue(
|
||||||
{r['rule_id']: False if r['enabled'] == 0 else True for r in results}
|
{r['rule_id']: False if r['enabled'] == 0 else True for r in results}
|
||||||
|
@ -201,7 +202,8 @@ class PushRuleStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
yield self._simple_delete_one(
|
yield self._simple_delete_one(
|
||||||
PushRuleTable.table_name,
|
PushRuleTable.table_name,
|
||||||
{'user_name': user_name, 'rule_id': rule_id}
|
{'user_name': user_name, 'rule_id': rule_id},
|
||||||
|
desc="delete_push_rule",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -209,7 +211,8 @@ class PushRuleStore(SQLBaseStore):
|
||||||
yield self._simple_upsert(
|
yield self._simple_upsert(
|
||||||
PushRuleEnableTable.table_name,
|
PushRuleEnableTable.table_name,
|
||||||
{'user_name': user_name, 'rule_id': rule_id},
|
{'user_name': user_name, 'rule_id': rule_id},
|
||||||
{'enabled': enabled}
|
{'enabled': enabled},
|
||||||
|
desc="set_push_rule_enabled",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -114,7 +114,9 @@ class PusherStore(SQLBaseStore):
|
||||||
ts=pushkey_ts,
|
ts=pushkey_ts,
|
||||||
lang=lang,
|
lang=lang,
|
||||||
data=data
|
data=data
|
||||||
))
|
),
|
||||||
|
desc="add_pusher",
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("create_pusher with failed: %s", e)
|
logger.error("create_pusher with failed: %s", e)
|
||||||
raise StoreError(500, "Problem creating pusher.")
|
raise StoreError(500, "Problem creating pusher.")
|
||||||
|
@ -123,7 +125,8 @@ class PusherStore(SQLBaseStore):
|
||||||
def delete_pusher_by_app_id_pushkey(self, app_id, pushkey):
|
def delete_pusher_by_app_id_pushkey(self, app_id, pushkey):
|
||||||
yield self._simple_delete_one(
|
yield self._simple_delete_one(
|
||||||
PushersTable.table_name,
|
PushersTable.table_name,
|
||||||
dict(app_id=app_id, pushkey=pushkey)
|
{"app_id": app_id, "pushkey": pushkey},
|
||||||
|
desc="delete_pusher_by_app_id_pushkey",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -131,7 +134,8 @@ class PusherStore(SQLBaseStore):
|
||||||
yield self._simple_update_one(
|
yield self._simple_update_one(
|
||||||
PushersTable.table_name,
|
PushersTable.table_name,
|
||||||
{'app_id': app_id, 'pushkey': pushkey},
|
{'app_id': app_id, 'pushkey': pushkey},
|
||||||
{'last_token': last_token}
|
{'last_token': last_token},
|
||||||
|
desc="update_pusher_last_token",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -140,7 +144,8 @@ class PusherStore(SQLBaseStore):
|
||||||
yield self._simple_update_one(
|
yield self._simple_update_one(
|
||||||
PushersTable.table_name,
|
PushersTable.table_name,
|
||||||
{'app_id': app_id, 'pushkey': pushkey},
|
{'app_id': app_id, 'pushkey': pushkey},
|
||||||
{'last_token': last_token, 'last_success': last_success}
|
{'last_token': last_token, 'last_success': last_success},
|
||||||
|
desc="update_pusher_last_token_and_success",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -148,7 +153,8 @@ class PusherStore(SQLBaseStore):
|
||||||
yield self._simple_update_one(
|
yield self._simple_update_one(
|
||||||
PushersTable.table_name,
|
PushersTable.table_name,
|
||||||
{'app_id': app_id, 'pushkey': pushkey},
|
{'app_id': app_id, 'pushkey': pushkey},
|
||||||
{'failing_since': failing_since}
|
{'failing_since': failing_since},
|
||||||
|
desc="update_pusher_failing_since",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ from sqlite3 import IntegrityError
|
||||||
|
|
||||||
from synapse.api.errors import StoreError, Codes
|
from synapse.api.errors import StoreError, Codes
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore, cached
|
||||||
|
|
||||||
|
|
||||||
class RegistrationStore(SQLBaseStore):
|
class RegistrationStore(SQLBaseStore):
|
||||||
|
@ -44,7 +44,8 @@ class RegistrationStore(SQLBaseStore):
|
||||||
{
|
{
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"token": token
|
"token": token
|
||||||
}
|
},
|
||||||
|
desc="add_access_token_to_user",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -87,6 +88,11 @@ class RegistrationStore(SQLBaseStore):
|
||||||
"get_user_by_id", self.cursor_to_dict, query, user_id
|
"get_user_by_id", self.cursor_to_dict, query, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
# TODO(paul): Currently there's no code to invalidate this cache. That
|
||||||
|
# means if/when we ever add internal ways to invalidate access tokens or
|
||||||
|
# change whether a user is a server admin, those will need to invoke
|
||||||
|
# store.get_user_by_token.invalidate(token)
|
||||||
def get_user_by_token(self, token):
|
def get_user_by_token(self, token):
|
||||||
"""Get a user from the given access token.
|
"""Get a user from the given access token.
|
||||||
|
|
||||||
|
@ -111,6 +117,7 @@ class RegistrationStore(SQLBaseStore):
|
||||||
keyvalues={"name": user.to_string()},
|
keyvalues={"name": user.to_string()},
|
||||||
retcol="admin",
|
retcol="admin",
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
|
desc="is_server_admin",
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(res if res else False)
|
defer.returnValue(res if res else False)
|
||||||
|
|
|
@ -29,7 +29,7 @@ class RejectionsStore(SQLBaseStore):
|
||||||
"event_id": event_id,
|
"event_id": event_id,
|
||||||
"reason": reason,
|
"reason": reason,
|
||||||
"last_check": self._clock.time_msec(),
|
"last_check": self._clock.time_msec(),
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_rejection_reason(self, event_id):
|
def get_rejection_reason(self, event_id):
|
||||||
|
@ -40,4 +40,5 @@ class RejectionsStore(SQLBaseStore):
|
||||||
"event_id": event_id,
|
"event_id": event_id,
|
||||||
},
|
},
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
|
desc="get_rejection_reason",
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,11 +15,9 @@
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from sqlite3 import IntegrityError
|
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
|
|
||||||
from ._base import SQLBaseStore, Table
|
from ._base import SQLBaseStore
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
|
@ -27,8 +25,9 @@ import logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
OpsLevel = collections.namedtuple("OpsLevel", (
|
OpsLevel = collections.namedtuple(
|
||||||
"ban_level", "kick_level", "redact_level")
|
"OpsLevel",
|
||||||
|
("ban_level", "kick_level", "redact_level",)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,13 +46,15 @@ class RoomStore(SQLBaseStore):
|
||||||
StoreError if the room could not be stored.
|
StoreError if the room could not be stored.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
yield self._simple_insert(RoomsTable.table_name, dict(
|
yield self._simple_insert(
|
||||||
room_id=room_id,
|
RoomsTable.table_name,
|
||||||
creator=room_creator_user_id,
|
{
|
||||||
is_public=is_public
|
"room_id": room_id,
|
||||||
))
|
"creator": room_creator_user_id,
|
||||||
except IntegrityError:
|
"is_public": is_public,
|
||||||
raise StoreError(409, "Room ID in use.")
|
},
|
||||||
|
desc="store_room",
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("store_room with room_id=%s failed: %s", room_id, e)
|
logger.error("store_room with room_id=%s failed: %s", room_id, e)
|
||||||
raise StoreError(500, "Problem creating room.")
|
raise StoreError(500, "Problem creating room.")
|
||||||
|
@ -66,12 +67,11 @@ class RoomStore(SQLBaseStore):
|
||||||
Returns:
|
Returns:
|
||||||
A namedtuple containing the room information, or an empty list.
|
A namedtuple containing the room information, or an empty list.
|
||||||
"""
|
"""
|
||||||
query = RoomsTable.select_statement("room_id=?")
|
return self._simple_select_one(
|
||||||
return self._execute(
|
table=RoomsTable.table_name,
|
||||||
"get_room",
|
keyvalues={"room_id": room_id},
|
||||||
lambda txn: RoomsTable.decode_single_result(txn.fetchall()),
|
retcols=RoomsTable.fields,
|
||||||
query,
|
desc="get_room",
|
||||||
room_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -146,7 +146,7 @@ class RoomStore(SQLBaseStore):
|
||||||
"event_id": event.event_id,
|
"event_id": event.event_id,
|
||||||
"room_id": event.room_id,
|
"room_id": event.room_id,
|
||||||
"topic": event.content["topic"],
|
"topic": event.content["topic"],
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _store_room_name_txn(self, txn, event):
|
def _store_room_name_txn(self, txn, event):
|
||||||
|
@ -199,7 +199,7 @@ class RoomStore(SQLBaseStore):
|
||||||
defer.returnValue((name, aliases))
|
defer.returnValue((name, aliases))
|
||||||
|
|
||||||
|
|
||||||
class RoomsTable(Table):
|
class RoomsTable(object):
|
||||||
table_name = "rooms"
|
table_name = "rooms"
|
||||||
|
|
||||||
fields = [
|
fields = [
|
||||||
|
@ -207,5 +207,3 @@ class RoomsTable(Table):
|
||||||
"is_public",
|
"is_public",
|
||||||
"creator"
|
"creator"
|
||||||
]
|
]
|
||||||
|
|
||||||
EntryType = collections.namedtuple("RoomEntry", fields)
|
|
||||||
|
|
|
@ -212,7 +212,8 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
return self._simple_select_onecol(
|
return self._simple_select_onecol(
|
||||||
"room_hosts",
|
"room_hosts",
|
||||||
{"room_id": room_id},
|
{"room_id": room_id},
|
||||||
"host"
|
"host",
|
||||||
|
desc="get_joined_hosts_for_room",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_members_by_dict(self, where_dict):
|
def _get_members_by_dict(self, where_dict):
|
||||||
|
|
|
@ -160,3 +160,4 @@ class StateStore(SQLBaseStore):
|
||||||
|
|
||||||
def _make_group_id(clock):
|
def _make_group_id(clock):
|
||||||
return str(int(clock.time_msec())) + random_string(5)
|
return str(int(clock.time_msec())) + random_string(5)
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ what sort order was used:
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore, cached
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
|
@ -413,6 +413,7 @@ class StreamStore(SQLBaseStore):
|
||||||
"get_recent_events_for_room", get_recent_events_for_room_txn
|
"get_recent_events_for_room", get_recent_events_for_room_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cached(num_args=0)
|
||||||
def get_room_events_max_id(self):
|
def get_room_events_max_id(self):
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"get_room_events_max_id",
|
"get_room_events_max_id",
|
||||||
|
|
|
@ -46,15 +46,19 @@ class TransactionStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_received_txn_response(self, txn, transaction_id, origin):
|
def _get_received_txn_response(self, txn, transaction_id, origin):
|
||||||
where_clause = "transaction_id = ? AND origin = ?"
|
result = self._simple_select_one_txn(
|
||||||
query = ReceivedTransactionsTable.select_statement(where_clause)
|
txn,
|
||||||
|
table=ReceivedTransactionsTable.table_name,
|
||||||
|
keyvalues={
|
||||||
|
"transaction_id": transaction_id,
|
||||||
|
"origin": origin,
|
||||||
|
},
|
||||||
|
retcols=ReceivedTransactionsTable.fields,
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
txn.execute(query, (transaction_id, origin))
|
if result and result.response_code:
|
||||||
|
return result["response_code"], result["response_json"]
|
||||||
results = ReceivedTransactionsTable.decode_results(txn.fetchall())
|
|
||||||
|
|
||||||
if results and results[0].response_code:
|
|
||||||
return (results[0].response_code, results[0].response_json)
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,10 @@
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
|
|
||||||
|
_string_with_symbols = (
|
||||||
|
string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def origin_from_ucid(ucid):
|
def origin_from_ucid(ucid):
|
||||||
return ucid.split("@", 1)[1]
|
return ucid.split("@", 1)[1]
|
||||||
|
@ -23,3 +27,9 @@ def origin_from_ucid(ucid):
|
||||||
|
|
||||||
def random_string(length):
|
def random_string(length):
|
||||||
return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
|
return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
|
||||||
|
|
||||||
|
|
||||||
|
def random_string_with_symbols(length):
|
||||||
|
return ''.join(
|
||||||
|
random.choice(_string_with_symbols) for _ in xrange(length)
|
||||||
|
)
|
||||||
|
|
|
@ -180,7 +180,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
self.mock_txn.rowcount = 1
|
self.mock_txn.rowcount = 1
|
||||||
self.mock_txn.fetchone.return_value = ("Old Value",)
|
self.mock_txn.fetchone.return_value = ("Old Value",)
|
||||||
|
|
||||||
ret = yield self.datastore._simple_update_one(
|
ret = yield self.datastore._simple_selectupdate_one(
|
||||||
table="tablename",
|
table="tablename",
|
||||||
keyvalues={"keycol": "TheKey"},
|
keyvalues={"keycol": "TheKey"},
|
||||||
updatevalues={"columname": "New Value"},
|
updatevalues={"columname": "New Value"},
|
||||||
|
|
|
@ -44,7 +44,7 @@ class RoomStoreTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_room(self):
|
def test_get_room(self):
|
||||||
self.assertObjectHasAttributes(
|
self.assertDictContainsSubset(
|
||||||
{"room_id": self.room.to_string(),
|
{"room_id": self.room.to_string(),
|
||||||
"creator": self.u_creator.to_string(),
|
"creator": self.u_creator.to_string(),
|
||||||
"is_public": True},
|
"is_public": True},
|
||||||
|
|
Loading…
Reference in a new issue