Merge branch 'develop' into anoa/hs_password_reset_sending_email

This commit is contained in:
Andrew Morgan 2019-06-04 19:07:41 +01:00
commit 9567c60ffa
117 changed files with 2310 additions and 1055 deletions

View file

@ -1,8 +1,16 @@
Synapse 0.99.5.2 (2019-05-30)
=============================
Bugfixes
--------
- Fix bug where we leaked extremities when we soft failed events, leading to performance degradation. ([\#5274](https://github.com/matrix-org/synapse/issues/5274), [\#5278](https://github.com/matrix-org/synapse/issues/5278), [\#5291](https://github.com/matrix-org/synapse/issues/5291))
Synapse 0.99.5.1 (2019-05-22) Synapse 0.99.5.1 (2019-05-22)
============================= =============================
No significant changes. 0.99.5.1 supersedes 0.99.5 due to malformed debian changelog - no functional changes.
Synapse 0.99.5 (2019-05-22) Synapse 0.99.5 (2019-05-22)
=========================== ===========================

1
changelog.d/5216.misc Normal file
View file

@ -0,0 +1 @@
Synapse will now serve the experimental "room complexity" API endpoint.

1
changelog.d/5226.misc Normal file
View file

@ -0,0 +1 @@
The base classes for the v1 and v2_alpha REST APIs have been unified.

1
changelog.d/5250.misc Normal file
View file

@ -0,0 +1 @@
Simplification to Keyring.wait_for_previous_lookups.

1
changelog.d/5251.bugfix Normal file
View file

@ -0,0 +1 @@
Ensure that server_keys fetched via a notary server are correctly signed.

1
changelog.d/5256.bugfix Normal file
View file

@ -0,0 +1 @@
Show the correct error when logging out and access token is missing.

1
changelog.d/5257.bugfix Normal file
View file

@ -0,0 +1 @@
Fix error code when there is an invalid parameter on /_matrix/client/r0/publicRooms

1
changelog.d/5258.bugfix Normal file
View file

@ -0,0 +1 @@
Fix error when downloading thumbnail with missing width/height parameter.

1
changelog.d/5260.feature Normal file
View file

@ -0,0 +1 @@
Synapse now more efficiently collates room statistics.

1
changelog.d/5268.bugfix Normal file
View file

@ -0,0 +1 @@
Fix schema update for account validity.

1
changelog.d/5274.bugfix Normal file
View file

@ -0,0 +1 @@
Fix bug where we leaked extremities when we soft failed events, leading to performance degradation.

1
changelog.d/5275.bugfix Normal file
View file

@ -0,0 +1 @@
Fix "db txn 'update_presence' from sentinel context" log messages.

1
changelog.d/5276.feature Normal file
View file

@ -0,0 +1 @@
Allow configuring a range for the account validity startup job.

1
changelog.d/5277.bugfix Normal file
View file

@ -0,0 +1 @@
Fix dropped logcontexts during high outbound traffic.

1
changelog.d/5278.bugfix Normal file
View file

@ -0,0 +1 @@
Fix bug where we leaked extremities when we soft failed events, leading to performance degradation.

1
changelog.d/5282.doc Normal file
View file

@ -0,0 +1 @@
Fix docs on resetting the user directory.

1
changelog.d/5283.misc Normal file
View file

@ -0,0 +1 @@
Specify the type of reCAPTCHA key to use.

1
changelog.d/5286.feature Normal file
View file

@ -0,0 +1 @@
CAS login will now hit the r0 API, not the deprecated v1 one.

1
changelog.d/5287.misc Normal file
View file

@ -0,0 +1 @@
Remove spurious debug from MatrixFederationHttpClient.get_json.

1
changelog.d/5288.misc Normal file
View file

@ -0,0 +1 @@
Improve logging for logcontext leaks.

1
changelog.d/5291.bugfix Normal file
View file

@ -0,0 +1 @@
Fix bug where we leaked extremities when we soft failed events, leading to performance degradation.

1
changelog.d/5293.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug where it is not possible to get events in the federation format with the request `GET /_matrix/client/r0/rooms/{roomId}/messages`.

1
changelog.d/5294.bugfix Normal file
View file

@ -0,0 +1 @@
Fix performance problems with the rooms stats background update.

1
changelog.d/5296.misc Normal file
View file

@ -0,0 +1 @@
Refactor keyring.VerifyKeyRequest to use attr.s.

1
changelog.d/5299.misc Normal file
View file

@ -0,0 +1 @@
Rewrite get_server_verify_keys, again.

1
changelog.d/5300.bugfix Normal file
View file

@ -0,0 +1 @@
Fix noisy 'no key for server' logs.

1
changelog.d/5303.misc Normal file
View file

@ -0,0 +1 @@
Clarify that the admin change password API logs the user out.

1
changelog.d/5307.bugfix Normal file
View file

@ -0,0 +1 @@
Fix bug where a notary server would sometimes forget old keys.

1
changelog.d/5309.bugfix Normal file
View file

@ -0,0 +1 @@
Prevent users from setting huge displaynames and avatar URLs.

1
changelog.d/5321.bugfix Normal file
View file

@ -0,0 +1 @@
Ensure that we have an up-to-date copy of the signing key when validating incoming federation requests.

1
changelog.d/5324.feature Normal file
View file

@ -0,0 +1 @@
Synapse now more efficiently collates room statistics.

1
changelog.d/5328.misc Normal file
View file

@ -0,0 +1 @@
The base classes for the v1 and v2_alpha REST APIs have been unified.

1
changelog.d/5332.misc Normal file
View file

@ -0,0 +1 @@
Improve docstrings on MatrixFederationClient.

1
changelog.d/5333.bugfix Normal file
View file

@ -0,0 +1 @@
Fix various problems which made the signing-key notary server time out for some requests.

1
changelog.d/5334.bugfix Normal file
View file

@ -0,0 +1 @@
Fix bug which would make certain operations (such as room joins) block for 20 minutes while attemoting to fetch verification keys.

1
changelog.d/5335.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug where we could rapidly mark a server as unreachable even though it was only down for a few minutes.

1
changelog.d/5341.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug where account validity renewal emails could only be sent when email notifs were enabled.

6
debian/changelog vendored
View file

@ -1,3 +1,9 @@
matrix-synapse-py3 (0.99.5.2) stable; urgency=medium
* New synapse release 0.99.5.2.
-- Synapse Packaging team <packages@matrix.org> Thu, 30 May 2019 16:28:07 +0100
matrix-synapse-py3 (0.99.5.1) stable; urgency=medium matrix-synapse-py3 (0.99.5.1) stable; urgency=medium
* New synapse release 0.99.5.1. * New synapse release 0.99.5.1.

View file

@ -7,6 +7,7 @@ Requires a public/private key pair from:
https://developers.google.com/recaptcha/ https://developers.google.com/recaptcha/
Must be a reCAPTCHA v2 key using the "I'm not a robot" Checkbox option
Setting ReCaptcha Keys Setting ReCaptcha Keys
---------------------- ----------------------

View file

@ -69,7 +69,7 @@ An empty body may be passed for backwards compatibility.
Reset password Reset password
============== ==============
Changes the password of another user. Changes the password of another user. This will automatically log the user out of all their devices.
The api is:: The api is::

View file

@ -763,7 +763,9 @@ uploads_path: "DATADIR/uploads"
# This means that, if a validity period is set, and Synapse is restarted (it will # This means that, if a validity period is set, and Synapse is restarted (it will
# then derive an expiration date from the current validity period), and some time # then derive an expiration date from the current validity period), and some time
# after that the validity period changes and Synapse is restarted, the users' # after that the validity period changes and Synapse is restarted, the users'
# expiration dates won't be updated unless their account is manually renewed. # expiration dates won't be updated unless their account is manually renewed. This
# date will be randomly selected within a range [now + period - d ; now + period],
# where d is equal to 10% of the validity period.
# #
#account_validity: #account_validity:
# enabled: True # enabled: True
@ -1103,9 +1105,9 @@ password_config:
# #
# 'search_all_users' defines whether to search all users visible to your HS # 'search_all_users' defines whether to search all users visible to your HS
# when searching the user directory, rather than limiting to users visible # when searching the user directory, rather than limiting to users visible
# in public rooms. Defaults to false. If you set it True, you'll have to run # in public rooms. Defaults to false. If you set it True, you'll have to
# UPDATE user_directory_stream_pos SET stream_id = NULL; # rebuild the user_directory search indexes, see
# on your database to tell it to rebuild the user_directory search indexes. # https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md
# #
#user_directory: #user_directory:
# enabled: true # enabled: true

View file

@ -7,11 +7,7 @@ who are present in a publicly viewable room present on the server.
The directory info is stored in various tables, which can (typically after The directory info is stored in various tables, which can (typically after
DB corruption) get stale or out of sync. If this happens, for now the DB corruption) get stale or out of sync. If this happens, for now the
quickest solution to fix it is: solution to fix it is to execute the SQL here
https://github.com/matrix-org/synapse/blob/master/synapse/storage/schema/delta/53/user_dir_populate.sql
``` and then restart synapse. This should then start a background task to
UPDATE user_directory_stream_pos SET stream_id = NULL;
```
and restart the synapse, which should then start a background task to
flush the current tables and regenerate the directory. flush the current tables and regenerate the directory.

View file

@ -20,9 +20,7 @@ class CallVisitor(ast.NodeVisitor):
else: else:
return return
if name == "client_path_patterns": if name == "client_patterns":
PATTERNS_V1.append(node.args[0].s)
elif name == "client_v2_patterns":
PATTERNS_V2.append(node.args[0].s) PATTERNS_V2.append(node.args[0].s)

View file

@ -27,4 +27,4 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "0.99.5.1" __version__ = "0.99.5.2"

View file

@ -344,15 +344,21 @@ class _LimitedHostnameResolver(object):
def resolveHostName(self, resolutionReceiver, hostName, portNumber=0, def resolveHostName(self, resolutionReceiver, hostName, portNumber=0,
addressTypes=None, transportSemantics='TCP'): addressTypes=None, transportSemantics='TCP'):
# Note this is happening deep within the reactor, so we don't need to
# worry about log contexts.
# We need this function to return `resolutionReceiver` so we do all the # We need this function to return `resolutionReceiver` so we do all the
# actual logic involving deferreds in a separate function. # actual logic involving deferreds in a separate function.
self._resolve(
resolutionReceiver, hostName, portNumber, # even though this is happening within the depths of twisted, we need to drop
addressTypes, transportSemantics, # our logcontext before starting _resolve, otherwise: (a) _resolve will drop
) # the logcontext if it returns an incomplete deferred; (b) _resolve will
# call the resolutionReceiver *with* a logcontext, which it won't be expecting.
with PreserveLoggingContext():
self._resolve(
resolutionReceiver,
hostName,
portNumber,
addressTypes,
transportSemantics,
)
return resolutionReceiver return resolutionReceiver

View file

@ -37,8 +37,7 @@ from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.base import ClientV1RestServlet, client_path_patterns from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.client.v2_alpha._base import client_v2_patterns
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
@ -49,11 +48,11 @@ from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse.app.frontend_proxy") logger = logging.getLogger("synapse.app.frontend_proxy")
class PresenceStatusStubServlet(ClientV1RestServlet): class PresenceStatusStubServlet(RestServlet):
PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status") PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status")
def __init__(self, hs): def __init__(self, hs):
super(PresenceStatusStubServlet, self).__init__(hs) super(PresenceStatusStubServlet, self).__init__()
self.http_client = hs.get_simple_http_client() self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.main_uri = hs.config.worker_main_http_uri self.main_uri = hs.config.worker_main_http_uri
@ -84,7 +83,7 @@ class PresenceStatusStubServlet(ClientV1RestServlet):
class KeyUploadServlet(RestServlet): class KeyUploadServlet(RestServlet):
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs): def __init__(self, hs):
""" """

View file

@ -1,5 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -29,10 +31,48 @@ logger = logging.getLogger(__name__)
class EmailConfig(Config): class EmailConfig(Config):
def read_config(self, config): def read_config(self, config):
# TODO: We should separate better the email configuration from the notification
# and account validity config.
self.email_enable_notifs = False self.email_enable_notifs = False
email_config = config.get("email", {}) email_config = config.get("email", {})
self.email_smtp_host = email_config.get("smtp_host", None)
self.email_smtp_port = email_config.get("smtp_port", None)
self.email_smtp_user = email_config.get("smtp_user", None)
self.email_smtp_pass = email_config.get("smtp_pass", None)
self.require_transport_security = email_config.get(
"require_transport_security", False
)
if "app_name" in email_config:
self.email_app_name = email_config["app_name"]
else:
self.email_app_name = "Matrix"
self.email_notif_from = email_config.get("notif_from", None)
if self.email_notif_from is not None:
# make sure it's valid
parsed = email.utils.parseaddr(self.email_notif_from)
if parsed[1] == '':
raise RuntimeError("Invalid notif_from address")
template_dir = email_config.get("template_dir")
# we need an absolute path, because we change directory after starting (and
# we don't yet know what auxilliary templates like mail.css we will need).
# (Note that loading as package_resources with jinja.PackageLoader doesn't
# work for the same reason.)
if not template_dir:
template_dir = pkg_resources.resource_filename(
'synapse', 'res/templates'
)
self.email_template_dir = os.path.abspath(template_dir)
self.email_enable_notifs = email_config.get("enable_notifs", False) self.email_enable_notifs = email_config.get("enable_notifs", False)
account_validity_renewal_enabled = config.get(
"account_validity", {},
).get("renew_at")
self.email_enable_password_reset_from_is = email_config.get( self.email_enable_password_reset_from_is = email_config.get(
"enable_password_reset_from_is", False, "enable_password_reset_from_is", False,
@ -42,7 +82,11 @@ class EmailConfig(Config):
"validation_token_lifetime", 15 * 60, "validation_token_lifetime", 15 * 60,
) )
if email_config != {}: if (
self.email_enable_notifs
or account_validity_renewal_enabled
or self.email_enable_password_reset_from_is
):
# make sure we can import the required deps # make sure we can import the required deps
import jinja2 import jinja2
import bleach import bleach
@ -50,39 +94,6 @@ class EmailConfig(Config):
jinja2 jinja2
bleach bleach
self.email_smtp_host = email_config["smtp_host"]
self.email_smtp_port = email_config["smtp_port"]
self.email_notif_from = email_config["notif_from"]
template_dir = email_config.get("template_dir")
# we need an absolute path, because we change directory after starting (and
# we don't yet know what auxilliary templates like mail.css we will need).
# (Note that loading as package_resources with jinja.PackageLoader doesn't
# work for the same reason.)
if not template_dir:
template_dir = pkg_resources.resource_filename(
'synapse', 'res/templates'
)
self.email_template_dir = os.path.abspath(template_dir)
self.email_riot_base_url = email_config.get(
"riot_base_url", None
)
self.email_smtp_user = email_config.get(
"smtp_user", None
)
self.email_smtp_pass = email_config.get(
"smtp_pass", None
)
self.require_transport_security = email_config.get(
"require_transport_security", False
)
self.email_app_name = email_config.get("app_name", "Matrix")
# make sure it's valid
parsed = email.utils.parseaddr(self.email_notif_from)
if parsed[1] == '':
raise RuntimeError("Invalid notif_from address")
if not self.email_enable_password_reset_from_is: if not self.email_enable_password_reset_from_is:
required = [ required = [
"smtp_host", "smtp_host",
@ -150,12 +161,6 @@ class EmailConfig(Config):
self.email_notif_template_html = email_config["notif_template_html"] self.email_notif_template_html = email_config["notif_template_html"]
self.email_notif_template_text = email_config["notif_template_text"] self.email_notif_template_text = email_config["notif_template_text"]
self.email_expiry_template_html = email_config.get(
"expiry_template_html", "notice_expiry.html",
)
self.email_expiry_template_text = email_config.get(
"expiry_template_text", "notice_expiry.txt",
)
for f in self.email_notif_template_text, self.email_notif_template_html: for f in self.email_notif_template_text, self.email_notif_template_html:
p = os.path.join(self.email_template_dir, f) p = os.path.join(self.email_template_dir, f)
@ -165,6 +170,26 @@ class EmailConfig(Config):
self.email_notif_for_new_users = email_config.get( self.email_notif_for_new_users = email_config.get(
"notif_for_new_users", True "notif_for_new_users", True
) )
self.email_riot_base_url = email_config.get(
"riot_base_url", None
)
else:
self.email_enable_notifs = False
# Not much point setting defaults for the rest: it would be an
# error for them to be used.
if account_validity_renewal_enabled:
self.email_expiry_template_html = email_config.get(
"expiry_template_html", "notice_expiry.html",
)
self.email_expiry_template_text = email_config.get(
"expiry_template_text", "notice_expiry.txt",
)
for f in self.email_expiry_template_text, self.email_expiry_template_html:
p = os.path.join(self.email_template_dir, f)
if not os.path.isfile(p):
raise ConfigError("Unable to find email template file %s" % (p, ))
def _get_template_content(self, template_dir, path): def _get_template_content(self, template_dir, path):
fullpath = os.path.join(template_dir, path) fullpath = os.path.join(template_dir, path)

View file

@ -39,6 +39,8 @@ class AccountValidityConfig(Config):
else: else:
self.renew_email_subject = "Renew your %(app)s account" self.renew_email_subject = "Renew your %(app)s account"
self.startup_job_max_delta = self.period * 10. / 100.
if self.renew_by_email_enabled and "public_baseurl" not in synapse_config: if self.renew_by_email_enabled and "public_baseurl" not in synapse_config:
raise ConfigError("Can't send renewal emails without 'public_baseurl'") raise ConfigError("Can't send renewal emails without 'public_baseurl'")
@ -129,7 +131,9 @@ class RegistrationConfig(Config):
# This means that, if a validity period is set, and Synapse is restarted (it will # This means that, if a validity period is set, and Synapse is restarted (it will
# then derive an expiration date from the current validity period), and some time # then derive an expiration date from the current validity period), and some time
# after that the validity period changes and Synapse is restarted, the users' # after that the validity period changes and Synapse is restarted, the users'
# expiration dates won't be updated unless their account is manually renewed. # expiration dates won't be updated unless their account is manually renewed. This
# date will be randomly selected within a range [now + period - d ; now + period],
# where d is equal to 10%% of the validity period.
# #
#account_validity: #account_validity:
# enabled: True # enabled: True

View file

@ -43,9 +43,9 @@ class UserDirectoryConfig(Config):
# #
# 'search_all_users' defines whether to search all users visible to your HS # 'search_all_users' defines whether to search all users visible to your HS
# when searching the user directory, rather than limiting to users visible # when searching the user directory, rather than limiting to users visible
# in public rooms. Defaults to false. If you set it True, you'll have to run # in public rooms. Defaults to false. If you set it True, you'll have to
# UPDATE user_directory_stream_pos SET stream_id = NULL; # rebuild the user_directory search indexes, see
# on your database to tell it to rebuild the user_directory search indexes. # https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md
# #
#user_directory: #user_directory:
# enabled: true # enabled: true

View file

@ -15,11 +15,13 @@
# limitations under the License. # limitations under the License.
import logging import logging
from collections import namedtuple from collections import defaultdict
import six
from six import raise_from from six import raise_from
from six.moves import urllib from six.moves import urllib
import attr
from signedjson.key import ( from signedjson.key import (
decode_verify_key_bytes, decode_verify_key_bytes,
encode_verify_key_base64, encode_verify_key_base64,
@ -44,6 +46,7 @@ from synapse.api.errors import (
) )
from synapse.storage.keys import FetchKeyResult from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext, unwrapFirstError from synapse.util import logcontext, unwrapFirstError
from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.logcontext import ( from synapse.util.logcontext import (
LoggingContext, LoggingContext,
PreserveLoggingContext, PreserveLoggingContext,
@ -56,22 +59,36 @@ from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VerifyKeyRequest = namedtuple( @attr.s(slots=True, cmp=False)
"VerifyRequest", ("server_name", "key_ids", "json_object", "deferred") class VerifyKeyRequest(object):
) """
""" A request for a verify key to verify a JSON object.
A request for a verify key to verify a JSON object.
Attributes: Attributes:
server_name(str): The name of the server to verify against. server_name(str): The name of the server to verify against.
key_ids(set(str)): The set of key_ids to that could be used to verify the
JSON object key_ids(set[str]): The set of key_ids to that could be used to verify the
json_object(dict): The JSON object to verify. JSON object
deferred(Deferred[str, str, nacl.signing.VerifyKey]):
A deferred (server_name, key_id, verify_key) tuple that resolves when json_object(dict): The JSON object to verify.
a verify key has been fetched. The deferreds' callbacks are run with no
logcontext. minimum_valid_until_ts (int): time at which we require the signing key to
""" be valid. (0 implies we don't care)
deferred(Deferred[str, str, nacl.signing.VerifyKey]):
A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched. The deferreds' callbacks are run with no
logcontext.
If we are unable to find a key which satisfies the request, the deferred
errbacks with an M_UNAUTHORIZED SynapseError.
"""
server_name = attr.ib()
key_ids = attr.ib()
json_object = attr.ib()
minimum_valid_until_ts = attr.ib()
deferred = attr.ib(default=attr.Factory(defer.Deferred))
class KeyLookupError(ValueError): class KeyLookupError(ValueError):
@ -79,14 +96,16 @@ class KeyLookupError(ValueError):
class Keyring(object): class Keyring(object):
def __init__(self, hs): def __init__(self, hs, key_fetchers=None):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._key_fetchers = ( if key_fetchers is None:
StoreKeyFetcher(hs), key_fetchers = (
PerspectivesKeyFetcher(hs), StoreKeyFetcher(hs),
ServerKeyFetcher(hs), PerspectivesKeyFetcher(hs),
) ServerKeyFetcher(hs),
)
self._key_fetchers = key_fetchers
# map from server name to Deferred. Has an entry for each server with # map from server name to Deferred. Has an entry for each server with
# an ongoing key download; the Deferred completes once the download # an ongoing key download; the Deferred completes once the download
@ -95,9 +114,25 @@ class Keyring(object):
# These are regular, logcontext-agnostic Deferreds. # These are regular, logcontext-agnostic Deferreds.
self.key_downloads = {} self.key_downloads = {}
def verify_json_for_server(self, server_name, json_object): def verify_json_for_server(self, server_name, json_object, validity_time):
"""Verify that a JSON object has been signed by a given server
Args:
server_name (str): name of the server which must have signed this object
json_object (dict): object to be checked
validity_time (int): timestamp at which we require the signing key to
be valid. (0 implies we don't care)
Returns:
Deferred[None]: completes if the the object was correctly signed, otherwise
errbacks with an error
"""
req = server_name, json_object, validity_time
return logcontext.make_deferred_yieldable( return logcontext.make_deferred_yieldable(
self.verify_json_objects_for_server([(server_name, json_object)])[0] self.verify_json_objects_for_server((req,))[0]
) )
def verify_json_objects_for_server(self, server_and_json): def verify_json_objects_for_server(self, server_and_json):
@ -105,10 +140,12 @@ class Keyring(object):
necessary. necessary.
Args: Args:
server_and_json (list): List of pairs of (server_name, json_object) server_and_json (iterable[Tuple[str, dict, int]):
Iterable of triplets of (server_name, json_object, validity_time)
validity_time is a timestamp at which the signing key must be valid.
Returns: Returns:
List<Deferred>: for each input pair, a deferred indicating success List<Deferred[None]>: for each input triplet, a deferred indicating success
or failure to verify each json object's signature for the given or failure to verify each json object's signature for the given
server_name. The deferreds run their callbacks in the sentinel server_name. The deferreds run their callbacks in the sentinel
logcontext. logcontext.
@ -117,12 +154,12 @@ class Keyring(object):
verify_requests = [] verify_requests = []
handle = preserve_fn(_handle_key_deferred) handle = preserve_fn(_handle_key_deferred)
def process(server_name, json_object): def process(server_name, json_object, validity_time):
"""Process an entry in the request list """Process an entry in the request list
Given a (server_name, json_object) pair from the request list, Given a (server_name, json_object, validity_time) triplet from the request
adds a key request to verify_requests, and returns a deferred which will list, adds a key request to verify_requests, and returns a deferred which
complete or fail (in the sentinel context) when verification completes. will complete or fail (in the sentinel context) when verification completes.
""" """
key_ids = signature_ids(json_object, server_name) key_ids = signature_ids(json_object, server_name)
@ -133,11 +170,16 @@ class Keyring(object):
) )
) )
logger.debug("Verifying for %s with key_ids %s", server_name, key_ids) logger.debug(
"Verifying for %s with key_ids %s, min_validity %i",
server_name,
key_ids,
validity_time,
)
# add the key request to the queue, but don't start it off yet. # add the key request to the queue, but don't start it off yet.
verify_request = VerifyKeyRequest( verify_request = VerifyKeyRequest(
server_name, key_ids, json_object, defer.Deferred() server_name, key_ids, json_object, validity_time
) )
verify_requests.append(verify_request) verify_requests.append(verify_request)
@ -149,8 +191,8 @@ class Keyring(object):
return handle(verify_request) return handle(verify_request)
results = [ results = [
process(server_name, json_object) process(server_name, json_object, validity_time)
for server_name, json_object in server_and_json for server_name, json_object, validity_time in server_and_json
] ]
if verify_requests: if verify_requests:
@ -180,9 +222,7 @@ class Keyring(object):
# We want to wait for any previous lookups to complete before # We want to wait for any previous lookups to complete before
# proceeding. # proceeding.
yield self.wait_for_previous_lookups( yield self.wait_for_previous_lookups(server_to_deferred)
[rq.server_name for rq in verify_requests], server_to_deferred
)
# Actually start fetching keys. # Actually start fetching keys.
self._get_server_verify_keys(verify_requests) self._get_server_verify_keys(verify_requests)
@ -215,12 +255,11 @@ class Keyring(object):
logger.exception("Error starting key lookups") logger.exception("Error starting key lookups")
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_previous_lookups(self, server_names, server_to_deferred): def wait_for_previous_lookups(self, server_to_deferred):
"""Waits for any previous key lookups for the given servers to finish. """Waits for any previous key lookups for the given servers to finish.
Args: Args:
server_names (list): list of server_names we want to lookup server_to_deferred (dict[str, Deferred]): server_name to deferred which gets
server_to_deferred (dict): server_name to deferred which gets
resolved once we've finished looking up keys for that server. resolved once we've finished looking up keys for that server.
The Deferreds should be regular twisted ones which call their The Deferreds should be regular twisted ones which call their
callbacks with no logcontext. callbacks with no logcontext.
@ -233,7 +272,7 @@ class Keyring(object):
while True: while True:
wait_on = [ wait_on = [
(server_name, self.key_downloads[server_name]) (server_name, self.key_downloads[server_name])
for server_name in server_names for server_name in server_to_deferred.keys()
if server_name in self.key_downloads if server_name in self.key_downloads
] ]
if not wait_on: if not wait_on:
@ -272,83 +311,112 @@ class Keyring(object):
verify_requests (list[VerifyKeyRequest]): list of verify requests verify_requests (list[VerifyKeyRequest]): list of verify requests
""" """
remaining_requests = set(
(rq for rq in verify_requests if not rq.deferred.called)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def do_iterations(): def do_iterations():
with Measure(self.clock, "get_server_verify_keys"): with Measure(self.clock, "get_server_verify_keys"):
# dict[str, set(str)]: keys to fetch for each server
missing_keys = {}
for verify_request in verify_requests:
missing_keys.setdefault(verify_request.server_name, set()).update(
verify_request.key_ids
)
for f in self._key_fetchers: for f in self._key_fetchers:
results = yield f.get_keys(missing_keys.items()) if not remaining_requests:
return
# We now need to figure out which verify requests we have keys yield self._attempt_key_fetches_with_fetcher(f, remaining_requests)
# for and which we don't
missing_keys = {}
requests_missing_keys = []
for verify_request in verify_requests:
if verify_request.deferred.called:
# We've already called this deferred, which probably
# means that we've already found a key for it.
continue
server_name = verify_request.server_name
# see if any of the keys we got this time are sufficient to
# complete this VerifyKeyRequest.
result_keys = results.get(server_name, {})
for key_id in verify_request.key_ids:
fetch_key_result = result_keys.get(key_id)
if fetch_key_result:
with PreserveLoggingContext():
verify_request.deferred.callback(
(
server_name,
key_id,
fetch_key_result.verify_key,
)
)
break
else:
# The else block is only reached if the loop above
# doesn't break.
missing_keys.setdefault(server_name, set()).update(
verify_request.key_ids
)
requests_missing_keys.append(verify_request)
if not missing_keys:
break
# look for any requests which weren't satisfied
with PreserveLoggingContext(): with PreserveLoggingContext():
for verify_request in requests_missing_keys: for verify_request in remaining_requests:
verify_request.deferred.errback( verify_request.deferred.errback(
SynapseError( SynapseError(
401, 401,
"No key for %s with id %s" "No key for %s with ids in %s (min_validity %i)"
% (verify_request.server_name, verify_request.key_ids), % (
verify_request.server_name,
verify_request.key_ids,
verify_request.minimum_valid_until_ts,
),
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
) )
def on_err(err): def on_err(err):
# we don't really expect to get here, because any errors should already
# have been caught and logged. But if we do, let's log the error and make
# sure that all of the deferreds are resolved.
logger.error("Unexpected error in _get_server_verify_keys: %s", err)
with PreserveLoggingContext(): with PreserveLoggingContext():
for verify_request in verify_requests: for verify_request in remaining_requests:
if not verify_request.deferred.called: if not verify_request.deferred.called:
verify_request.deferred.errback(err) verify_request.deferred.errback(err)
run_in_background(do_iterations).addErrback(on_err) run_in_background(do_iterations).addErrback(on_err)
@defer.inlineCallbacks
def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
"""Use a key fetcher to attempt to satisfy some key requests
Args:
fetcher (KeyFetcher): fetcher to use to fetch the keys
remaining_requests (set[VerifyKeyRequest]): outstanding key requests.
Any successfully-completed requests will be removed from the list.
"""
# dict[str, dict[str, int]]: keys to fetch.
# server_name -> key_id -> min_valid_ts
missing_keys = defaultdict(dict)
for verify_request in remaining_requests:
# any completed requests should already have been removed
assert not verify_request.deferred.called
keys_for_server = missing_keys[verify_request.server_name]
for key_id in verify_request.key_ids:
# If we have several requests for the same key, then we only need to
# request that key once, but we should do so with the greatest
# min_valid_until_ts of the requests, so that we can satisfy all of
# the requests.
keys_for_server[key_id] = max(
keys_for_server.get(key_id, -1),
verify_request.minimum_valid_until_ts
)
results = yield fetcher.get_keys(missing_keys)
completed = list()
for verify_request in remaining_requests:
server_name = verify_request.server_name
# see if any of the keys we got this time are sufficient to
# complete this VerifyKeyRequest.
result_keys = results.get(server_name, {})
for key_id in verify_request.key_ids:
fetch_key_result = result_keys.get(key_id)
if not fetch_key_result:
# we didn't get a result for this key
continue
if (
fetch_key_result.valid_until_ts
< verify_request.minimum_valid_until_ts
):
# key was not valid at this point
continue
with PreserveLoggingContext():
verify_request.deferred.callback(
(server_name, key_id, fetch_key_result.verify_key)
)
completed.append(verify_request)
break
remaining_requests.difference_update(completed)
class KeyFetcher(object): class KeyFetcher(object):
def get_keys(self, server_name_and_key_ids): def get_keys(self, keys_to_fetch):
""" """
Args: Args:
server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]): keys_to_fetch (dict[str, dict[str, int]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for the keys to be fetched. server_name -> key_id -> min_valid_ts
Returns: Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
@ -364,13 +432,15 @@ class StoreKeyFetcher(KeyFetcher):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids): def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys""" """see KeyFetcher.get_keys"""
keys_to_fetch = ( keys_to_fetch = (
(server_name, key_id) (server_name, key_id)
for server_name, key_ids in server_name_and_key_ids for server_name, keys_for_server in keys_to_fetch.items()
for key_id in key_ids for key_id in keys_for_server.keys()
) )
res = yield self.store.get_server_verify_keys(keys_to_fetch) res = yield self.store.get_server_verify_keys(keys_to_fetch)
keys = {} keys = {}
for (server_name, key_id), key in res.items(): for (server_name, key_id), key in res.items():
@ -385,7 +455,7 @@ class BaseV2KeyFetcher(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def process_v2_response( def process_v2_response(
self, from_server, response_json, time_added_ms, requested_ids=[] self, from_server, response_json, time_added_ms
): ):
"""Parse a 'Server Keys' structure from the result of a /key request """Parse a 'Server Keys' structure from the result of a /key request
@ -394,8 +464,7 @@ class BaseV2KeyFetcher(object):
POST /_matrix/key/v2/query. POST /_matrix/key/v2/query.
Checks that each signature in the response that claims to come from the origin Checks that each signature in the response that claims to come from the origin
server is valid. (Does not check that there actually is such a signature, for server is valid, and that there is at least one such signature.
some reason.)
Stores the json in server_keys_json so that it can be used for future responses Stores the json in server_keys_json so that it can be used for future responses
to /_matrix/key/v2/query. to /_matrix/key/v2/query.
@ -409,10 +478,6 @@ class BaseV2KeyFetcher(object):
time_added_ms (int): the timestamp to record in server_keys_json time_added_ms (int): the timestamp to record in server_keys_json
requested_ids (iterable[str]): a list of the key IDs that were requested.
We will store the json for these key ids as well as any that are
actually in the response
Returns: Returns:
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
""" """
@ -430,16 +495,25 @@ class BaseV2KeyFetcher(object):
verify_key=verify_key, valid_until_ts=ts_valid_until_ms verify_key=verify_key, valid_until_ts=ts_valid_until_ms
) )
# TODO: improve this signature checking
server_name = response_json["server_name"] server_name = response_json["server_name"]
verified = False
for key_id in response_json["signatures"].get(server_name, {}): for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in verify_keys: # each of the keys used for the signature must be present in the response
# json.
key = verify_keys.get(key_id)
if not key:
raise KeyLookupError( raise KeyLookupError(
"Key response must include verification keys for all signatures" "Key response is signed by key id %s:%s but that key is not "
"present in the response" % (server_name, key_id)
) )
verify_signed_json( verify_signed_json(response_json, server_name, key.verify_key)
response_json, server_name, verify_keys[key_id].verify_key verified = True
if not verified:
raise KeyLookupError(
"Key response for %s is not signed by the origin server"
% (server_name,)
) )
for key_id, key_data in response_json["old_verify_keys"].items(): for key_id, key_data in response_json["old_verify_keys"].items():
@ -459,11 +533,6 @@ class BaseV2KeyFetcher(object):
signed_key_json_bytes = encode_canonical_json(signed_key_json) signed_key_json_bytes = encode_canonical_json(signed_key_json)
# for reasons I don't quite understand, we store this json for the key ids we
# requested, as well as those we got.
updated_key_ids = set(requested_ids)
updated_key_ids.update(verify_keys)
yield logcontext.make_deferred_yieldable( yield logcontext.make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[ [
@ -476,7 +545,7 @@ class BaseV2KeyFetcher(object):
ts_expires_ms=ts_valid_until_ms, ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes, key_json_bytes=signed_key_json_bytes,
) )
for key_id in updated_key_ids for key_id in verify_keys
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
@ -495,14 +564,14 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
self.perspective_servers = self.config.perspectives self.perspective_servers = self.config.perspectives
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids): def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys""" """see KeyFetcher.get_keys"""
@defer.inlineCallbacks @defer.inlineCallbacks
def get_key(perspective_name, perspective_keys): def get_key(perspective_name, perspective_keys):
try: try:
result = yield self.get_server_verify_key_v2_indirect( result = yield self.get_server_verify_key_v2_indirect(
server_name_and_key_ids, perspective_name, perspective_keys keys_to_fetch, perspective_name, perspective_keys
) )
defer.returnValue(result) defer.returnValue(result)
except KeyLookupError as e: except KeyLookupError as e:
@ -536,23 +605,32 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_indirect( def get_server_verify_key_v2_indirect(
self, server_names_and_key_ids, perspective_name, perspective_keys self, keys_to_fetch, perspective_name, perspective_keys
): ):
""" """
Args: Args:
server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]): keys_to_fetch (dict[str, dict[str, int]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for the keys to be fetched. server_name -> key_id -> min_valid_ts
perspective_name (str): name of the notary server to query for the keys perspective_name (str): name of the notary server to query for the keys
perspective_keys (dict[str, VerifyKey]): map of key_id->key for the perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
notary server notary server
Returns: Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map
from server_name -> key_id -> FetchKeyResult from server_name -> key_id -> FetchKeyResult
Raises:
KeyLookupError if there was an error processing the entire response from
the server
""" """
# TODO(mark): Set the minimum_valid_until_ts to that needed by logger.info(
# the events being validated or the current time if validating "Requesting keys %s from notary server %s",
# an incoming request. keys_to_fetch.items(),
perspective_name,
)
try: try:
query_response = yield self.client.post_json( query_response = yield self.client.post_json(
destination=perspective_name, destination=perspective_name,
@ -560,12 +638,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
data={ data={
u"server_keys": { u"server_keys": {
server_name: { server_name: {
key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids key_id: {u"minimum_valid_until_ts": min_valid_ts}
for key_id, min_valid_ts in server_keys.items()
} }
for server_name, key_ids in server_names_and_key_ids for server_name, server_keys in keys_to_fetch.items()
} }
}, },
long_retries=True,
) )
except (NotRetryingDestination, RequestSendFailed) as e: except (NotRetryingDestination, RequestSendFailed) as e:
raise_from(KeyLookupError("Failed to connect to remote server"), e) raise_from(KeyLookupError("Failed to connect to remote server"), e)
@ -578,40 +656,31 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
for response in query_response["server_keys"]: for response in query_response["server_keys"]:
if ( # do this first, so that we can give useful errors thereafter
u"signatures" not in response server_name = response.get("server_name")
or perspective_name not in response[u"signatures"] if not isinstance(server_name, six.string_types):
):
raise KeyLookupError( raise KeyLookupError(
"Key response not signed by perspective server" "Malformed response from key notary server %s: invalid server_name"
" %r" % (perspective_name,) % (perspective_name,)
) )
verified = False try:
for key_id in response[u"signatures"][perspective_name]: processed_response = yield self._process_perspectives_response(
if key_id in perspective_keys:
verify_signed_json(
response, perspective_name, perspective_keys[key_id]
)
verified = True
if not verified:
logging.info(
"Response from perspective server %r not signed with a"
" known key, signed with: %r, known keys: %r",
perspective_name, perspective_name,
list(response[u"signatures"][perspective_name]), perspective_keys,
list(perspective_keys), response,
time_added_ms=time_now_ms,
) )
raise KeyLookupError( except KeyLookupError as e:
"Response not signed with a known key for perspective" logger.warning(
" server %r" % (perspective_name,) "Error processing response from key notary server %s for origin "
"server %s: %s",
perspective_name,
server_name,
e,
) )
# we continue to process the rest of the response
processed_response = yield self.process_v2_response( continue
perspective_name, response, time_added_ms=time_now_ms
)
server_name = response["server_name"]
added_keys.extend( added_keys.extend(
(server_name, key_id, key) for key_id, key in processed_response.items() (server_name, key_id, key) for key_id, key in processed_response.items()
@ -624,6 +693,53 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
defer.returnValue(keys) defer.returnValue(keys)
def _process_perspectives_response(
self, perspective_name, perspective_keys, response, time_added_ms
):
"""Parse a 'Server Keys' structure from the result of a /key/query request
Checks that the entry is correctly signed by the perspectives server, and then
passes over to process_v2_response
Args:
perspective_name (str): the name of the notary server that produced this
result
perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
notary server
response (dict): the json-decoded Server Keys response object
time_added_ms (int): the timestamp to record in server_keys_json
Returns:
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
"""
if (
u"signatures" not in response
or perspective_name not in response[u"signatures"]
):
raise KeyLookupError("Response not signed by the notary server")
verified = False
for key_id in response[u"signatures"][perspective_name]:
if key_id in perspective_keys:
verify_signed_json(response, perspective_name, perspective_keys[key_id])
verified = True
if not verified:
raise KeyLookupError(
"Response not signed with a known key: signed with: %r, known keys: %r"
% (
list(response[u"signatures"][perspective_name].keys()),
list(perspective_keys.keys()),
)
)
return self.process_v2_response(
perspective_name, response, time_added_ms=time_added_ms
)
class ServerKeyFetcher(BaseV2KeyFetcher): class ServerKeyFetcher(BaseV2KeyFetcher):
"""KeyFetcher impl which fetches keys from the origin servers""" """KeyFetcher impl which fetches keys from the origin servers"""
@ -633,34 +749,54 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.client = hs.get_http_client() self.client = hs.get_http_client()
@defer.inlineCallbacks def get_keys(self, keys_to_fetch):
def get_keys(self, server_name_and_key_ids): """
"""see KeyFetcher.get_keys""" Args:
results = yield logcontext.make_deferred_yieldable( keys_to_fetch (dict[str, iterable[str]]):
defer.gatherResults( the keys to be fetched. server_name -> key_ids
[
run_in_background(
self.get_server_verify_key_v2_direct, server_name, key_ids
)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
merged = {} Returns:
for result in results: Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
merged.update(result) map from server_name -> key_id -> FetchKeyResult
"""
defer.returnValue( results = {}
{server_name: keys for server_name, keys in merged.items() if keys}
@defer.inlineCallbacks
def get_key(key_to_fetch_item):
server_name, key_ids = key_to_fetch_item
try:
keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids)
results[server_name] = keys
except KeyLookupError as e:
logger.warning(
"Error looking up keys %s from %s: %s", key_ids, server_name, e
)
except Exception:
logger.exception("Error getting keys %s from %s", key_ids, server_name)
return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback(
lambda _: results
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids): def get_server_verify_key_v2_direct(self, server_name, key_ids):
"""
Args:
server_name (str):
key_ids (iterable[str]):
Returns:
Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result
Raises:
KeyLookupError if there was a problem making the lookup
"""
keys = {} # type: dict[str, FetchKeyResult] keys = {} # type: dict[str, FetchKeyResult]
for requested_key_id in key_ids: for requested_key_id in key_ids:
# we may have found this key as a side-effect of asking for another.
if requested_key_id in keys: if requested_key_id in keys:
continue continue
@ -671,18 +807,25 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
path="/_matrix/key/v2/server/" path="/_matrix/key/v2/server/"
+ urllib.parse.quote(requested_key_id), + urllib.parse.quote(requested_key_id),
ignore_backoff=True, ignore_backoff=True,
# we only give the remote server 10s to respond. It should be an
# easy request to handle, so if it doesn't reply within 10s, it's
# probably not going to.
#
# Furthermore, when we are acting as a notary server, we cannot
# wait all day for all of the origin servers, as the requesting
# server will otherwise time out before we can respond.
#
# (Note that get_json may make 4 attempts, so this can still take
# almost 45 seconds to fetch the headers, plus up to another 60s to
# read the response).
timeout=10000,
) )
except (NotRetryingDestination, RequestSendFailed) as e: except (NotRetryingDestination, RequestSendFailed) as e:
raise_from(KeyLookupError("Failed to connect to remote server"), e) raise_from(KeyLookupError("Failed to connect to remote server"), e)
except HttpResponseException as e: except HttpResponseException as e:
raise_from(KeyLookupError("Remote server returned an error"), e) raise_from(KeyLookupError("Remote server returned an error"), e)
if (
u"signatures" not in response
or server_name not in response[u"signatures"]
):
raise KeyLookupError("Key response not signed by remote server")
if response["server_name"] != server_name: if response["server_name"] != server_name:
raise KeyLookupError( raise KeyLookupError(
"Expected a response for server %r not %r" "Expected a response for server %r not %r"
@ -691,7 +834,6 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
response_keys = yield self.process_v2_response( response_keys = yield self.process_v2_response(
from_server=server_name, from_server=server_name,
requested_ids=[requested_key_id],
response_json=response, response_json=response,
time_added_ms=time_now_ms, time_added_ms=time_now_ms,
) )
@ -702,7 +844,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
) )
keys.update(response_keys) keys.update(response_keys)
defer.returnValue({server_name: keys}) defer.returnValue(keys)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -719,31 +861,8 @@ def _handle_key_deferred(verify_request):
SynapseError if there was a problem performing the verification SynapseError if there was a problem performing the verification
""" """
server_name = verify_request.server_name server_name = verify_request.server_name
try: with PreserveLoggingContext():
with PreserveLoggingContext(): _, key_id, verify_key = yield verify_request.deferred
_, key_id, verify_key = yield verify_request.deferred
except KeyLookupError as e:
logger.warn(
"Failed to download keys for %s: %s %s",
server_name,
type(e).__name__,
str(e),
)
raise SynapseError(
502, "Error downloading keys for %s" % (server_name,), Codes.UNAUTHORIZED
)
except Exception as e:
logger.exception(
"Got Exception when downloading keys for %s: %s %s",
server_name,
type(e).__name__,
str(e),
)
raise SynapseError(
401,
"No key for %s with id %s" % (server_name, verify_request.key_ids),
Codes.UNAUTHORIZED,
)
json_object = verify_request.json_object json_object = verify_request.json_object

View file

@ -265,7 +265,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
] ]
more_deferreds = keyring.verify_json_objects_for_server([ more_deferreds = keyring.verify_json_objects_for_server([
(p.sender_domain, p.redacted_pdu_json) (p.sender_domain, p.redacted_pdu_json, 0)
for p in pdus_to_check_sender for p in pdus_to_check_sender
]) ])
@ -298,7 +298,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
] ]
more_deferreds = keyring.verify_json_objects_for_server([ more_deferreds = keyring.verify_json_objects_for_server([
(get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json) (get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json, 0)
for p in pdus_to_check_event_id for p in pdus_to_check_event_id
]) ])

View file

@ -23,7 +23,11 @@ from twisted.internet import defer
import synapse import synapse
from synapse.api.errors import Codes, FederationDeniedError, SynapseError from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX from synapse.api.urls import (
FEDERATION_UNSTABLE_PREFIX,
FEDERATION_V1_PREFIX,
FEDERATION_V2_PREFIX,
)
from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import ( from synapse.http.servlet import (
@ -90,6 +94,7 @@ class NoAuthenticationError(AuthenticationError):
class Authenticator(object): class Authenticator(object):
def __init__(self, hs): def __init__(self, hs):
self._clock = hs.get_clock()
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
self.server_name = hs.hostname self.server_name = hs.hostname
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -98,6 +103,7 @@ class Authenticator(object):
# A method just so we can pass 'self' as the authenticator to the Servlets # A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks @defer.inlineCallbacks
def authenticate_request(self, request, content): def authenticate_request(self, request, content):
now = self._clock.time_msec()
json_request = { json_request = {
"method": request.method.decode('ascii'), "method": request.method.decode('ascii'),
"uri": request.uri.decode('ascii'), "uri": request.uri.decode('ascii'),
@ -134,7 +140,7 @@ class Authenticator(object):
401, "Missing Authorization headers", Codes.UNAUTHORIZED, 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
) )
yield self.keyring.verify_json_for_server(origin, json_request) yield self.keyring.verify_json_for_server(origin, json_request, now)
logger.info("Request from %s", origin) logger.info("Request from %s", origin)
request.authenticated_entity = origin request.authenticated_entity = origin
@ -1304,6 +1310,30 @@ class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
defer.returnValue((200, new_content)) defer.returnValue((200, new_content))
class RoomComplexityServlet(BaseFederationServlet):
"""
Indicates to other servers how complex (and therefore likely
resource-intensive) a public room this server knows about is.
"""
PATH = "/rooms/(?P<room_id>[^/]*)/complexity"
PREFIX = FEDERATION_UNSTABLE_PREFIX
@defer.inlineCallbacks
def on_GET(self, origin, content, query, room_id):
store = self.handler.hs.get_datastore()
is_public = yield store.is_room_world_readable_or_publicly_joinable(
room_id
)
if not is_public:
raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM)
complexity = yield store.get_room_complexity(room_id)
defer.returnValue((200, complexity))
FEDERATION_SERVLET_CLASSES = ( FEDERATION_SERVLET_CLASSES = (
FederationSendServlet, FederationSendServlet,
FederationEventServlet, FederationEventServlet,
@ -1327,6 +1357,7 @@ FEDERATION_SERVLET_CLASSES = (
FederationThirdPartyInviteExchangeServlet, FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet, On3pidBindServlet,
FederationVersionServlet, FederationVersionServlet,
RoomComplexityServlet,
) )
OPENID_SERVLET_CLASSES = ( OPENID_SERVLET_CLASSES = (

View file

@ -97,10 +97,11 @@ class GroupAttestationSigning(object):
# TODO: We also want to check that *new* attestations that people give # TODO: We also want to check that *new* attestations that people give
# us to store are valid for at least a little while. # us to store are valid for at least a little while.
if valid_until_ms < self.clock.time_msec(): now = self.clock.time_msec()
if valid_until_ms < now:
raise SynapseError(400, "Attestation expired") raise SynapseError(400, "Attestation expired")
yield self.keyring.verify_json_for_server(server_name, attestation) yield self.keyring.verify_json_for_server(server_name, attestation, now)
def create_attestation(self, group_id, user_id): def create_attestation(self, group_id, user_id):
"""Create an attestation for the group_id and user_id with default """Create an attestation for the group_id and user_id with default

View file

@ -182,17 +182,27 @@ class PresenceHandler(object):
# Start a LoopingCall in 30s that fires every 5s. # Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to # The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline. # reconnect before we treat them as offline.
def run_timeout_handler():
return run_as_background_process(
"handle_presence_timeouts", self._handle_timeouts
)
self.clock.call_later( self.clock.call_later(
30, 30,
self.clock.looping_call, self.clock.looping_call,
self._handle_timeouts, run_timeout_handler,
5000, 5000,
) )
def run_persister():
return run_as_background_process(
"persist_presence_changes", self._persist_unpersisted_changes
)
self.clock.call_later( self.clock.call_later(
60, 60,
self.clock.looping_call, self.clock.looping_call,
self._persist_unpersisted_changes, run_persister,
60 * 1000, 60 * 1000,
) )
@ -229,6 +239,7 @@ class PresenceHandler(object):
) )
if self.unpersisted_users_changes: if self.unpersisted_users_changes:
yield self.store.update_presence([ yield self.store.update_presence([
self.user_to_current_state[user_id] self.user_to_current_state[user_id]
for user_id in self.unpersisted_users_changes for user_id in self.unpersisted_users_changes
@ -240,30 +251,18 @@ class PresenceHandler(object):
"""We periodically persist the unpersisted changes, as otherwise they """We periodically persist the unpersisted changes, as otherwise they
may stack up and slow down shutdown times. may stack up and slow down shutdown times.
""" """
logger.info(
"Performing _persist_unpersisted_changes. Persisting %d unpersisted changes",
len(self.unpersisted_users_changes)
)
unpersisted = self.unpersisted_users_changes unpersisted = self.unpersisted_users_changes
self.unpersisted_users_changes = set() self.unpersisted_users_changes = set()
if unpersisted: if unpersisted:
logger.info(
"Persisting %d upersisted presence updates", len(unpersisted)
)
yield self.store.update_presence([ yield self.store.update_presence([
self.user_to_current_state[user_id] self.user_to_current_state[user_id]
for user_id in unpersisted for user_id in unpersisted
]) ])
logger.info("Finished _persist_unpersisted_changes")
@defer.inlineCallbacks
def _update_states_and_catch_exception(self, new_states):
try:
res = yield self._update_states(new_states)
defer.returnValue(res)
except Exception:
logger.exception("Error updating presence")
@defer.inlineCallbacks @defer.inlineCallbacks
def _update_states(self, new_states): def _update_states(self, new_states):
"""Updates presence of users. Sets the appropriate timeouts. Pokes """Updates presence of users. Sets the appropriate timeouts. Pokes
@ -338,45 +337,41 @@ class PresenceHandler(object):
logger.info("Handling presence timeouts") logger.info("Handling presence timeouts")
now = self.clock.time_msec() now = self.clock.time_msec()
try: # Fetch the list of users that *may* have timed out. Things may have
with Measure(self.clock, "presence_handle_timeouts"): # changed since the timeout was set, so we won't necessarily have to
# Fetch the list of users that *may* have timed out. Things may have # take any action.
# changed since the timeout was set, so we won't necessarily have to users_to_check = set(self.wheel_timer.fetch(now))
# take any action.
users_to_check = set(self.wheel_timer.fetch(now))
# Check whether the lists of syncing processes from an external # Check whether the lists of syncing processes from an external
# process have expired. # process have expired.
expired_process_ids = [ expired_process_ids = [
process_id for process_id, last_update process_id for process_id, last_update
in self.external_process_last_updated_ms.items() in self.external_process_last_updated_ms.items()
if now - last_update > EXTERNAL_PROCESS_EXPIRY if now - last_update > EXTERNAL_PROCESS_EXPIRY
] ]
for process_id in expired_process_ids: for process_id in expired_process_ids:
users_to_check.update( users_to_check.update(
self.external_process_last_updated_ms.pop(process_id, ()) self.external_process_last_updated_ms.pop(process_id, ())
) )
self.external_process_last_update.pop(process_id) self.external_process_last_update.pop(process_id)
states = [ states = [
self.user_to_current_state.get( self.user_to_current_state.get(
user_id, UserPresenceState.default(user_id) user_id, UserPresenceState.default(user_id)
) )
for user_id in users_to_check for user_id in users_to_check
] ]
timers_fired_counter.inc(len(states)) timers_fired_counter.inc(len(states))
changes = handle_timeouts( changes = handle_timeouts(
states, states,
is_mine_fn=self.is_mine_id, is_mine_fn=self.is_mine_id,
syncing_user_ids=self.get_currently_syncing_users(), syncing_user_ids=self.get_currently_syncing_users(),
now=now, now=now,
) )
run_in_background(self._update_states_and_catch_exception, changes) return self._update_states(changes)
except Exception:
logger.exception("Exception in _handle_timeouts loop")
@defer.inlineCallbacks @defer.inlineCallbacks
def bump_presence_active_time(self, user): def bump_presence_active_time(self, user):

View file

@ -31,6 +31,9 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MAX_DISPLAYNAME_LEN = 100
MAX_AVATAR_URL_LEN = 1000
class BaseProfileHandler(BaseHandler): class BaseProfileHandler(BaseHandler):
"""Handles fetching and updating user profile information. """Handles fetching and updating user profile information.
@ -162,6 +165,11 @@ class BaseProfileHandler(BaseHandler):
if not by_admin and target_user != requester.user: if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname") raise AuthError(400, "Cannot set another user's displayname")
if len(new_displayname) > MAX_DISPLAYNAME_LEN:
raise SynapseError(
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN, ),
)
if new_displayname == '': if new_displayname == '':
new_displayname = None new_displayname = None
@ -217,6 +225,11 @@ class BaseProfileHandler(BaseHandler):
if not by_admin and target_user != requester.user: if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url") raise AuthError(400, "Cannot set another user's avatar_url")
if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
raise SynapseError(
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN, ),
)
yield self.store.set_profile_avatar_url( yield self.store.set_profile_avatar_url(
target_user.localpart, new_avatar_url target_user.localpart, new_avatar_url
) )

View file

@ -531,6 +531,8 @@ class RegistrationHandler(BaseHandler):
A tuple of (user_id, access_token). A tuple of (user_id, access_token).
Raises: Raises:
RegistrationError if there was a problem registering. RegistrationError if there was a problem registering.
NB this is only used in tests. TODO: move it to the test package!
""" """
if localpart is None: if localpart is None:
raise SynapseError(400, "Request must include user id") raise SynapseError(400, "Request must include user id")

View file

@ -285,7 +285,24 @@ class MatrixFederationHttpClient(object):
request (MatrixFederationRequest): details of request to be sent request (MatrixFederationRequest): details of request to be sent
timeout (int|None): number of milliseconds to wait for the response headers timeout (int|None): number of milliseconds to wait for the response headers
(including connecting to the server). 60s by default. (including connecting to the server), *for each attempt*.
60s by default.
long_retries (bool): whether to use the long retry algorithm.
The regular retry algorithm makes 4 attempts, with intervals
[0.5s, 1s, 2s].
The long retry algorithm makes 11 attempts, with intervals
[4s, 16s, 60s, 60s, ...]
Both algorithms add -20%/+40% jitter to the retry intervals.
Note that the above intervals are *in addition* to the time spent
waiting for the request to complete (up to `timeout` ms).
NB: the long retry algorithm takes over 20 minutes to complete, with
a default timeout of 60s!
ignore_backoff (bool): true to ignore the historical backoff data ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway. and try the request anyway.
@ -566,10 +583,14 @@ class MatrixFederationHttpClient(object):
the request body. This will be encoded as JSON. the request body. This will be encoded as JSON.
json_data_callback (callable): A callable returning the dict to json_data_callback (callable): A callable returning the dict to
use as the request body. use as the request body.
long_retries (bool): A boolean that indicates whether we should
retry for a short or long time. long_retries (bool): whether to use the long retry algorithm. See
timeout(int): How long to try (in ms) the destination for before docs on _send_request for details.
giving up. None indicates no timeout.
timeout (int|None): number of milliseconds to wait for the response headers
(including connecting to the server), *for each attempt*.
self._default_timeout (60s) by default.
ignore_backoff (bool): true to ignore the historical backoff data ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway. and try the request anyway.
backoff_on_404 (bool): True if we should count a 404 response as backoff_on_404 (bool): True if we should count a 404 response as
@ -627,15 +648,22 @@ class MatrixFederationHttpClient(object):
Args: Args:
destination (str): The remote server to send the HTTP request destination (str): The remote server to send the HTTP request
to. to.
path (str): The HTTP path. path (str): The HTTP path.
data (dict): A dict containing the data that will be used as data (dict): A dict containing the data that will be used as
the request body. This will be encoded as JSON. the request body. This will be encoded as JSON.
long_retries (bool): A boolean that indicates whether we should
retry for a short or long time. long_retries (bool): whether to use the long retry algorithm. See
timeout(int): How long to try (in ms) the destination for before docs on _send_request for details.
giving up. None indicates no timeout.
timeout (int|None): number of milliseconds to wait for the response headers
(including connecting to the server), *for each attempt*.
self._default_timeout (60s) by default.
ignore_backoff (bool): true to ignore the historical backoff data and ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway. try the request anyway.
args (dict): query params args (dict): query params
Returns: Returns:
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
@ -686,14 +714,19 @@ class MatrixFederationHttpClient(object):
Args: Args:
destination (str): The remote server to send the HTTP request destination (str): The remote server to send the HTTP request
to. to.
path (str): The HTTP path. path (str): The HTTP path.
args (dict|None): A dictionary used to create query strings, defaults to args (dict|None): A dictionary used to create query strings, defaults to
None. None.
timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout and that the request will timeout (int|None): number of milliseconds to wait for the response headers
be retried. (including connecting to the server), *for each attempt*.
self._default_timeout (60s) by default.
ignore_backoff (bool): true to ignore the historical backoff data ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway. and try the request anyway.
try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end of response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3. the request. Workaround for #3622 in Synapse <= v0.99.3.
@ -711,10 +744,6 @@ class MatrixFederationHttpClient(object):
RequestSendFailed: If there were problems connecting to the RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc. remote, due to e.g. DNS failures, connection timeouts etc.
""" """
logger.debug("get_json args: %s", args)
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
request = MatrixFederationRequest( request = MatrixFederationRequest(
method="GET", method="GET",
destination=destination, destination=destination,
@ -746,12 +775,18 @@ class MatrixFederationHttpClient(object):
destination (str): The remote server to send the HTTP request destination (str): The remote server to send the HTTP request
to. to.
path (str): The HTTP path. path (str): The HTTP path.
long_retries (bool): A boolean that indicates whether we should
retry for a short or long time. long_retries (bool): whether to use the long retry algorithm. See
timeout(int): How long to try (in ms) the destination for before docs on _send_request for details.
giving up. None indicates no timeout.
timeout (int|None): number of milliseconds to wait for the response headers
(including connecting to the server), *for each attempt*.
self._default_timeout (60s) by default.
ignore_backoff (bool): true to ignore the historical backoff data and ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway. try the request anyway.
args (dict): query params
Returns: Returns:
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body. result will be the decoded JSON body.

View file

@ -55,7 +55,7 @@ def parse_integer_from_args(args, name, default=None, required=False):
return int(args[name][0]) return int(args[name][0])
except Exception: except Exception:
message = "Query parameter %r must be an integer" % (name,) message = "Query parameter %r must be an integer" % (name,)
raise SynapseError(400, message) raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
else: else:
if required: if required:
message = "Missing integer query parameter %r" % (name,) message = "Missing integer query parameter %r" % (name,)

View file

@ -822,10 +822,16 @@ class AdminRestResource(JsonResource):
def __init__(self, hs): def __init__(self, hs):
JsonResource.__init__(self, hs, canonical_json=False) JsonResource.__init__(self, hs, canonical_json=False)
register_servlets(hs, self)
register_servlets_for_client_rest_resource(hs, self)
SendServerNoticeServlet(hs).register(self) def register_servlets(hs, http_server):
VersionServlet(hs).register(self) """
Register all the admin servlets.
"""
register_servlets_for_client_rest_resource(hs, http_server)
SendServerNoticeServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server): def register_servlets_for_client_rest_resource(hs, http_server):

View file

@ -1,65 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains base REST classes for constructing client v1 servlets.
"""
import logging
import re
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.http.servlet import RestServlet
from synapse.rest.client.transactions import HttpTransactionCache
logger = logging.getLogger(__name__)
def client_path_patterns(path_regex, releases=(0,), include_in_unstable=True):
"""Creates a regex compiled client path with the correct client path
prefix.
Args:
path_regex (str): The regex string to match. This should NOT have a ^
as this will be prefixed.
Returns:
SRE_Pattern
"""
patterns = [re.compile("^" + CLIENT_API_PREFIX + "/api/v1" + path_regex)]
if include_in_unstable:
unstable_prefix = CLIENT_API_PREFIX + "/unstable"
patterns.append(re.compile("^" + unstable_prefix + path_regex))
for release in releases:
new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,)
patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns
class ClientV1RestServlet(RestServlet):
"""A base Synapse REST Servlet for the client version 1 API.
"""
# This subclass was presumably created to allow the auth for the v1
# protocol version to be different, however this behaviour was removed.
# it may no longer be necessary
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
self.hs = hs
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs)

View file

@ -19,11 +19,10 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import RoomAlias from synapse.types import RoomAlias
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -33,13 +32,14 @@ def register_servlets(hs, http_server):
ClientAppserviceDirectoryListServer(hs).register(http_server) ClientAppserviceDirectoryListServer(hs).register(http_server)
class ClientDirectoryServer(ClientV1RestServlet): class ClientDirectoryServer(RestServlet):
PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$") PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(ClientDirectoryServer, self).__init__(hs) super(ClientDirectoryServer, self).__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_alias): def on_GET(self, request, room_alias):
@ -120,13 +120,14 @@ class ClientDirectoryServer(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class ClientDirectoryListServer(ClientV1RestServlet): class ClientDirectoryListServer(RestServlet):
PATTERNS = client_path_patterns("/directory/list/room/(?P<room_id>[^/]*)$") PATTERNS = client_patterns("/directory/list/room/(?P<room_id>[^/]*)$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(ClientDirectoryListServer, self).__init__(hs) super(ClientDirectoryListServer, self).__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -162,15 +163,16 @@ class ClientDirectoryListServer(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class ClientAppserviceDirectoryListServer(ClientV1RestServlet): class ClientAppserviceDirectoryListServer(RestServlet):
PATTERNS = client_path_patterns( PATTERNS = client_patterns(
"/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$" "/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$", v1=True
) )
def __init__(self, hs): def __init__(self, hs):
super(ClientAppserviceDirectoryListServer, self).__init__(hs) super(ClientAppserviceDirectoryListServer, self).__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
def on_PUT(self, request, network_id, room_id): def on_PUT(self, request, network_id, room_id):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)

View file

@ -19,21 +19,22 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EventStreamRestServlet(ClientV1RestServlet): class EventStreamRestServlet(RestServlet):
PATTERNS = client_path_patterns("/events$") PATTERNS = client_patterns("/events$", v1=True)
DEFAULT_LONGPOLL_TIME_MS = 30000 DEFAULT_LONGPOLL_TIME_MS = 30000
def __init__(self, hs): def __init__(self, hs):
super(EventStreamRestServlet, self).__init__(hs) super(EventStreamRestServlet, self).__init__()
self.event_stream_handler = hs.get_event_stream_handler() self.event_stream_handler = hs.get_event_stream_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -76,11 +77,11 @@ class EventStreamRestServlet(ClientV1RestServlet):
# TODO: Unit test gets, with and without auth, with different kinds of events. # TODO: Unit test gets, with and without auth, with different kinds of events.
class EventRestServlet(ClientV1RestServlet): class EventRestServlet(RestServlet):
PATTERNS = client_path_patterns("/events/(?P<event_id>[^/]*)$") PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(EventRestServlet, self).__init__(hs) super(EventRestServlet, self).__init__()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()

View file

@ -15,19 +15,19 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.http.servlet import parse_boolean from synapse.http.servlet import RestServlet, parse_boolean
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from .base import ClientV1RestServlet, client_path_patterns
# TODO: Needs unit testing # TODO: Needs unit testing
class InitialSyncRestServlet(ClientV1RestServlet): class InitialSyncRestServlet(RestServlet):
PATTERNS = client_path_patterns("/initialSync$") PATTERNS = client_patterns("/initialSync$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(InitialSyncRestServlet, self).__init__(hs) super(InitialSyncRestServlet, self).__init__()
self.initial_sync_handler = hs.get_initial_sync_handler() self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):

View file

@ -29,12 +29,11 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID, map_username_to_mxid_localpart from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -81,15 +80,16 @@ def login_id_thirdparty_from_phone(identifier):
} }
class LoginRestServlet(ClientV1RestServlet): class LoginRestServlet(RestServlet):
PATTERNS = client_path_patterns("/login$") PATTERNS = client_patterns("/login$", v1=True)
CAS_TYPE = "m.login.cas" CAS_TYPE = "m.login.cas"
SSO_TYPE = "m.login.sso" SSO_TYPE = "m.login.sso"
TOKEN_TYPE = "m.login.token" TOKEN_TYPE = "m.login.token"
JWT_TYPE = "m.login.jwt" JWT_TYPE = "m.login.jwt"
def __init__(self, hs): def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs) super(LoginRestServlet, self).__init__()
self.hs = hs
self.jwt_enabled = hs.config.jwt_enabled self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm self.jwt_algorithm = hs.config.jwt_algorithm
@ -371,7 +371,7 @@ class LoginRestServlet(ClientV1RestServlet):
class CasRedirectServlet(RestServlet): class CasRedirectServlet(RestServlet):
PATTERNS = client_path_patterns("/login/(cas|sso)/redirect") PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(CasRedirectServlet, self).__init__() super(CasRedirectServlet, self).__init__()
@ -386,7 +386,7 @@ class CasRedirectServlet(RestServlet):
b"redirectUrl": args[b"redirectUrl"][0] b"redirectUrl": args[b"redirectUrl"][0]
}).encode('ascii') }).encode('ascii')
hs_redirect_url = (self.cas_service_url + hs_redirect_url = (self.cas_service_url +
b"/_matrix/client/api/v1/login/cas/ticket") b"/_matrix/client/r0/login/cas/ticket")
service_param = urllib.parse.urlencode({ service_param = urllib.parse.urlencode({
b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param) b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
}).encode('ascii') }).encode('ascii')
@ -394,27 +394,27 @@ class CasRedirectServlet(RestServlet):
finish_request(request) finish_request(request)
class CasTicketServlet(ClientV1RestServlet): class CasTicketServlet(RestServlet):
PATTERNS = client_path_patterns("/login/cas/ticket", releases=()) PATTERNS = client_patterns("/login/cas/ticket", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(CasTicketServlet, self).__init__(hs) super(CasTicketServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes self.cas_required_attributes = hs.config.cas_required_attributes
self._sso_auth_handler = SSOAuthHandler(hs) self._sso_auth_handler = SSOAuthHandler(hs)
self._http_client = hs.get_simple_http_client()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
client_redirect_url = parse_string(request, "redirectUrl", required=True) client_redirect_url = parse_string(request, "redirectUrl", required=True)
http_client = self.hs.get_simple_http_client()
uri = self.cas_server_url + "/proxyValidate" uri = self.cas_server_url + "/proxyValidate"
args = { args = {
"ticket": parse_string(request, "ticket", required=True), "ticket": parse_string(request, "ticket", required=True),
"service": self.cas_service_url "service": self.cas_service_url
} }
try: try:
body = yield http_client.get_raw(uri, args) body = yield self._http_client.get_raw(uri, args)
except PartialDownloadError as pde: except PartialDownloadError as pde:
# Twisted raises this error if the connection is closed, # Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data # even if that's being used old-http style to signal end-of-data

View file

@ -17,19 +17,18 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LogoutRestServlet(ClientV1RestServlet): class LogoutRestServlet(RestServlet):
PATTERNS = client_path_patterns("/logout$") PATTERNS = client_patterns("/logout$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(LogoutRestServlet, self).__init__(hs) super(LogoutRestServlet, self).__init__()
self._auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler() self._device_handler = hs.get_device_handler()
@ -38,32 +37,25 @@ class LogoutRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
try: requester = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
except AuthError: if requester.device_id is None:
# this implies the access token has already been deleted. # the acccess token wasn't associated with a device.
defer.returnValue((401, { # Just delete the access token
"errcode": "M_UNKNOWN_TOKEN", access_token = self.auth.get_access_token_from_request(request)
"error": "Access Token unknown or expired" yield self._auth_handler.delete_access_token(access_token)
}))
else: else:
if requester.device_id is None: yield self._device_handler.delete_device(
# the acccess token wasn't associated with a device. requester.user.to_string(), requester.device_id)
# Just delete the access token
access_token = self._auth.get_access_token_from_request(request)
yield self._auth_handler.delete_access_token(access_token)
else:
yield self._device_handler.delete_device(
requester.user.to_string(), requester.device_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))
class LogoutAllRestServlet(ClientV1RestServlet): class LogoutAllRestServlet(RestServlet):
PATTERNS = client_path_patterns("/logout/all$") PATTERNS = client_patterns("/logout/all$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(LogoutAllRestServlet, self).__init__(hs) super(LogoutAllRestServlet, self).__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler() self._device_handler = hs.get_device_handler()

View file

@ -23,21 +23,22 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import UserID from synapse.types import UserID
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PresenceStatusRestServlet(ClientV1RestServlet): class PresenceStatusRestServlet(RestServlet):
PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status") PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(PresenceStatusRestServlet, self).__init__(hs) super(PresenceStatusRestServlet, self).__init__()
self.hs = hs
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):

View file

@ -16,18 +16,19 @@
""" This module contains REST servlets to do with profile: /profile/<paths> """ """ This module contains REST servlets to do with profile: /profile/<paths> """
from twisted.internet import defer from twisted.internet import defer
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import UserID from synapse.types import UserID
from .base import ClientV1RestServlet, client_path_patterns
class ProfileDisplaynameRestServlet(RestServlet):
class ProfileDisplaynameRestServlet(ClientV1RestServlet): PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
def __init__(self, hs): def __init__(self, hs):
super(ProfileDisplaynameRestServlet, self).__init__(hs) super(ProfileDisplaynameRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
@ -71,12 +72,14 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
return (200, {}) return (200, {})
class ProfileAvatarURLRestServlet(ClientV1RestServlet): class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url") PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(ProfileAvatarURLRestServlet, self).__init__(hs) super(ProfileAvatarURLRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
@ -119,12 +122,14 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
return (200, {}) return (200, {})
class ProfileRestServlet(ClientV1RestServlet): class ProfileRestServlet(RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)") PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(ProfileRestServlet, self).__init__(hs) super(ProfileRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):

View file

@ -21,22 +21,22 @@ from synapse.api.errors import (
SynapseError, SynapseError,
UnrecognizedRequestError, UnrecognizedRequestError,
) )
from synapse.http.servlet import parse_json_value_from_request, parse_string from synapse.http.servlet import RestServlet, parse_json_value_from_request, parse_string
from synapse.push.baserules import BASE_RULE_IDS from synapse.push.baserules import BASE_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from .base import ClientV1RestServlet, client_path_patterns
class PushRuleRestServlet(RestServlet):
class PushRuleRestServlet(ClientV1RestServlet): PATTERNS = client_patterns("/(?P<path>pushrules/.*)$", v1=True)
PATTERNS = client_path_patterns("/(?P<path>pushrules/.*)$")
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
"Unrecognised request: You probably wanted a trailing slash") "Unrecognised request: You probably wanted a trailing slash")
def __init__(self, hs): def __init__(self, hs):
super(PushRuleRestServlet, self).__init__(hs) super(PushRuleRestServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None self._is_worker = hs.config.worker_app is not None

View file

@ -26,17 +26,18 @@ from synapse.http.servlet import (
parse_string, parse_string,
) )
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from synapse.rest.client.v2_alpha._base import client_patterns
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PushersRestServlet(ClientV1RestServlet): class PushersRestServlet(RestServlet):
PATTERNS = client_path_patterns("/pushers$") PATTERNS = client_patterns("/pushers$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(PushersRestServlet, self).__init__(hs) super(PushersRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -69,11 +70,13 @@ class PushersRestServlet(ClientV1RestServlet):
return 200, {} return 200, {}
class PushersSetRestServlet(ClientV1RestServlet): class PushersSetRestServlet(RestServlet):
PATTERNS = client_path_patterns("/pushers/set$") PATTERNS = client_patterns("/pushers/set$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(PushersSetRestServlet, self).__init__(hs) super(PushersSetRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool() self.pusher_pool = self.hs.get_pusherpool()
@ -141,7 +144,7 @@ class PushersRemoveRestServlet(RestServlet):
""" """
To allow pusher to be delete by clicking a link (ie. GET request) To allow pusher to be delete by clicking a link (ie. GET request)
""" """
PATTERNS = client_path_patterns("/pushers/remove$") PATTERNS = client_patterns("/pushers/remove$", v1=True)
SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>" SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs): def __init__(self, hs):

View file

@ -28,37 +28,45 @@ from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2 from synapse.events.utils import format_event_for_client_v2
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet,
assert_params_in_dict, assert_params_in_dict,
parse_integer, parse_integer,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RoomCreateRestServlet(ClientV1RestServlet): class TransactionRestServlet(RestServlet):
def __init__(self, hs):
super(TransactionRestServlet, self).__init__()
self.txns = HttpTransactionCache(hs)
class RoomCreateRestServlet(TransactionRestServlet):
# No PATTERN; we have custom dispatch rules here # No PATTERN; we have custom dispatch rules here
def __init__(self, hs): def __init__(self, hs):
super(RoomCreateRestServlet, self).__init__(hs) super(RoomCreateRestServlet, self).__init__(hs)
self._room_creation_handler = hs.get_room_creation_handler() self._room_creation_handler = hs.get_room_creation_handler()
self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server):
PATTERNS = "/createRoom" PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
# define CORS for all of /rooms in RoomCreateRestServlet for simplicity # define CORS for all of /rooms in RoomCreateRestServlet for simplicity
http_server.register_paths("OPTIONS", http_server.register_paths("OPTIONS",
client_path_patterns("/rooms(?:/.*)?$"), client_patterns("/rooms(?:/.*)?$", v1=True),
self.on_OPTIONS) self.on_OPTIONS)
# define CORS for /createRoom[/txnid] # define CORS for /createRoom[/txnid]
http_server.register_paths("OPTIONS", http_server.register_paths("OPTIONS",
client_path_patterns("/createRoom(?:/.*)?$"), client_patterns("/createRoom(?:/.*)?$", v1=True),
self.on_OPTIONS) self.on_OPTIONS)
def on_PUT(self, request, txn_id): def on_PUT(self, request, txn_id):
@ -85,13 +93,14 @@ class RoomCreateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events # TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(ClientV1RestServlet): class RoomStateEventRestServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs) super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server):
# /room/$roomid/state/$eventtype # /room/$roomid/state/$eventtype
@ -102,16 +111,16 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
"(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$") "(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$")
http_server.register_paths("GET", http_server.register_paths("GET",
client_path_patterns(state_key), client_patterns(state_key, v1=True),
self.on_GET) self.on_GET)
http_server.register_paths("PUT", http_server.register_paths("PUT",
client_path_patterns(state_key), client_patterns(state_key, v1=True),
self.on_PUT) self.on_PUT)
http_server.register_paths("GET", http_server.register_paths("GET",
client_path_patterns(no_state_key), client_patterns(no_state_key, v1=True),
self.on_GET_no_state_key) self.on_GET_no_state_key)
http_server.register_paths("PUT", http_server.register_paths("PUT",
client_path_patterns(no_state_key), client_patterns(no_state_key, v1=True),
self.on_PUT_no_state_key) self.on_PUT_no_state_key)
def on_GET_no_state_key(self, request, room_id, event_type): def on_GET_no_state_key(self, request, room_id, event_type):
@ -185,11 +194,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events + feedback # TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(ClientV1RestServlet): class RoomSendEventRestServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs) super(RoomSendEventRestServlet, self).__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id] # /rooms/$roomid/send/$event_type[/$txn_id]
@ -229,10 +239,11 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for room ID + alias joins # TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ClientV1RestServlet): class JoinRoomAliasServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs):
super(JoinRoomAliasServlet, self).__init__(hs) super(JoinRoomAliasServlet, self).__init__(hs)
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server):
# /join/$room_identifier[/$txn_id] # /join/$room_identifier[/$txn_id]
@ -291,8 +302,13 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class PublicRoomListRestServlet(ClientV1RestServlet): class PublicRoomListRestServlet(TransactionRestServlet):
PATTERNS = client_path_patterns("/publicRooms$") PATTERNS = client_patterns("/publicRooms$", v1=True)
def __init__(self, hs):
super(PublicRoomListRestServlet, self).__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -382,12 +398,13 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomMemberListRestServlet(ClientV1RestServlet): class RoomMemberListRestServlet(RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$") PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(RoomMemberListRestServlet, self).__init__(hs) super(RoomMemberListRestServlet, self).__init__()
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -436,12 +453,13 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
# deprecated in favour of /members?membership=join? # deprecated in favour of /members?membership=join?
# except it does custom AS logic and has a simpler return format # except it does custom AS logic and has a simpler return format
class JoinedRoomMemberListRestServlet(ClientV1RestServlet): class JoinedRoomMemberListRestServlet(RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$") PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(JoinedRoomMemberListRestServlet, self).__init__(hs) super(JoinedRoomMemberListRestServlet, self).__init__()
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -457,12 +475,13 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
# TODO: Needs better unit testing # TODO: Needs better unit testing
class RoomMessageListRestServlet(ClientV1RestServlet): class RoomMessageListRestServlet(RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$") PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(RoomMessageListRestServlet, self).__init__(hs) super(RoomMessageListRestServlet, self).__init__()
self.pagination_handler = hs.get_pagination_handler() self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -475,6 +494,8 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
if filter_bytes: if filter_bytes:
filter_json = urlparse.unquote(filter_bytes.decode("UTF-8")) filter_json = urlparse.unquote(filter_bytes.decode("UTF-8"))
event_filter = Filter(json.loads(filter_json)) event_filter = Filter(json.loads(filter_json))
if event_filter.filter_json.get("event_format", "client") == "federation":
as_client_event = False
else: else:
event_filter = None event_filter = None
msgs = yield self.pagination_handler.get_messages( msgs = yield self.pagination_handler.get_messages(
@ -489,12 +510,13 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomStateRestServlet(ClientV1RestServlet): class RoomStateRestServlet(RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$") PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(RoomStateRestServlet, self).__init__(hs) super(RoomStateRestServlet, self).__init__()
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -509,12 +531,13 @@ class RoomStateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomInitialSyncRestServlet(ClientV1RestServlet): class RoomInitialSyncRestServlet(RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$") PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(RoomInitialSyncRestServlet, self).__init__(hs) super(RoomInitialSyncRestServlet, self).__init__()
self.initial_sync_handler = hs.get_initial_sync_handler() self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -528,16 +551,17 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
defer.returnValue((200, content)) defer.returnValue((200, content))
class RoomEventServlet(ClientV1RestServlet): class RoomEventServlet(RestServlet):
PATTERNS = client_path_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$", v1=True
) )
def __init__(self, hs): def __init__(self, hs):
super(RoomEventServlet, self).__init__(hs) super(RoomEventServlet, self).__init__()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_id): def on_GET(self, request, room_id, event_id):
@ -552,16 +576,17 @@ class RoomEventServlet(ClientV1RestServlet):
defer.returnValue((404, "Event not found.")) defer.returnValue((404, "Event not found."))
class RoomEventContextServlet(ClientV1RestServlet): class RoomEventContextServlet(RestServlet):
PATTERNS = client_path_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$", v1=True
) )
def __init__(self, hs): def __init__(self, hs):
super(RoomEventContextServlet, self).__init__(hs) super(RoomEventContextServlet, self).__init__()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler() self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_id): def on_GET(self, request, room_id, event_id):
@ -607,10 +632,11 @@ class RoomEventContextServlet(ClientV1RestServlet):
defer.returnValue((200, results)) defer.returnValue((200, results))
class RoomForgetRestServlet(ClientV1RestServlet): class RoomForgetRestServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomForgetRestServlet, self).__init__(hs) super(RoomForgetRestServlet, self).__init__(hs)
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
@ -637,11 +663,12 @@ class RoomForgetRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomMembershipRestServlet(ClientV1RestServlet): class RoomMembershipRestServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomMembershipRestServlet, self).__init__(hs) super(RoomMembershipRestServlet, self).__init__(hs)
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/[invite|join|leave] # /rooms/$roomid/[invite|join|leave]
@ -720,11 +747,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
) )
class RoomRedactEventRestServlet(ClientV1RestServlet): class RoomRedactEventRestServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs) super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
@ -755,15 +783,16 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
) )
class RoomTypingRestServlet(ClientV1RestServlet): class RoomTypingRestServlet(RestServlet):
PATTERNS = client_path_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$", v1=True
) )
def __init__(self, hs): def __init__(self, hs):
super(RoomTypingRestServlet, self).__init__(hs) super(RoomTypingRestServlet, self).__init__()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.typing_handler = hs.get_typing_handler() self.typing_handler = hs.get_typing_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id): def on_PUT(self, request, room_id, user_id):
@ -796,14 +825,13 @@ class RoomTypingRestServlet(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class SearchRestServlet(ClientV1RestServlet): class SearchRestServlet(RestServlet):
PATTERNS = client_path_patterns( PATTERNS = client_patterns("/search$", v1=True)
"/search$"
)
def __init__(self, hs): def __init__(self, hs):
super(SearchRestServlet, self).__init__(hs) super(SearchRestServlet, self).__init__()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -821,12 +849,13 @@ class SearchRestServlet(ClientV1RestServlet):
defer.returnValue((200, results)) defer.returnValue((200, results))
class JoinedRoomsRestServlet(ClientV1RestServlet): class JoinedRoomsRestServlet(RestServlet):
PATTERNS = client_path_patterns("/joined_rooms$") PATTERNS = client_patterns("/joined_rooms$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(JoinedRoomsRestServlet, self).__init__(hs) super(JoinedRoomsRestServlet, self).__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -851,18 +880,18 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
""" """
http_server.register_paths( http_server.register_paths(
"POST", "POST",
client_path_patterns(regex_string + "$"), client_patterns(regex_string + "$", v1=True),
servlet.on_POST servlet.on_POST
) )
http_server.register_paths( http_server.register_paths(
"PUT", "PUT",
client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"), client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_PUT servlet.on_PUT
) )
if with_get: if with_get:
http_server.register_paths( http_server.register_paths(
"GET", "GET",
client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"), client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_GET servlet.on_GET
) )

View file

@ -19,11 +19,17 @@ import hmac
from twisted.internet import defer from twisted.internet import defer
from .base import ClientV1RestServlet, client_path_patterns from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns
class VoipRestServlet(ClientV1RestServlet): class VoipRestServlet(RestServlet):
PATTERNS = client_path_patterns("/voip/turnServer$") PATTERNS = client_patterns("/voip/turnServer$", v1=True)
def __init__(self, hs):
super(VoipRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):

View file

@ -26,8 +26,7 @@ from synapse.api.urls import CLIENT_API_PREFIX
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def client_v2_patterns(path_regex, releases=(0,), def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
unstable=True):
"""Creates a regex compiled client path with the correct client path """Creates a regex compiled client path with the correct client path
prefix. prefix.
@ -41,6 +40,9 @@ def client_v2_patterns(path_regex, releases=(0,),
if unstable: if unstable:
unstable_prefix = CLIENT_API_PREFIX + "/unstable" unstable_prefix = CLIENT_API_PREFIX + "/unstable"
patterns.append(re.compile("^" + unstable_prefix + path_regex)) patterns.append(re.compile("^" + unstable_prefix + path_regex))
if v1:
v1_prefix = CLIENT_API_PREFIX + "/api/v1"
patterns.append(re.compile("^" + v1_prefix + path_regex))
for release in releases: for release in releases:
new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,) new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,)
patterns.append(re.compile("^" + new_prefix + path_regex)) patterns.append(re.compile("^" + new_prefix + path_regex))

View file

@ -32,13 +32,13 @@ from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.util.threepids import check_3pid_allowed from synapse.util.threepids import check_3pid_allowed
from ._base import client_v2_patterns, interactive_auth_handler from ._base import client_patterns, interactive_auth_handler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EmailPasswordRequestTokenRestServlet(RestServlet): class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/email/requestToken$") PATTERNS = client_patterns("/account/password/email/requestToken$")
def __init__(self, hs): def __init__(self, hs):
super(EmailPasswordRequestTokenRestServlet, self).__init__() super(EmailPasswordRequestTokenRestServlet, self).__init__()
@ -174,7 +174,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
class MsisdnPasswordRequestTokenRestServlet(RestServlet): class MsisdnPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/msisdn/requestToken$") PATTERNS = client_patterns("/account/password/msisdn/requestToken$")
def __init__(self, hs): def __init__(self, hs):
super(MsisdnPasswordRequestTokenRestServlet, self).__init__() super(MsisdnPasswordRequestTokenRestServlet, self).__init__()
@ -212,7 +212,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
class PasswordRestServlet(RestServlet): class PasswordRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password$") PATTERNS = client_patterns("/account/password$")
def __init__(self, hs): def __init__(self, hs):
super(PasswordRestServlet, self).__init__() super(PasswordRestServlet, self).__init__()
@ -285,7 +285,7 @@ class PasswordRestServlet(RestServlet):
class DeactivateAccountRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/deactivate$") PATTERNS = client_patterns("/account/deactivate$")
def __init__(self, hs): def __init__(self, hs):
super(DeactivateAccountRestServlet, self).__init__() super(DeactivateAccountRestServlet, self).__init__()
@ -333,7 +333,7 @@ class DeactivateAccountRestServlet(RestServlet):
class EmailThreepidRequestTokenRestServlet(RestServlet): class EmailThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$") PATTERNS = client_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
@ -368,7 +368,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
class MsisdnThreepidRequestTokenRestServlet(RestServlet): class MsisdnThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/msisdn/requestToken$") PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
@ -405,7 +405,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
class ThreepidRestServlet(RestServlet): class ThreepidRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid$") PATTERNS = client_patterns("/account/3pid$")
def __init__(self, hs): def __init__(self, hs):
super(ThreepidRestServlet, self).__init__() super(ThreepidRestServlet, self).__init__()
@ -469,7 +469,7 @@ class ThreepidRestServlet(RestServlet):
class ThreepidDeleteRestServlet(RestServlet): class ThreepidDeleteRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/delete$") PATTERNS = client_patterns("/account/3pid/delete$")
def __init__(self, hs): def __init__(self, hs):
super(ThreepidDeleteRestServlet, self).__init__() super(ThreepidDeleteRestServlet, self).__init__()
@ -506,7 +506,7 @@ class ThreepidDeleteRestServlet(RestServlet):
class WhoamiRestServlet(RestServlet): class WhoamiRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/whoami$") PATTERNS = client_patterns("/account/whoami$")
def __init__(self, hs): def __init__(self, hs):
super(WhoamiRestServlet, self).__init__() super(WhoamiRestServlet, self).__init__()

View file

@ -20,7 +20,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, NotFoundError, SynapseError from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,7 +30,7 @@ class AccountDataServlet(RestServlet):
PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1 PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1
GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1 GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)" "/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)"
) )
@ -79,7 +79,7 @@ class RoomAccountDataServlet(RestServlet):
PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)" "/user/(?P<user_id>[^/]*)"
"/rooms/(?P<room_id>[^/]*)" "/rooms/(?P<room_id>[^/]*)"
"/account_data/(?P<account_data_type>[^/]*)" "/account_data/(?P<account_data_type>[^/]*)"

View file

@ -21,13 +21,13 @@ from synapse.api.errors import AuthError, SynapseError
from synapse.http.server import finish_request from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AccountValidityRenewServlet(RestServlet): class AccountValidityRenewServlet(RestServlet):
PATTERNS = client_v2_patterns("/account_validity/renew$") PATTERNS = client_patterns("/account_validity/renew$")
SUCCESS_HTML = b"<html><body>Your account has been successfully renewed.</body><html>" SUCCESS_HTML = b"<html><body>Your account has been successfully renewed.</body><html>"
def __init__(self, hs): def __init__(self, hs):
@ -60,7 +60,7 @@ class AccountValidityRenewServlet(RestServlet):
class AccountValiditySendMailServlet(RestServlet): class AccountValiditySendMailServlet(RestServlet):
PATTERNS = client_v2_patterns("/account_validity/send_mail$") PATTERNS = client_patterns("/account_validity/send_mail$")
def __init__(self, hs): def __init__(self, hs):
""" """

View file

@ -23,7 +23,7 @@ from synapse.api.urls import CLIENT_API_PREFIX
from synapse.http.server import finish_request from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet, parse_string from synapse.http.servlet import RestServlet, parse_string
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -122,7 +122,7 @@ class AuthRestServlet(RestServlet):
cannot be handled in the normal flow (with requests to the same endpoint). cannot be handled in the normal flow (with requests to the same endpoint).
Current use is for web fallback auth. Current use is for web fallback auth.
""" """
PATTERNS = client_v2_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web") PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
def __init__(self, hs): def __init__(self, hs):
super(AuthRestServlet, self).__init__() super(AuthRestServlet, self).__init__()

View file

@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
class CapabilitiesRestServlet(RestServlet): class CapabilitiesRestServlet(RestServlet):
"""End point to expose the capabilities of the server.""" """End point to expose the capabilities of the server."""
PATTERNS = client_v2_patterns("/capabilities$") PATTERNS = client_patterns("/capabilities$")
def __init__(self, hs): def __init__(self, hs):
""" """

View file

@ -24,13 +24,13 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
) )
from ._base import client_v2_patterns, interactive_auth_handler from ._base import client_patterns, interactive_auth_handler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DevicesRestServlet(RestServlet): class DevicesRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices$") PATTERNS = client_patterns("/devices$")
def __init__(self, hs): def __init__(self, hs):
""" """
@ -56,7 +56,7 @@ class DeleteDevicesRestServlet(RestServlet):
API for bulk deletion of devices. Accepts a JSON object with a devices API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth. key which lists the device_ids to delete. Requires user interactive auth.
""" """
PATTERNS = client_v2_patterns("/delete_devices") PATTERNS = client_patterns("/delete_devices")
def __init__(self, hs): def __init__(self, hs):
super(DeleteDevicesRestServlet, self).__init__() super(DeleteDevicesRestServlet, self).__init__()
@ -95,7 +95,7 @@ class DeleteDevicesRestServlet(RestServlet):
class DeviceRestServlet(RestServlet): class DeviceRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$") PATTERNS = client_patterns("/devices/(?P<device_id>[^/]*)$")
def __init__(self, hs): def __init__(self, hs):
""" """

View file

@ -21,13 +21,13 @@ from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID from synapse.types import UserID
from ._base import client_v2_patterns, set_timeline_upper_limit from ._base import client_patterns, set_timeline_upper_limit
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class GetFilterRestServlet(RestServlet): class GetFilterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)") PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
def __init__(self, hs): def __init__(self, hs):
super(GetFilterRestServlet, self).__init__() super(GetFilterRestServlet, self).__init__()
@ -63,7 +63,7 @@ class GetFilterRestServlet(RestServlet):
class CreateFilterRestServlet(RestServlet): class CreateFilterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter") PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter")
def __init__(self, hs): def __init__(self, hs):
super(CreateFilterRestServlet, self).__init__() super(CreateFilterRestServlet, self).__init__()

View file

@ -21,7 +21,7 @@ from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import GroupID from synapse.types import GroupID
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class GroupServlet(RestServlet): class GroupServlet(RestServlet):
"""Get the group profile """Get the group profile
""" """
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/profile$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
def __init__(self, hs): def __init__(self, hs):
super(GroupServlet, self).__init__() super(GroupServlet, self).__init__()
@ -65,7 +65,7 @@ class GroupServlet(RestServlet):
class GroupSummaryServlet(RestServlet): class GroupSummaryServlet(RestServlet):
"""Get the full group summary """Get the full group summary
""" """
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/summary$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
def __init__(self, hs): def __init__(self, hs):
super(GroupSummaryServlet, self).__init__() super(GroupSummaryServlet, self).__init__()
@ -93,7 +93,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
- /groups/:group/summary/rooms/:room_id - /groups/:group/summary/rooms/:room_id
- /groups/:group/summary/categories/:category/rooms/:room_id - /groups/:group/summary/categories/:category/rooms/:room_id
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/summary" "/groups/(?P<group_id>[^/]*)/summary"
"(/categories/(?P<category_id>[^/]+))?" "(/categories/(?P<category_id>[^/]+))?"
"/rooms/(?P<room_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)$"
@ -137,7 +137,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
class GroupCategoryServlet(RestServlet): class GroupCategoryServlet(RestServlet):
"""Get/add/update/delete a group category """Get/add/update/delete a group category
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$" "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
) )
@ -189,7 +189,7 @@ class GroupCategoryServlet(RestServlet):
class GroupCategoriesServlet(RestServlet): class GroupCategoriesServlet(RestServlet):
"""Get all group categories """Get all group categories
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/categories/$" "/groups/(?P<group_id>[^/]*)/categories/$"
) )
@ -214,7 +214,7 @@ class GroupCategoriesServlet(RestServlet):
class GroupRoleServlet(RestServlet): class GroupRoleServlet(RestServlet):
"""Get/add/update/delete a group role """Get/add/update/delete a group role
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$" "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
) )
@ -266,7 +266,7 @@ class GroupRoleServlet(RestServlet):
class GroupRolesServlet(RestServlet): class GroupRolesServlet(RestServlet):
"""Get all group roles """Get all group roles
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/roles/$" "/groups/(?P<group_id>[^/]*)/roles/$"
) )
@ -295,7 +295,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
- /groups/:group/summary/users/:room_id - /groups/:group/summary/users/:room_id
- /groups/:group/summary/roles/:role/users/:user_id - /groups/:group/summary/roles/:role/users/:user_id
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/summary" "/groups/(?P<group_id>[^/]*)/summary"
"(/roles/(?P<role_id>[^/]+))?" "(/roles/(?P<role_id>[^/]+))?"
"/users/(?P<user_id>[^/]*)$" "/users/(?P<user_id>[^/]*)$"
@ -339,7 +339,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
class GroupRoomServlet(RestServlet): class GroupRoomServlet(RestServlet):
"""Get all rooms in a group """Get all rooms in a group
""" """
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/rooms$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
def __init__(self, hs): def __init__(self, hs):
super(GroupRoomServlet, self).__init__() super(GroupRoomServlet, self).__init__()
@ -360,7 +360,7 @@ class GroupRoomServlet(RestServlet):
class GroupUsersServlet(RestServlet): class GroupUsersServlet(RestServlet):
"""Get all users in a group """Get all users in a group
""" """
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/users$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
def __init__(self, hs): def __init__(self, hs):
super(GroupUsersServlet, self).__init__() super(GroupUsersServlet, self).__init__()
@ -381,7 +381,7 @@ class GroupUsersServlet(RestServlet):
class GroupInvitedUsersServlet(RestServlet): class GroupInvitedUsersServlet(RestServlet):
"""Get users invited to a group """Get users invited to a group
""" """
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/invited_users$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
def __init__(self, hs): def __init__(self, hs):
super(GroupInvitedUsersServlet, self).__init__() super(GroupInvitedUsersServlet, self).__init__()
@ -405,7 +405,7 @@ class GroupInvitedUsersServlet(RestServlet):
class GroupSettingJoinPolicyServlet(RestServlet): class GroupSettingJoinPolicyServlet(RestServlet):
"""Set group join policy """Set group join policy
""" """
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
def __init__(self, hs): def __init__(self, hs):
super(GroupSettingJoinPolicyServlet, self).__init__() super(GroupSettingJoinPolicyServlet, self).__init__()
@ -431,7 +431,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
class GroupCreateServlet(RestServlet): class GroupCreateServlet(RestServlet):
"""Create a group """Create a group
""" """
PATTERNS = client_v2_patterns("/create_group$") PATTERNS = client_patterns("/create_group$")
def __init__(self, hs): def __init__(self, hs):
super(GroupCreateServlet, self).__init__() super(GroupCreateServlet, self).__init__()
@ -462,7 +462,7 @@ class GroupCreateServlet(RestServlet):
class GroupAdminRoomsServlet(RestServlet): class GroupAdminRoomsServlet(RestServlet):
"""Add a room to the group """Add a room to the group
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$" "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
) )
@ -499,7 +499,7 @@ class GroupAdminRoomsServlet(RestServlet):
class GroupAdminRoomsConfigServlet(RestServlet): class GroupAdminRoomsConfigServlet(RestServlet):
"""Update the config of a room in a group """Update the config of a room in a group
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)" "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
"/config/(?P<config_key>[^/]*)$" "/config/(?P<config_key>[^/]*)$"
) )
@ -526,7 +526,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
class GroupAdminUsersInviteServlet(RestServlet): class GroupAdminUsersInviteServlet(RestServlet):
"""Invite a user to the group """Invite a user to the group
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$" "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
) )
@ -555,7 +555,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
class GroupAdminUsersKickServlet(RestServlet): class GroupAdminUsersKickServlet(RestServlet):
"""Kick a user from the group """Kick a user from the group
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$" "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
) )
@ -581,7 +581,7 @@ class GroupAdminUsersKickServlet(RestServlet):
class GroupSelfLeaveServlet(RestServlet): class GroupSelfLeaveServlet(RestServlet):
"""Leave a joined group """Leave a joined group
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/leave$" "/groups/(?P<group_id>[^/]*)/self/leave$"
) )
@ -607,7 +607,7 @@ class GroupSelfLeaveServlet(RestServlet):
class GroupSelfJoinServlet(RestServlet): class GroupSelfJoinServlet(RestServlet):
"""Attempt to join a group, or knock """Attempt to join a group, or knock
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/join$" "/groups/(?P<group_id>[^/]*)/self/join$"
) )
@ -633,7 +633,7 @@ class GroupSelfJoinServlet(RestServlet):
class GroupSelfAcceptInviteServlet(RestServlet): class GroupSelfAcceptInviteServlet(RestServlet):
"""Accept a group invite """Accept a group invite
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/accept_invite$" "/groups/(?P<group_id>[^/]*)/self/accept_invite$"
) )
@ -659,7 +659,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
class GroupSelfUpdatePublicityServlet(RestServlet): class GroupSelfUpdatePublicityServlet(RestServlet):
"""Update whether we publicise a users membership of a group """Update whether we publicise a users membership of a group
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/update_publicity$" "/groups/(?P<group_id>[^/]*)/self/update_publicity$"
) )
@ -686,7 +686,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
class PublicisedGroupsForUserServlet(RestServlet): class PublicisedGroupsForUserServlet(RestServlet):
"""Get the list of groups a user is advertising """Get the list of groups a user is advertising
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/publicised_groups/(?P<user_id>[^/]*)$" "/publicised_groups/(?P<user_id>[^/]*)$"
) )
@ -711,7 +711,7 @@ class PublicisedGroupsForUserServlet(RestServlet):
class PublicisedGroupsForUsersServlet(RestServlet): class PublicisedGroupsForUsersServlet(RestServlet):
"""Get the list of groups a user is advertising """Get the list of groups a user is advertising
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/publicised_groups$" "/publicised_groups$"
) )
@ -739,7 +739,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
class GroupsForUserServlet(RestServlet): class GroupsForUserServlet(RestServlet):
"""Get all groups the logged in user is joined to """Get all groups the logged in user is joined to
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/joined_groups$" "/joined_groups$"
) )

View file

@ -26,7 +26,7 @@ from synapse.http.servlet import (
) )
from synapse.types import StreamToken from synapse.types import StreamToken
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,7 +56,7 @@ class KeyUploadServlet(RestServlet):
}, },
} }
""" """
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs): def __init__(self, hs):
""" """
@ -130,7 +130,7 @@ class KeyQueryServlet(RestServlet):
} } } } } } } } } } } }
""" """
PATTERNS = client_v2_patterns("/keys/query$") PATTERNS = client_patterns("/keys/query$")
def __init__(self, hs): def __init__(self, hs):
""" """
@ -159,7 +159,7 @@ class KeyChangesServlet(RestServlet):
200 OK 200 OK
{ "changed": ["@foo:example.com"] } { "changed": ["@foo:example.com"] }
""" """
PATTERNS = client_v2_patterns("/keys/changes$") PATTERNS = client_patterns("/keys/changes$")
def __init__(self, hs): def __init__(self, hs):
""" """
@ -209,7 +209,7 @@ class OneTimeKeyServlet(RestServlet):
} } } } } } } }
""" """
PATTERNS = client_v2_patterns("/keys/claim$") PATTERNS = client_patterns("/keys/claim$")
def __init__(self, hs): def __init__(self, hs):
super(OneTimeKeyServlet, self).__init__() super(OneTimeKeyServlet, self).__init__()

View file

@ -20,13 +20,13 @@ from twisted.internet import defer
from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_integer, parse_string
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class NotificationsServlet(RestServlet): class NotificationsServlet(RestServlet):
PATTERNS = client_v2_patterns("/notifications$") PATTERNS = client_patterns("/notifications$")
def __init__(self, hs): def __init__(self, hs):
super(NotificationsServlet, self).__init__() super(NotificationsServlet, self).__init__()

View file

@ -22,7 +22,7 @@ from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,7 +56,7 @@ class IdTokenServlet(RestServlet):
"expires_in": 3600, "expires_in": 3600,
} }
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/openid/request_token" "/user/(?P<user_id>[^/]*)/openid/request_token"
) )

View file

@ -19,13 +19,13 @@ from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReadMarkerRestServlet(RestServlet): class ReadMarkerRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$") PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
def __init__(self, hs): def __init__(self, hs):
super(ReadMarkerRestServlet, self).__init__() super(ReadMarkerRestServlet, self).__init__()

View file

@ -20,13 +20,13 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReceiptRestServlet(RestServlet): class ReceiptRestServlet(RestServlet):
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)" "/rooms/(?P<room_id>[^/]*)"
"/receipt/(?P<receipt_type>[^/]*)" "/receipt/(?P<receipt_type>[^/]*)"
"/(?P<event_id>[^/]*)$" "/(?P<event_id>[^/]*)$"

View file

@ -43,7 +43,7 @@ from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.threepids import check_3pid_allowed from synapse.util.threepids import check_3pid_allowed
from ._base import client_v2_patterns, interactive_auth_handler from ._base import client_patterns, interactive_auth_handler
# We ought to be using hmac.compare_digest() but on older pythons it doesn't # We ought to be using hmac.compare_digest() but on older pythons it doesn't
# exist. It's a _really minor_ security flaw to use plain string comparison # exist. It's a _really minor_ security flaw to use plain string comparison
@ -60,7 +60,7 @@ logger = logging.getLogger(__name__)
class EmailRegisterRequestTokenRestServlet(RestServlet): class EmailRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/email/requestToken$") PATTERNS = client_patterns("/register/email/requestToken$")
def __init__(self, hs): def __init__(self, hs):
""" """
@ -98,7 +98,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
class MsisdnRegisterRequestTokenRestServlet(RestServlet): class MsisdnRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/msisdn/requestToken$") PATTERNS = client_patterns("/register/msisdn/requestToken$")
def __init__(self, hs): def __init__(self, hs):
""" """
@ -142,7 +142,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
class UsernameAvailabilityRestServlet(RestServlet): class UsernameAvailabilityRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/available") PATTERNS = client_patterns("/register/available")
def __init__(self, hs): def __init__(self, hs):
""" """
@ -182,7 +182,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
class RegisterRestServlet(RestServlet): class RegisterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register$") PATTERNS = client_patterns("/register$")
def __init__(self, hs): def __init__(self, hs):
""" """

View file

@ -34,7 +34,7 @@ from synapse.http.servlet import (
from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.transactions import HttpTransactionCache
from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -66,12 +66,12 @@ class RelationSendServlet(RestServlet):
def register(self, http_server): def register(self, http_server):
http_server.register_paths( http_server.register_paths(
"POST", "POST",
client_v2_patterns(self.PATTERN + "$", releases=()), client_patterns(self.PATTERN + "$", releases=()),
self.on_PUT_or_POST, self.on_PUT_or_POST,
) )
http_server.register_paths( http_server.register_paths(
"PUT", "PUT",
client_v2_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()), client_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()),
self.on_PUT, self.on_PUT,
) )
@ -120,7 +120,7 @@ class RelationPaginationServlet(RestServlet):
filtered by relation type and event type. filtered by relation type and event type.
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/relations/(?P<parent_id>[^/]*)" "/rooms/(?P<room_id>[^/]*)/relations/(?P<parent_id>[^/]*)"
"(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$", "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
releases=(), releases=(),
@ -197,7 +197,7 @@ class RelationAggregationPaginationServlet(RestServlet):
} }
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)" "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
"(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$", "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
releases=(), releases=(),
@ -269,7 +269,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
} }
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)" "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
"/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)/(?P<key>[^/]*)$", "/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)/(?P<key>[^/]*)$",
releases=(), releases=(),

View file

@ -27,13 +27,13 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
) )
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReportEventRestServlet(RestServlet): class ReportEventRestServlet(RestServlet):
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$"
) )

View file

@ -24,13 +24,13 @@ from synapse.http.servlet import (
parse_string, parse_string,
) )
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RoomKeysServlet(RestServlet): class RoomKeysServlet(RestServlet):
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$" "/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$"
) )
@ -256,7 +256,7 @@ class RoomKeysServlet(RestServlet):
class RoomKeysNewVersionServlet(RestServlet): class RoomKeysNewVersionServlet(RestServlet):
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/room_keys/version$" "/room_keys/version$"
) )
@ -314,7 +314,7 @@ class RoomKeysNewVersionServlet(RestServlet):
class RoomKeysVersionServlet(RestServlet): class RoomKeysVersionServlet(RestServlet):
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/room_keys/version(/(?P<version>[^/]+))?$" "/room_keys/version(/(?P<version>[^/]+))?$"
) )

View file

@ -25,7 +25,7 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
) )
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,7 +47,7 @@ class RoomUpgradeRestServlet(RestServlet):
Args: Args:
hs (synapse.server.HomeServer): hs (synapse.server.HomeServer):
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
# /rooms/$roomid/upgrade # /rooms/$roomid/upgrade
"/rooms/(?P<room_id>[^/]*)/upgrade$", "/rooms/(?P<room_id>[^/]*)/upgrade$",
) )

View file

@ -21,13 +21,13 @@ from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.transactions import HttpTransactionCache
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SendToDeviceRestServlet(servlet.RestServlet): class SendToDeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$", "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
) )

View file

@ -32,7 +32,7 @@ from synapse.handlers.sync import SyncConfig
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.types import StreamToken from synapse.types import StreamToken
from ._base import client_v2_patterns, set_timeline_upper_limit from ._base import client_patterns, set_timeline_upper_limit
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -73,7 +73,7 @@ class SyncRestServlet(RestServlet):
} }
""" """
PATTERNS = client_v2_patterns("/sync$") PATTERNS = client_patterns("/sync$")
ALLOWED_PRESENCE = set(["online", "offline", "unavailable"]) ALLOWED_PRESENCE = set(["online", "offline", "unavailable"])
def __init__(self, hs): def __init__(self, hs):

View file

@ -20,7 +20,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,7 +29,7 @@ class TagListServlet(RestServlet):
""" """
GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags" "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags"
) )
@ -54,7 +54,7 @@ class TagServlet(RestServlet):
PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)" "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
) )

View file

@ -21,13 +21,13 @@ from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind from synapse.api.constants import ThirdPartyEntityKind
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet): class ThirdPartyProtocolsServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/protocols") PATTERNS = client_patterns("/thirdparty/protocols")
def __init__(self, hs): def __init__(self, hs):
super(ThirdPartyProtocolsServlet, self).__init__() super(ThirdPartyProtocolsServlet, self).__init__()
@ -44,7 +44,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
class ThirdPartyProtocolServlet(RestServlet): class ThirdPartyProtocolServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$") PATTERNS = client_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
def __init__(self, hs): def __init__(self, hs):
super(ThirdPartyProtocolServlet, self).__init__() super(ThirdPartyProtocolServlet, self).__init__()
@ -66,7 +66,7 @@ class ThirdPartyProtocolServlet(RestServlet):
class ThirdPartyUserServlet(RestServlet): class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$") PATTERNS = client_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
def __init__(self, hs): def __init__(self, hs):
super(ThirdPartyUserServlet, self).__init__() super(ThirdPartyUserServlet, self).__init__()
@ -89,7 +89,7 @@ class ThirdPartyUserServlet(RestServlet):
class ThirdPartyLocationServlet(RestServlet): class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$") PATTERNS = client_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
def __init__(self, hs): def __init__(self, hs):
super(ThirdPartyLocationServlet, self).__init__() super(ThirdPartyLocationServlet, self).__init__()

View file

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns from ._base import client_patterns
class TokenRefreshRestServlet(RestServlet): class TokenRefreshRestServlet(RestServlet):
@ -26,7 +26,7 @@ class TokenRefreshRestServlet(RestServlet):
Exchanges refresh tokens for a pair of an access token and a new refresh Exchanges refresh tokens for a pair of an access token and a new refresh
token. token.
""" """
PATTERNS = client_v2_patterns("/tokenrefresh") PATTERNS = client_patterns("/tokenrefresh")
def __init__(self, hs): def __init__(self, hs):
super(TokenRefreshRestServlet, self).__init__() super(TokenRefreshRestServlet, self).__init__()

View file

@ -20,13 +20,13 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UserDirectorySearchRestServlet(RestServlet): class UserDirectorySearchRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/user_directory/search$") PATTERNS = client_patterns("/user_directory/search$")
def __init__(self, hs): def __init__(self, hs):
""" """

View file

@ -20,7 +20,7 @@ from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import KeyLookupError, ServerKeyFetcher from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler
from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.http.servlet import parse_integer, parse_json_object_from_request
@ -215,15 +215,7 @@ class RemoteKey(Resource):
json_results.add(bytes(result["key_json"])) json_results.add(bytes(result["key_json"]))
if cache_misses and query_remote_on_cache_miss: if cache_misses and query_remote_on_cache_miss:
for server_name, key_ids in cache_misses.items(): yield self.fetcher.get_keys(cache_misses)
try:
yield self.fetcher.get_server_verify_key_v2_direct(
server_name, key_ids
)
except KeyLookupError as e:
logger.info("Failed to fetch key: %s", e)
except Exception:
logger.exception("Failed to get key for %r", server_name)
yield self.query_keys( yield self.query_keys(
request, query, query_remote_on_cache_miss=False request, query, query_remote_on_cache_miss=False
) )

View file

@ -56,8 +56,8 @@ class ThumbnailResource(Resource):
def _async_render_GET(self, request): def _async_render_GET(self, request):
set_cors_headers(request) set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request) server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width") width = parse_integer(request, "width", required=True)
height = parse_integer(request, "height") height = parse_integer(request, "height", required=True)
method = parse_string(request, "method", "scale") method = parse_string(request, "method", "scale")
m_type = parse_string(request, "type", "image/png") m_type = parse_string(request, "type", "image/png")

View file

@ -36,6 +36,7 @@ from .engines import PostgresEngine
from .event_federation import EventFederationStore from .event_federation import EventFederationStore
from .event_push_actions import EventPushActionsStore from .event_push_actions import EventPushActionsStore
from .events import EventsStore from .events import EventsStore
from .events_bg_updates import EventsBackgroundUpdatesStore
from .filtering import FilteringStore from .filtering import FilteringStore
from .group_server import GroupServerStore from .group_server import GroupServerStore
from .keys import KeyStore from .keys import KeyStore
@ -66,6 +67,7 @@ logger = logging.getLogger(__name__)
class DataStore( class DataStore(
EventsBackgroundUpdatesStore,
RoomMemberStore, RoomMemberStore,
RoomStore, RoomStore,
RegistrationStore, RegistrationStore,

View file

@ -16,6 +16,7 @@
# limitations under the License. # limitations under the License.
import itertools import itertools
import logging import logging
import random
import sys import sys
import threading import threading
import time import time
@ -247,6 +248,8 @@ class SQLBaseStore(object):
self._check_safe_to_upsert, self._check_safe_to_upsert,
) )
self.rand = random.SystemRandom()
if self._account_validity.enabled: if self._account_validity.enabled:
self._clock.call_later( self._clock.call_later(
0.0, 0.0,
@ -308,21 +311,36 @@ class SQLBaseStore(object):
res = self.cursor_to_dict(txn) res = self.cursor_to_dict(txn)
if res: if res:
for user in res: for user in res:
self.set_expiration_date_for_user_txn(txn, user["name"]) self.set_expiration_date_for_user_txn(
txn,
user["name"],
use_delta=True,
)
yield self.runInteraction( yield self.runInteraction(
"get_users_with_no_expiration_date", "get_users_with_no_expiration_date",
select_users_with_no_expiration_date_txn, select_users_with_no_expiration_date_txn,
) )
def set_expiration_date_for_user_txn(self, txn, user_id): def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
"""Sets an expiration date to the account with the given user ID. """Sets an expiration date to the account with the given user ID.
Args: Args:
user_id (str): User ID to set an expiration date for. user_id (str): User ID to set an expiration date for.
use_delta (bool): If set to False, the expiration date for the user will be
now + validity period. If set to True, this expiration date will be a
random value in the [now + period - d ; now + period] range, d being a
delta equal to 10% of the validity period.
""" """
now_ms = self._clock.time_msec() now_ms = self._clock.time_msec()
expiration_ts = now_ms + self._account_validity.period expiration_ts = now_ms + self._account_validity.period
if use_delta:
expiration_ts = self.rand.randrange(
expiration_ts - self._account_validity.startup_job_max_delta,
expiration_ts,
)
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
"account_validity", "account_validity",
@ -1265,7 +1283,8 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k,) for k in keyvalues), " AND ".join("%s = ?" % (k,) for k in keyvalues),
) )
return txn.execute(sql, list(keyvalues.values())) txn.execute(sql, list(keyvalues.values()))
return txn.rowcount
def _simple_delete_many(self, table, column, iterable, keyvalues, desc): def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
return self.runInteraction( return self.runInteraction(
@ -1284,9 +1303,12 @@ class SQLBaseStore(object):
column : column name to test for inclusion against `iterable` column : column name to test for inclusion against `iterable`
iterable : list iterable : list
keyvalues : dict of column names and values to select the rows with keyvalues : dict of column names and values to select the rows with
Returns:
int: Number rows deleted
""" """
if not iterable: if not iterable:
return return 0
sql = "DELETE FROM %s" % table sql = "DELETE FROM %s" % table
@ -1301,7 +1323,9 @@ class SQLBaseStore(object):
if clauses: if clauses:
sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
return txn.execute(sql, values) txn.execute(sql, values)
return txn.rowcount
def _get_cache_dict( def _get_cache_dict(
self, db_conn, table, entity_column, stream_column, max_value, limit=100000 self, db_conn, table, entity_column, stream_column, max_value, limit=100000

Some files were not shown because too many files have changed in this diff Show more