Merge remote-tracking branch 'origin/develop' into joriks/opentracing_e2e

This commit is contained in:
Jorik Schellekens 2019-08-05 15:47:15 +01:00
commit a68119e676
94 changed files with 1397 additions and 421 deletions

View file

@ -49,14 +49,15 @@ steps:
- command: - command:
- "python -m pip install tox" - "apt-get update && apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev zlib1g-dev"
- "python3.5 -m pip install tox"
- "tox -e py35-old,codecov" - "tox -e py35-old,codecov"
label: ":python: 3.5 / SQLite / Old Deps" label: ":python: 3.5 / SQLite / Old Deps"
env: env:
TRIAL_FLAGS: "-j 2" TRIAL_FLAGS: "-j 2"
plugins: plugins:
- docker#v3.0.1: - docker#v3.0.1:
image: "python:3.5" image: "ubuntu:xenial" # We use xenail to get an old sqlite and python
propagate-environment: true propagate-environment: true
retry: retry:
automatic: automatic:

View file

@ -1,5 +1,4 @@
comment: comment: off
layout: "diff"
coverage: coverage:
status: status:

View file

@ -1,6 +1,48 @@
Synapse 1.2.1 (2019-07-26)
==========================
Security update
---------------
This release includes *four* security fixes:
- Prevent an attack where a federated server could send redactions for arbitrary events in v1 and v2 rooms. ([\#5767](https://github.com/matrix-org/synapse/issues/5767))
- Prevent a denial-of-service attack where cycles of redaction events would make Synapse spin infinitely. Thanks to `@lrizika:matrix.org` for identifying and responsibly disclosing this issue. ([0f2ecb961](https://github.com/matrix-org/synapse/commit/0f2ecb961))
- Prevent an attack where users could be joined or parted from public rooms without their consent. Thanks to @dylangerdaly for identifying and responsibly disclosing this issue. ([\#5744](https://github.com/matrix-org/synapse/issues/5744))
- Fix a vulnerability where a federated server could spoof read-receipts from
users on other servers. Thanks to @dylangerdaly for identifying this issue too. ([\#5743](https://github.com/matrix-org/synapse/issues/5743))
Additionally, the following fix was in Synapse **1.2.0**, but was not correctly
identified during the original release:
- It was possible for a room moderator to send a redaction for an `m.room.create` event, which would downgrade the room to version 1. Thanks to `/dev/ponies` for identifying and responsibly disclosing this issue! ([\#5701](https://github.com/matrix-org/synapse/issues/5701))
Synapse 1.2.0 (2019-07-25)
==========================
No significant changes.
Synapse 1.2.0rc2 (2019-07-24)
=============================
Bugfixes
--------
- Fix a regression introduced in v1.2.0rc1 which led to incorrect labels on some prometheus metrics. ([\#5734](https://github.com/matrix-org/synapse/issues/5734))
Synapse 1.2.0rc1 (2019-07-22) Synapse 1.2.0rc1 (2019-07-22)
============================= =============================
Security fixes
--------------
This update included a security fix which was initially incorrectly flagged as
a regular bug fix.
- It was possible for a room moderator to send a redaction for an `m.room.create` event, which would downgrade the room to version 1. Thanks to `/dev/ponies` for identifying and responsibly disclosing this issue! ([\#5701](https://github.com/matrix-org/synapse/issues/5701))
Features Features
-------- --------
@ -26,7 +68,6 @@ Bugfixes
- Fix bug in #5626 that prevented the original_event field from actually having the contents of the original event in a call to `/relations`. ([\#5654](https://github.com/matrix-org/synapse/issues/5654)) - Fix bug in #5626 that prevented the original_event field from actually having the contents of the original event in a call to `/relations`. ([\#5654](https://github.com/matrix-org/synapse/issues/5654))
- Fix 3PID bind requests being sent to identity servers as `application/x-form-www-urlencoded` data, which is deprecated. ([\#5658](https://github.com/matrix-org/synapse/issues/5658)) - Fix 3PID bind requests being sent to identity servers as `application/x-form-www-urlencoded` data, which is deprecated. ([\#5658](https://github.com/matrix-org/synapse/issues/5658))
- Fix some problems with authenticating redactions in recent room versions. ([\#5699](https://github.com/matrix-org/synapse/issues/5699), [\#5700](https://github.com/matrix-org/synapse/issues/5700), [\#5707](https://github.com/matrix-org/synapse/issues/5707)) - Fix some problems with authenticating redactions in recent room versions. ([\#5699](https://github.com/matrix-org/synapse/issues/5699), [\#5700](https://github.com/matrix-org/synapse/issues/5700), [\#5707](https://github.com/matrix-org/synapse/issues/5707))
- Ignore redactions of m.room.create events. ([\#5701](https://github.com/matrix-org/synapse/issues/5701))
Updates to the Docker image Updates to the Docker image

1
changelog.d/5686.feature Normal file
View file

@ -0,0 +1 @@
Use `M_USER_DEACTIVATED` instead of `M_UNKNOWN` for errcode when a deactivated user attempts to login.

1
changelog.d/5693.bugfix Normal file
View file

@ -0,0 +1 @@
Fix UISIs during homeserver outage.

1
changelog.d/5743.bugfix Normal file
View file

@ -0,0 +1 @@
Log when we receive an event receipt from an unexpected origin.

1
changelog.d/5746.misc Normal file
View file

@ -0,0 +1 @@
Reduce database IO usage by optimising queries for current membership.

1
changelog.d/5749.misc Normal file
View file

@ -0,0 +1 @@
Fix some error cases in the caching layer.

1
changelog.d/5750.misc Normal file
View file

@ -0,0 +1 @@
Add a prometheus metric for pending cache lookups.

1
changelog.d/5752.misc Normal file
View file

@ -0,0 +1 @@
Reduce database IO usage by optimising queries for current membership.

1
changelog.d/5753.misc Normal file
View file

@ -0,0 +1 @@
Stop trying to fetch events with event_id=None.

1
changelog.d/5768.misc Normal file
View file

@ -0,0 +1 @@
Convert RedactionTestCase to modern test style.

1
changelog.d/5770.misc Normal file
View file

@ -0,0 +1 @@
Reduce database IO usage by optimising queries for current membership.

1
changelog.d/5774.misc Normal file
View file

@ -0,0 +1 @@
Reduce database IO usage by optimising queries for current membership.

1
changelog.d/5775.bugfix Normal file
View file

@ -0,0 +1 @@
Fix debian packaging scripts to correctly build sid packages.

1
changelog.d/5780.misc Normal file
View file

@ -0,0 +1 @@
Allow looping calls to be given arguments.

1
changelog.d/5782.removal Normal file
View file

@ -0,0 +1 @@
Remove non-functional 'expire_access_token' setting.

1
changelog.d/5783.feature Normal file
View file

@ -0,0 +1 @@
Synapse can now be configured to not join remote rooms of a given "complexity" (currently, state events) over federation. This option can be used to prevent adverse performance on resource-constrained homeservers.

1
changelog.d/5785.misc Normal file
View file

@ -0,0 +1 @@
Set the logs emitted when checking typing and presence timeouts to DEBUG level, not INFO.

1
changelog.d/5787.misc Normal file
View file

@ -0,0 +1 @@
Remove DelayedCall debugging from the test suite, as it is no longer required in the vast majority of Synapse's tests.

1
changelog.d/5789.bugfix Normal file
View file

@ -0,0 +1 @@
Fix UISIs during homeserver outage.

1
changelog.d/5790.misc Normal file
View file

@ -0,0 +1 @@
Remove some spurious exceptions from the logs where we failed to talk to a remote server.

1
changelog.d/5792.misc Normal file
View file

@ -0,0 +1 @@
Reduce database IO usage by optimising queries for current membership.

1
changelog.d/5793.misc Normal file
View file

@ -0,0 +1 @@
Reduce database IO usage by optimising queries for current membership.

1
changelog.d/5794.misc Normal file
View file

@ -0,0 +1 @@
Improve performance when making `.well-known` requests by sharing the SSL options between requests.

1
changelog.d/5796.misc Normal file
View file

@ -0,0 +1 @@
Disable codecov GitHub comments on PRs.

1
changelog.d/5801.misc Normal file
View file

@ -0,0 +1 @@
Don't allow clients to send tombstone events that reference the room it's sent in.

1
changelog.d/5802.misc Normal file
View file

@ -0,0 +1 @@
Deny redactions of events sent in a different room.

1
changelog.d/5804.bugfix Normal file
View file

@ -0,0 +1 @@
Fix check that tombstone is a state event in push rules.

1
changelog.d/5805.misc Normal file
View file

@ -0,0 +1 @@
Deny sending well known state types as non-state events.

1
changelog.d/5806.bugfix Normal file
View file

@ -0,0 +1 @@
Fix error when trying to login as a deactivated user when using a worker to handle login.

1
changelog.d/5807.feature Normal file
View file

@ -0,0 +1 @@
Allow defining HTML templates to serve the user on account renewal attempt when using the account validity feature.

1
changelog.d/5808.misc Normal file
View file

@ -0,0 +1 @@
Handle incorrectly encoded query params correctly by returning a 400.

1
changelog.d/5810.misc Normal file
View file

@ -0,0 +1 @@
Return 502 not 500 when failing to reach any remote server.

13
debian/changelog vendored
View file

@ -1,4 +1,10 @@
matrix-synapse-py3 (1.1.0-1) UNRELEASED; urgency=medium matrix-synapse-py3 (1.2.1) stable; urgency=medium
* New synapse release 1.2.1.
-- Synapse Packaging team <packages@matrix.org> Fri, 26 Jul 2019 11:32:47 +0100
matrix-synapse-py3 (1.2.0) stable; urgency=medium
[ Amber Brown ] [ Amber Brown ]
* Update logging config defaults to match API changes in Synapse. * Update logging config defaults to match API changes in Synapse.
@ -6,7 +12,10 @@ matrix-synapse-py3 (1.1.0-1) UNRELEASED; urgency=medium
[ Richard van der Hoff ] [ Richard van der Hoff ]
* Add Recommends and Depends for some libraries which you probably want. * Add Recommends and Depends for some libraries which you probably want.
-- Erik Johnston <erikj@rae> Thu, 04 Jul 2019 13:59:02 +0100 [ Synapse Packaging team ]
* New synapse release 1.2.0.
-- Synapse Packaging team <packages@matrix.org> Thu, 25 Jul 2019 14:10:07 +0100
matrix-synapse-py3 (1.1.0) stable; urgency=medium matrix-synapse-py3 (1.1.0) stable; urgency=medium

View file

@ -42,6 +42,11 @@ RUN cd dh-virtualenv-1.1 && dpkg-buildpackage -us -uc -b
### ###
FROM ${distro} FROM ${distro}
# Get the distro we want to pull from as a dynamic build variable
# (We need to define it in each build stage)
ARG distro=""
ENV distro ${distro}
# Install the build dependencies # Install the build dependencies
# #
# NB: keep this list in sync with the list of build-deps in debian/control # NB: keep this list in sync with the list of build-deps in debian/control

View file

@ -4,7 +4,8 @@
set -ex set -ex
DIST=`lsb_release -c -s` # Get the codename from distro env
DIST=`cut -d ':' -f2 <<< $distro`
# we get a read-only copy of the source: make a writeable copy # we get a read-only copy of the source: make a writeable copy
cp -aT /synapse/source /synapse/build cp -aT /synapse/source /synapse/build

View file

@ -278,6 +278,23 @@ listeners:
# Used by phonehome stats to group together related servers. # Used by phonehome stats to group together related servers.
#server_context: context #server_context: context
# Resource-constrained Homeserver Settings
#
# If limit_remote_rooms.enabled is True, the room complexity will be
# checked before a user joins a new remote room. If it is above
# limit_remote_rooms.complexity, it will disallow joining or
# instantly leave.
#
# limit_remote_rooms.complexity_error can be set to customise the text
# displayed to the user when a room above the complexity threshold has
# its join cancelled.
#
# Uncomment the below lines to enable:
#limit_remote_rooms:
# enabled: True
# complexity: 1.0
# complexity_error: "This room is too complex."
# Whether to require a user to be in the room to add an alias to it. # Whether to require a user to be in the room to add an alias to it.
# Defaults to 'true'. # Defaults to 'true'.
# #
@ -785,6 +802,16 @@ uploads_path: "DATADIR/uploads"
# period: 6w # period: 6w
# renew_at: 1w # renew_at: 1w
# renew_email_subject: "Renew your %(app)s account" # renew_email_subject: "Renew your %(app)s account"
# # Directory in which Synapse will try to find the HTML files to serve to the
# # user when trying to renew an account. Optional, defaults to
# # synapse/res/templates.
# template_dir: "res/templates"
# # HTML to be displayed to the user after they successfully renewed their
# # account. Optional.
# account_renewed_html_path: "account_renewed.html"
# # HTML to be displayed when the user tries to renew an account with an invalid
# # renewal token. Optional.
# invalid_token_html_path: "invalid_token.html"
# Time that a user's session remains valid for, after they log in. # Time that a user's session remains valid for, after they log in.
# #
@ -925,10 +952,6 @@ uploads_path: "DATADIR/uploads"
# #
# macaroon_secret_key: <PRIVATE STRING> # macaroon_secret_key: <PRIVATE STRING>
# Used to enable access token expiration.
#
#expire_access_token: False
# a secret which is used to calculate HMACs for form values, to stop # a secret which is used to calculate HMACs for form values, to stop
# falsification of values. Must be specified for the User Consent # falsification of values. Must be specified for the User Consent
# forms to work. # forms to work.

View file

@ -35,4 +35,4 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "1.2.0rc1" __version__ = "1.2.1"

View file

@ -421,21 +421,16 @@ class Auth(object):
try: try:
user_id = self.get_user_id_from_macaroon(macaroon) user_id = self.get_user_id_from_macaroon(macaroon)
has_expiry = False
guest = False guest = False
for caveat in macaroon.caveats: for caveat in macaroon.caveats:
if caveat.caveat_id.startswith("time "): if caveat.caveat_id == "guest = true":
has_expiry = True
elif caveat.caveat_id == "guest = true":
guest = True guest = True
self.validate_macaroon( self.validate_macaroon(macaroon, rights, user_id=user_id)
macaroon, rights, self.hs.config.expire_access_token, user_id=user_id
)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise InvalidClientTokenError("Invalid macaroon passed.") raise InvalidClientTokenError("Invalid macaroon passed.")
if not has_expiry and rights == "access": if rights == "access":
self.token_cache[token] = (user_id, guest) self.token_cache[token] = (user_id, guest)
return user_id, guest return user_id, guest
@ -461,7 +456,7 @@ class Auth(object):
return caveat.caveat_id[len(user_prefix) :] return caveat.caveat_id[len(user_prefix) :]
raise InvalidClientTokenError("No user caveat in macaroon") raise InvalidClientTokenError("No user caveat in macaroon")
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id): def validate_macaroon(self, macaroon, type_string, user_id):
""" """
validate that a Macaroon is understood by and was signed by this server. validate that a Macaroon is understood by and was signed by this server.
@ -469,7 +464,6 @@ class Auth(object):
macaroon(pymacaroons.Macaroon): The macaroon to validate macaroon(pymacaroons.Macaroon): The macaroon to validate
type_string(str): The kind of token required (e.g. "access", type_string(str): The kind of token required (e.g. "access",
"delete_pusher") "delete_pusher")
verify_expiry(bool): Whether to verify whether the macaroon has expired.
user_id (str): The user_id required user_id (str): The user_id required
""" """
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
@ -482,19 +476,7 @@ class Auth(object):
v.satisfy_exact("type = " + type_string) v.satisfy_exact("type = " + type_string)
v.satisfy_exact("user_id = %s" % user_id) v.satisfy_exact("user_id = %s" % user_id)
v.satisfy_exact("guest = true") v.satisfy_exact("guest = true")
# verify_expiry should really always be True, but there exist access
# tokens in the wild which expire when they should not, so we can't
# enforce expiry yet (so we have to allow any caveat starting with
# 'time < ' in access tokens).
#
# On the other hand, short-term login tokens (as used by CAS login, for
# example) have an expiry time which we do want to enforce.
if verify_expiry:
v.satisfy_general(self._verify_expiry) v.satisfy_general(self._verify_expiry)
else:
v.satisfy_general(lambda c: c.startswith("time < "))
# access_tokens include a nonce for uniqueness: any value is acceptable # access_tokens include a nonce for uniqueness: any value is acceptable
v.satisfy_general(lambda c: c.startswith("nonce = ")) v.satisfy_general(lambda c: c.startswith("nonce = "))

View file

@ -61,6 +61,7 @@ class Codes(object):
INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION" INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION" WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION"
EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT" EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT"
USER_DEACTIVATED = "M_USER_DEACTIVATED"
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):
@ -151,7 +152,7 @@ class UserDeactivatedError(SynapseError):
msg (str): The human-readable error message msg (str): The human-readable error message
""" """
super(UserDeactivatedError, self).__init__( super(UserDeactivatedError, self).__init__(
code=http_client.FORBIDDEN, msg=msg, errcode=Codes.UNKNOWN code=http_client.FORBIDDEN, msg=msg, errcode=Codes.USER_DEACTIVATED
) )

View file

@ -116,8 +116,6 @@ class KeyConfig(Config):
seed = bytes(self.signing_key[0]) seed = bytes(self.signing_key[0])
self.macaroon_secret_key = hashlib.sha256(seed).digest() self.macaroon_secret_key = hashlib.sha256(seed).digest()
self.expire_access_token = config.get("expire_access_token", False)
# a secret which is used to calculate HMACs for form values, to stop # a secret which is used to calculate HMACs for form values, to stop
# falsification of values # falsification of values
self.form_secret = config.get("form_secret", None) self.form_secret = config.get("form_secret", None)
@ -144,10 +142,6 @@ class KeyConfig(Config):
# #
%(macaroon_secret_key)s %(macaroon_secret_key)s
# Used to enable access token expiration.
#
#expire_access_token: False
# a secret which is used to calculate HMACs for form values, to stop # a secret which is used to calculate HMACs for form values, to stop
# falsification of values. Must be specified for the User Consent # falsification of values. Must be specified for the User Consent
# forms to work. # forms to work.

View file

@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from distutils.util import strtobool from distutils.util import strtobool
import pkg_resources
from synapse.config._base import Config, ConfigError from synapse.config._base import Config, ConfigError
from synapse.types import RoomAlias from synapse.types import RoomAlias
from synapse.util.stringutils import random_string_with_symbols from synapse.util.stringutils import random_string_with_symbols
@ -41,9 +44,37 @@ class AccountValidityConfig(Config):
self.startup_job_max_delta = self.period * 10.0 / 100.0 self.startup_job_max_delta = self.period * 10.0 / 100.0
if self.renew_by_email_enabled and "public_baseurl" not in synapse_config: if self.renew_by_email_enabled:
if "public_baseurl" not in synapse_config:
raise ConfigError("Can't send renewal emails without 'public_baseurl'") raise ConfigError("Can't send renewal emails without 'public_baseurl'")
template_dir = config.get("template_dir")
if not template_dir:
template_dir = pkg_resources.resource_filename("synapse", "res/templates")
if "account_renewed_html_path" in config:
file_path = os.path.join(template_dir, config["account_renewed_html_path"])
self.account_renewed_html_content = self.read_file(
file_path, "account_validity.account_renewed_html_path"
)
else:
self.account_renewed_html_content = (
"<html><body>Your account has been successfully renewed.</body><html>"
)
if "invalid_token_html_path" in config:
file_path = os.path.join(template_dir, config["invalid_token_html_path"])
self.invalid_token_html_content = self.read_file(
file_path, "account_validity.invalid_token_html_path"
)
else:
self.invalid_token_html_content = (
"<html><body>Invalid renewal token.</body><html>"
)
class RegistrationConfig(Config): class RegistrationConfig(Config):
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
@ -145,6 +176,16 @@ class RegistrationConfig(Config):
# period: 6w # period: 6w
# renew_at: 1w # renew_at: 1w
# renew_email_subject: "Renew your %%(app)s account" # renew_email_subject: "Renew your %%(app)s account"
# # Directory in which Synapse will try to find the HTML files to serve to the
# # user when trying to renew an account. Optional, defaults to
# # synapse/res/templates.
# template_dir: "res/templates"
# # HTML to be displayed to the user after they successfully renewed their
# # account. Optional.
# account_renewed_html_path: "account_renewed.html"
# # HTML to be displayed when the user tries to renew an account with an invalid
# # renewal token. Optional.
# invalid_token_html_path: "invalid_token.html"
# Time that a user's session remains valid for, after they log in. # Time that a user's session remains valid for, after they log in.
# #

View file

@ -18,6 +18,7 @@
import logging import logging
import os.path import os.path
import attr
from netaddr import IPSet from netaddr import IPSet
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@ -38,6 +39,12 @@ DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"]
DEFAULT_ROOM_VERSION = "4" DEFAULT_ROOM_VERSION = "4"
ROOM_COMPLEXITY_TOO_GREAT = (
"Your homeserver is unable to join rooms this large or complex. "
"Please speak to your server administrator, or upgrade your instance "
"to join this room."
)
class ServerConfig(Config): class ServerConfig(Config):
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
@ -247,6 +254,23 @@ class ServerConfig(Config):
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None)) self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
@attr.s
class LimitRemoteRoomsConfig(object):
enabled = attr.ib(
validator=attr.validators.instance_of(bool), default=False
)
complexity = attr.ib(
validator=attr.validators.instance_of((int, float)), default=1.0
)
complexity_error = attr.ib(
validator=attr.validators.instance_of(str),
default=ROOM_COMPLEXITY_TOO_GREAT,
)
self.limit_remote_rooms = LimitRemoteRoomsConfig(
**config.get("limit_remote_rooms", {})
)
bind_port = config.get("bind_port") bind_port = config.get("bind_port")
if bind_port: if bind_port:
if config.get("no_tls", False): if config.get("no_tls", False):
@ -617,6 +641,23 @@ class ServerConfig(Config):
# Used by phonehome stats to group together related servers. # Used by phonehome stats to group together related servers.
#server_context: context #server_context: context
# Resource-constrained Homeserver Settings
#
# If limit_remote_rooms.enabled is True, the room complexity will be
# checked before a user joins a new remote room. If it is above
# limit_remote_rooms.complexity, it will disallow joining or
# instantly leave.
#
# limit_remote_rooms.complexity_error can be set to customise the text
# displayed to the user when a room above the complexity threshold has
# its join cancelled.
#
# Uncomment the below lines to enable:
#limit_remote_rooms:
# enabled: True
# complexity: 1.0
# complexity_error: "This room is too complex."
# Whether to require a user to be in the room to add an alias to it. # Whether to require a user to be in the room to add an alias to it.
# Defaults to 'true'. # Defaults to 'true'.
# #

View file

@ -31,6 +31,7 @@ from twisted.internet.ssl import (
platformTrust, platformTrust,
) )
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.web.iweb import IPolicyForHTTPS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -74,6 +75,7 @@ class ServerContextFactory(ContextFactory):
return self._context return self._context
@implementer(IPolicyForHTTPS)
class ClientTLSOptionsFactory(object): class ClientTLSOptionsFactory(object):
"""Factory for Twisted SSLClientConnectionCreators that are used to make connections """Factory for Twisted SSLClientConnectionCreators that are used to make connections
to remote servers for federation. to remote servers for federation.
@ -146,6 +148,12 @@ class ClientTLSOptionsFactory(object):
f = Failure() f = Failure()
tls_protocol.failVerification(f) tls_protocol.failVerification(f)
def creatorForNetloc(self, hostname, port):
"""Implements the IPolicyForHTTPS interace so that this can be passed
directly to agents.
"""
return self.get_options(hostname)
@implementer(IOpenSSLClientConnectionCreator) @implementer(IOpenSSLClientConnectionCreator)
class SSLClientConnectionCreator(object): class SSLClientConnectionCreator(object):

View file

@ -95,10 +95,10 @@ class EventValidator(object):
elif event.type == EventTypes.Topic: elif event.type == EventTypes.Topic:
self._ensure_strings(event.content, ["topic"]) self._ensure_strings(event.content, ["topic"])
self._ensure_state_event(event)
elif event.type == EventTypes.Name: elif event.type == EventTypes.Name:
self._ensure_strings(event.content, ["name"]) self._ensure_strings(event.content, ["name"])
self._ensure_state_event(event)
elif event.type == EventTypes.Member: elif event.type == EventTypes.Member:
if "membership" not in event.content: if "membership" not in event.content:
raise SynapseError(400, "Content has not membership key") raise SynapseError(400, "Content has not membership key")
@ -106,9 +106,25 @@ class EventValidator(object):
if event.content["membership"] not in Membership.LIST: if event.content["membership"] not in Membership.LIST:
raise SynapseError(400, "Invalid membership key") raise SynapseError(400, "Invalid membership key")
self._ensure_state_event(event)
elif event.type == EventTypes.Tombstone:
if "replacement_room" not in event.content:
raise SynapseError(400, "Content has no replacement_room key")
if event.content["replacement_room"] == event.room_id:
raise SynapseError(
400, "Tombstone cannot reference the room it was sent in"
)
self._ensure_state_event(event)
def _ensure_strings(self, d, keys): def _ensure_strings(self, d, keys):
for s in keys: for s in keys:
if s not in d: if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,)) raise SynapseError(400, "'%s' not in content" % (s,))
if not isinstance(d[s], string_types): if not isinstance(d[s], string_types):
raise SynapseError(400, "'%s' not a string type" % (s,)) raise SynapseError(400, "'%s' not a string type" % (s,))
def _ensure_state_event(self, event):
if not event.is_state():
raise SynapseError(400, "'%s' must be state events" % (event.type,))

View file

@ -511,9 +511,8 @@ class FederationClient(FederationBase):
The [Deferred] result of callback, if it succeeds The [Deferred] result of callback, if it succeeds
Raises: Raises:
SynapseError if the chosen remote server returns a 300/400 code. SynapseError if the chosen remote server returns a 300/400 code, or
no servers were reachable.
RuntimeError if no servers were reachable.
""" """
for destination in destinations: for destination in destinations:
if destination == self.server_name: if destination == self.server_name:
@ -538,7 +537,7 @@ class FederationClient(FederationBase):
except Exception: except Exception:
logger.warn("Failed to %s via %s", description, destination, exc_info=1) logger.warn("Failed to %s via %s", description, destination, exc_info=1)
raise RuntimeError("Failed to %s via any server" % (description,)) raise SynapseError(502, "Failed to %s via any server" % (description,))
def make_membership_event( def make_membership_event(
self, destinations, room_id, user_id, membership, content, params self, destinations, room_id, user_id, membership, content, params
@ -993,3 +992,39 @@ class FederationClient(FederationBase):
) )
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks
def get_room_complexity(self, destination, room_id):
"""
Fetch the complexity of a remote room from another server.
Args:
destination (str): The remote server
room_id (str): The room ID to ask about.
Returns:
Deferred[dict] or Deferred[None]: Dict contains the complexity
metric versions, while None means we could not fetch the complexity.
"""
try:
complexity = yield self.transport_layer.get_room_complexity(
destination=destination, room_id=room_id
)
defer.returnValue(complexity)
except CodeMessageException as e:
# We didn't manage to get it -- probably a 404. We are okay if other
# servers don't give it to us.
logger.debug(
"Failed to fetch room complexity via %s for %s, got a %d",
destination,
room_id,
e.code,
)
except Exception:
logger.exception(
"Failed to fetch room complexity via %s for %s", destination, room_id
)
# If we don't manage to find it, return None. It's not an error if a
# server doesn't give it to us.
defer.returnValue(None)

View file

@ -365,7 +365,7 @@ class FederationServer(FederationBase):
logger.warn("Room version %s not in %s", room_version, supported_versions) logger.warn("Room version %s not in %s", room_version, supported_versions)
raise IncompatibleRoomVersionError(room_version=room_version) raise IncompatibleRoomVersionError(room_version=room_version)
pdu = yield self.handler.on_make_join_request(room_id, user_id) pdu = yield self.handler.on_make_join_request(origin, room_id, user_id)
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
@ -415,7 +415,7 @@ class FederationServer(FederationBase):
def on_make_leave_request(self, origin, room_id, user_id): def on_make_leave_request(self, origin, room_id, user_id):
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, room_id) yield self.check_server_matches_acl(origin_host, room_id)
pdu = yield self.handler.on_make_leave_request(room_id, user_id) pdu = yield self.handler.on_make_leave_request(origin, room_id, user_id)
room_version = yield self.store.get_room_version(room_id) room_version = yield self.store.get_room_version(room_id)

View file

@ -21,7 +21,11 @@ from six.moves import urllib
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX from synapse.api.urls import (
FEDERATION_UNSTABLE_PREFIX,
FEDERATION_V1_PREFIX,
FEDERATION_V2_PREFIX,
)
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -935,6 +939,23 @@ class TransportLayerClient(object):
destination=destination, path=path, data=content, ignore_backoff=True destination=destination, path=path, data=content, ignore_backoff=True
) )
def get_room_complexity(self, destination, room_id):
"""
Args:
destination (str): The remote server
room_id (str): The room ID to ask about.
"""
path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/rooms/%s/complexity", room_id)
return self.client.get_json(destination=destination, path=path)
def _create_path(federation_prefix, path, *args):
"""
Ensures that all args are url encoded.
"""
return federation_prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args)
def _create_v1_path(path, *args): def _create_v1_path(path, *args):
"""Creates a path against V1 federation API from the path template and """Creates a path against V1 federation API from the path template and
@ -951,9 +972,7 @@ def _create_v1_path(path, *args):
Returns: Returns:
str str
""" """
return FEDERATION_V1_PREFIX + path % tuple( return _create_path(FEDERATION_V1_PREFIX, path, *args)
urllib.parse.quote(arg, "") for arg in args
)
def _create_v2_path(path, *args): def _create_v2_path(path, *args):
@ -971,6 +990,4 @@ def _create_v2_path(path, *args):
Returns: Returns:
str str
""" """
return FEDERATION_V2_PREFIX + path % tuple( return _create_path(FEDERATION_V2_PREFIX, path, *args)
urllib.parse.quote(arg, "") for arg in args
)

View file

@ -326,7 +326,9 @@ class BaseFederationServlet(object):
if code is None: if code is None:
continue continue
server.register_paths(method, (pattern,), self._wrap(code)) server.register_paths(
method, (pattern,), self._wrap(code), self.__class__.__name__
)
class FederationSendServlet(BaseFederationServlet): class FederationSendServlet(BaseFederationServlet):

View file

@ -226,11 +226,19 @@ class AccountValidityHandler(object):
Args: Args:
renewal_token (str): Token sent with the renewal request. renewal_token (str): Token sent with the renewal request.
Returns:
bool: Whether the provided token is valid.
""" """
try:
user_id = yield self.store.get_user_from_renewal_token(renewal_token) user_id = yield self.store.get_user_from_renewal_token(renewal_token)
except StoreError:
defer.returnValue(False)
logger.debug("Renewing an account for user %s", user_id) logger.debug("Renewing an account for user %s", user_id)
yield self.renew_account_for_user(user_id) yield self.renew_account_for_user(user_id)
defer.returnValue(True)
@defer.inlineCallbacks @defer.inlineCallbacks
def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False): def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False):
"""Renews the account attached to a given user by pushing back the """Renews the account attached to a given user by pushing back the

View file

@ -860,7 +860,7 @@ class AuthHandler(BaseHandler):
try: try:
macaroon = pymacaroons.Macaroon.deserialize(login_token) macaroon = pymacaroons.Macaroon.deserialize(login_token)
user_id = auth_api.get_user_id_from_macaroon(macaroon) user_id = auth_api.get_user_id_from_macaroon(macaroon)
auth_api.validate_macaroon(macaroon, "login", True, user_id) auth_api.validate_macaroon(macaroon, "login", user_id)
except Exception: except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
self.ratelimit_login_per_account(user_id) self.ratelimit_login_per_account(user_id)

View file

@ -229,12 +229,12 @@ class DeviceHandler(DeviceWorkerHandler):
self.federation_sender = hs.get_federation_sender() self.federation_sender = hs.get_federation_sender()
self._edu_updater = DeviceListEduUpdater(hs, self) self.device_list_updater = DeviceListUpdater(hs, self)
federation_registry = hs.get_federation_registry() federation_registry = hs.get_federation_registry()
federation_registry.register_edu_handler( federation_registry.register_edu_handler(
"m.device_list_update", self._edu_updater.incoming_device_list_update "m.device_list_update", self.device_list_updater.incoming_device_list_update
) )
federation_registry.register_query_handler( federation_registry.register_query_handler(
"user_devices", self.on_federation_query_user_devices "user_devices", self.on_federation_query_user_devices
@ -460,7 +460,7 @@ def _update_device_from_client_ips(device, client_ips):
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
class DeviceListEduUpdater(object): class DeviceListUpdater(object):
"Handles incoming device list updates from federation and updates the DB" "Handles incoming device list updates from federation and updates the DB"
def __init__(self, hs, device_handler): def __init__(self, hs, device_handler):
@ -574,85 +574,7 @@ class DeviceListEduUpdater(object):
logger.debug("Need to re-sync devices for %r? %r", user_id, resync) logger.debug("Need to re-sync devices for %r? %r", user_id, resync)
if resync: if resync:
opentracing.log_kv({"message": "Doing resync to update device list."}) yield self.user_device_resync(user_id)
# Fetch all devices for the user.
origin = get_domain_from_id(user_id)
try:
result = yield self.federation.query_user_devices(origin, user_id)
except (
NotRetryingDestination,
RequestSendFailed,
HttpResponseException,
):
# TODO: Remember that we are now out of sync and try again
# later
logger.warn("Failed to handle device list update for %s", user_id)
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
# is out of date. If we bail then we will retry the resync
# next time we get a device list update for this user_id.
# This makes it more likely that the device lists will
# eventually become consistent.
return
except FederationDeniedError as e:
opentracing.set_tag("error", True)
opentracing.log_kv({"reason": "FederationDeniedError"})
logger.info(e)
return
except Exception as e:
# TODO: Remember that we are now out of sync and try again
# later
opentracing.set_tag("error", True)
opentracing.log_kv(
{
"message": "Exception raised by federation request",
"exception": e,
}
)
logger.exception(
"Failed to handle device list update for %s", user_id
)
return
opentracing.log_kv({"result": result})
stream_id = result["stream_id"]
devices = result["devices"]
# If the remote server has more than ~1000 devices for this user
# we assume that something is going horribly wrong (e.g. a bot
# that logs in and creates a new device every time it tries to
# send a message). Maintaining lots of devices per user in the
# cache can cause serious performance issues as if this request
# takes more than 60s to complete, internal replication from the
# inbound federation worker to the synapse master may time out
# causing the inbound federation to fail and causing the remote
# server to retry, causing a DoS. So in this scenario we give
# up on storing the total list of devices and only handle the
# delta instead.
if len(devices) > 1000:
logger.warn(
"Ignoring device list snapshot for %s as it has >1K devs (%d)",
user_id,
len(devices),
)
devices = []
for device in devices:
logger.debug(
"Handling resync update %r/%r, ID: %r",
user_id,
device["device_id"],
stream_id,
)
yield self.store.update_remote_device_list_cache(
user_id, devices, stream_id
)
device_ids = [device["device_id"] for device in devices]
yield self.device_handler.notify_device_update(user_id, device_ids)
# We clobber the seen updates since we've re-synced from a given
# point.
self._seen_updates[user_id] = set([stream_id])
else: else:
# Simply update the single device, since we know that is the only # Simply update the single device, since we know that is the only
# change (because of the single prev_id matching the current cache) # change (because of the single prev_id matching the current cache)
@ -699,3 +621,85 @@ class DeviceListEduUpdater(object):
stream_id_in_updates.add(stream_id) stream_id_in_updates.add(stream_id)
return False return False
@defer.inlineCallbacks
def user_device_resync(self, user_id):
"""Fetches all devices for a user and updates the device cache with them.
Args:
user_id (str): The user's id whose device_list will be updated.
Returns:
Deferred[dict]: a dict with device info as under the "devices" in the result of this
request:
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
"""
opentracing.log_kv({"message": "Doing resync to update device list."})
# Fetch all devices for the user.
origin = get_domain_from_id(user_id)
try:
result = yield self.federation.query_user_devices(origin, user_id)
except (NotRetryingDestination, RequestSendFailed, HttpResponseException):
# TODO: Remember that we are now out of sync and try again
# later
logger.warn("Failed to handle device list update for %s", user_id)
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
# is out of date. If we bail then we will retry the resync
# next time we get a device list update for this user_id.
# This makes it more likely that the device lists will
# eventually become consistent.
return
except FederationDeniedError as e:
opentracing.set_tag("error", True)
opentracing.log_kv({"reason": "FederationDeniedError"})
logger.info(e)
return
except Exception:
# TODO: Remember that we are now out of sync and try again
# later
opentracing.set_tag("error", True)
opentracing.log_kv(
{"message": "Exception raised by federation request", "exception": e}
)
logger.exception("Failed to handle device list update for %s", user_id)
return
opentracing.log_kv({"result": result})
stream_id = result["stream_id"]
devices = result["devices"]
# If the remote server has more than ~1000 devices for this user
# we assume that something is going horribly wrong (e.g. a bot
# that logs in and creates a new device every time it tries to
# send a message). Maintaining lots of devices per user in the
# cache can cause serious performance issues as if this request
# takes more than 60s to complete, internal replication from the
# inbound federation worker to the synapse master may time out
# causing the inbound federation to fail and causing the remote
# server to retry, causing a DoS. So in this scenario we give
# up on storing the total list of devices and only handle the
# delta instead.
if len(devices) > 1000:
logger.warn(
"Ignoring device list snapshot for %s as it has >1K devs (%d)",
user_id,
len(devices),
)
devices = []
for device in devices:
logger.debug(
"Handling resync update %r/%r, ID: %r",
user_id,
device["device_id"],
stream_id,
)
yield self.store.update_remote_device_list_cache(user_id, devices, stream_id)
device_ids = [device["device_id"] for device in devices]
yield self.device_handler.notify_device_update(user_id, device_ids)
# We clobber the seen updates since we've re-synced from a given
# point.
self._seen_updates[user_id] = set([stream_id])
defer.returnValue(result)

View file

@ -278,7 +278,6 @@ class DirectoryHandler(BaseHandler):
servers = list(servers) servers = list(servers)
return {"room_id": room_id, "servers": servers} return {"room_id": room_id, "servers": servers}
return
@defer.inlineCallbacks @defer.inlineCallbacks
def on_directory_query(self, args): def on_directory_query(self, args):

View file

@ -26,6 +26,7 @@ import synapse.logging.opentracing as opentracing
from synapse.api.errors import CodeMessageException, SynapseError from synapse.api.errors import CodeMessageException, SynapseError
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
from synapse.util import unwrapFirstError
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -128,9 +129,57 @@ class E2eKeysHandler(object):
@opentracing.trace @opentracing.trace
@defer.inlineCallbacks @defer.inlineCallbacks
def do_remote_query(destination): def do_remote_query(destination):
"""This is called when we are querying the device list of a user on
a remote homeserver and their device list is not in the device list
cache. If we share a room with this user and we're not querying for
specific user we will update the cache
with their device list."""
destination_query = remote_queries_not_in_cache[destination] destination_query = remote_queries_not_in_cache[destination]
opentracing.set_tag("key_query", destination_query) opentracing.set_tag("key_query", destination_query)
# We first consider whether we wish to update the device list cache with
# the users device list. We want to track a user's devices when the
# authenticated user shares a room with the queried user and the query
# has not specified a particular device.
# If we update the cache for the queried user we remove them from further
# queries. We use the more efficient batched query_client_keys for all
# remaining users
user_ids_updated = []
for (user_id, device_list) in destination_query.items():
if user_id in user_ids_updated:
continue
if device_list:
continue
room_ids = yield self.store.get_rooms_for_user(user_id)
if not room_ids:
continue
# We've decided we're sharing a room with this user and should
# probably be tracking their device lists. However, we haven't
# done an initial sync on the device list so we do it now.
try:
user_devices = yield self.device_handler.device_list_updater.user_device_resync(
user_id
)
user_devices = user_devices["devices"]
for device in user_devices:
results[user_id] = {device["device_id"]: device["keys"]}
user_ids_updated.append(user_id)
except Exception as e:
failures[destination] = _exception_to_failure(e)
if len(destination_query) == len(user_ids_updated):
# We've updated all the users in the query and we do not need to
# make any further remote calls.
return
# Remove all the users from the query which we have updated
for user_id in user_ids_updated:
destination_query.pop(user_id)
try: try:
remote_result = yield self.federation.query_client_keys( remote_result = yield self.federation.query_client_keys(
destination, {"device_keys": destination_query}, timeout=timeout destination, {"device_keys": destination_query}, timeout=timeout
@ -153,7 +202,7 @@ class E2eKeysHandler(object):
for destination in remote_queries_not_in_cache for destination in remote_queries_not_in_cache
], ],
consumeErrors=True, consumeErrors=True,
) ).addErrback(unwrapFirstError)
) )
return {"device_keys": results, "failures": failures} return {"device_keys": results, "failures": failures}

View file

@ -978,6 +978,9 @@ class FederationHandler(BaseHandler):
except NotRetryingDestination as e: except NotRetryingDestination as e:
logger.info(str(e)) logger.info(str(e))
continue continue
except RequestSendFailed as e:
logger.info("Falied to get backfill from %s because %s", dom, e)
continue
except FederationDeniedError as e: except FederationDeniedError as e:
logger.info(e) logger.info(e)
continue continue
@ -1204,11 +1207,28 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_make_join_request(self, room_id, user_id): def on_make_join_request(self, origin, room_id, user_id):
""" We've received a /make_join/ request, so we create a partial """ We've received a /make_join/ request, so we create a partial
join event for the room and return that. We do *not* persist or join event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back. process it until the other server has signed it and sent it back.
Args:
origin (str): The (verified) server name of the requesting server.
room_id (str): Room to create join event in
user_id (str): The user to create the join for
Returns:
Deferred[FrozenEvent]
""" """
if get_domain_from_id(user_id) != origin:
logger.info(
"Got /make_join request for user %r from different origin %s, ignoring",
user_id,
origin,
)
raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
event_content = {"membership": Membership.JOIN} event_content = {"membership": Membership.JOIN}
room_version = yield self.store.get_room_version(room_id) room_version = yield self.store.get_room_version(room_id)
@ -1411,11 +1431,27 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_make_leave_request(self, room_id, user_id): def on_make_leave_request(self, origin, room_id, user_id):
""" We've received a /make_leave/ request, so we create a partial """ We've received a /make_leave/ request, so we create a partial
leave event for the room and return that. We do *not* persist or leave event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back. process it until the other server has signed it and sent it back.
Args:
origin (str): The (verified) server name of the requesting server.
room_id (str): Room to create leave event in
user_id (str): The user to create the leave for
Returns:
Deferred[FrozenEvent]
""" """
if get_domain_from_id(user_id) != origin:
logger.info(
"Got /make_leave request for user %r from different origin %s, ignoring",
user_id,
origin,
)
raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
room_version = yield self.store.get_room_version(room_id) room_version = yield self.store.get_room_version(room_id)
builder = self.event_builder_factory.new( builder = self.event_builder_factory.new(
room_version, room_version,
@ -2763,3 +2799,28 @@ class FederationHandler(BaseHandler):
) )
else: else:
return user_joined_room(self.distributor, user, room_id) return user_joined_room(self.distributor, user, room_id)
@defer.inlineCallbacks
def get_room_complexity(self, remote_room_hosts, room_id):
"""
Fetch the complexity of a remote room over federation.
Args:
remote_room_hosts (list[str]): The remote servers to ask.
room_id (str): The room ID to ask about.
Returns:
Deferred[dict] or Deferred[None]: Dict contains the complexity
metric versions, while None means we could not fetch the complexity.
"""
for host in remote_room_hosts:
res = yield self.federation_client.get_room_complexity(host, room_id)
# We got a result, return it.
if res:
defer.returnValue(res)
# We fell off the bottom, couldn't get the complexity from anyone. Oh
# well.
defer.returnValue(None)

View file

@ -126,9 +126,12 @@ class GroupsLocalHandler(object):
group_id, requester_user_id group_id, requester_user_id
) )
else: else:
try:
res = yield self.transport_client.get_group_summary( res = yield self.transport_client.get_group_summary(
get_domain_from_id(group_id), group_id, requester_user_id get_domain_from_id(group_id), group_id, requester_user_id
) )
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")
group_server_name = get_domain_from_id(group_id) group_server_name = get_domain_from_id(group_id)
@ -183,9 +186,12 @@ class GroupsLocalHandler(object):
content["user_profile"] = yield self.profile_handler.get_profile(user_id) content["user_profile"] = yield self.profile_handler.get_profile(user_id)
try:
res = yield self.transport_client.create_group( res = yield self.transport_client.create_group(
get_domain_from_id(group_id), group_id, user_id, content get_domain_from_id(group_id), group_id, user_id, content
) )
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")
remote_attestation = res["attestation"] remote_attestation = res["attestation"]
yield self.attestations.verify_attestation( yield self.attestations.verify_attestation(
@ -221,9 +227,12 @@ class GroupsLocalHandler(object):
group_server_name = get_domain_from_id(group_id) group_server_name = get_domain_from_id(group_id)
try:
res = yield self.transport_client.get_users_in_group( res = yield self.transport_client.get_users_in_group(
get_domain_from_id(group_id), group_id, requester_user_id get_domain_from_id(group_id), group_id, requester_user_id
) )
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")
chunk = res["chunk"] chunk = res["chunk"]
valid_entries = [] valid_entries = []
@ -258,9 +267,12 @@ class GroupsLocalHandler(object):
local_attestation = self.attestations.create_attestation(group_id, user_id) local_attestation = self.attestations.create_attestation(group_id, user_id)
content["attestation"] = local_attestation content["attestation"] = local_attestation
try:
res = yield self.transport_client.join_group( res = yield self.transport_client.join_group(
get_domain_from_id(group_id), group_id, user_id, content get_domain_from_id(group_id), group_id, user_id, content
) )
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")
remote_attestation = res["attestation"] remote_attestation = res["attestation"]
@ -299,9 +311,12 @@ class GroupsLocalHandler(object):
local_attestation = self.attestations.create_attestation(group_id, user_id) local_attestation = self.attestations.create_attestation(group_id, user_id)
content["attestation"] = local_attestation content["attestation"] = local_attestation
try:
res = yield self.transport_client.accept_group_invite( res = yield self.transport_client.accept_group_invite(
get_domain_from_id(group_id), group_id, user_id, content get_domain_from_id(group_id), group_id, user_id, content
) )
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")
remote_attestation = res["attestation"] remote_attestation = res["attestation"]
@ -338,6 +353,7 @@ class GroupsLocalHandler(object):
group_id, user_id, requester_user_id, content group_id, user_id, requester_user_id, content
) )
else: else:
try:
res = yield self.transport_client.invite_to_group( res = yield self.transport_client.invite_to_group(
get_domain_from_id(group_id), get_domain_from_id(group_id),
group_id, group_id,
@ -345,6 +361,8 @@ class GroupsLocalHandler(object):
requester_user_id, requester_user_id,
content, content,
) )
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")
return res return res
@ -398,6 +416,7 @@ class GroupsLocalHandler(object):
) )
else: else:
content["requester_user_id"] = requester_user_id content["requester_user_id"] = requester_user_id
try:
res = yield self.transport_client.remove_user_from_group( res = yield self.transport_client.remove_user_from_group(
get_domain_from_id(group_id), get_domain_from_id(group_id),
group_id, group_id,
@ -405,6 +424,8 @@ class GroupsLocalHandler(object):
user_id, user_id,
content, content,
) )
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")
return res return res
@ -435,9 +456,13 @@ class GroupsLocalHandler(object):
return {"groups": result} return {"groups": result}
else: else:
try:
bulk_result = yield self.transport_client.bulk_get_publicised_groups( bulk_result = yield self.transport_client.bulk_get_publicised_groups(
get_domain_from_id(user_id), [user_id] get_domain_from_id(user_id), [user_id]
) )
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")
result = bulk_result.get("users", {}).get(user_id) result = bulk_result.get("users", {}).get(user_id)
# TODO: Verify attestations # TODO: Verify attestations
return {"groups": result} return {"groups": result}

View file

@ -378,7 +378,11 @@ class EventCreationHandler(object):
# tolerate them in event_auth.check(). # tolerate them in event_auth.check().
prev_state_ids = yield context.get_prev_state_ids(self.store) prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True) prev_event = (
yield self.store.get_event(prev_event_id, allow_none=True)
if prev_event_id
else None
)
if not prev_event or prev_event.membership != Membership.JOIN: if not prev_event or prev_event.membership != Membership.JOIN:
logger.warning( logger.warning(
( (
@ -521,6 +525,8 @@ class EventCreationHandler(object):
""" """
prev_state_ids = yield context.get_prev_state_ids(self.store) prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_event_id = prev_state_ids.get((event.type, event.state_key)) prev_event_id = prev_state_ids.get((event.type, event.state_key))
if not prev_event_id:
return
prev_event = yield self.store.get_event(prev_event_id, allow_none=True) prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event: if not prev_event:
return return
@ -789,7 +795,6 @@ class EventCreationHandler(object):
get_prev_content=False, get_prev_content=False,
allow_rejected=False, allow_rejected=False,
allow_none=True, allow_none=True,
check_room_id=event.room_id,
) )
# we can make some additional checks now if we have the original event. # we can make some additional checks now if we have the original event.
@ -797,6 +802,9 @@ class EventCreationHandler(object):
if original_event.type == EventTypes.Create: if original_event.type == EventTypes.Create:
raise AuthError(403, "Redacting create events is not permitted") raise AuthError(403, "Redacting create events is not permitted")
if original_event.room_id != event.room_id:
raise SynapseError(400, "Cannot redact event from a different room")
prev_state_ids = yield context.get_prev_state_ids(self.store) prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.auth.compute_auth_events( auth_events_ids = yield self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True event, prev_state_ids, for_verification=True

View file

@ -333,7 +333,7 @@ class PresenceHandler(object):
"""Checks the presence of users that have timed out and updates as """Checks the presence of users that have timed out and updates as
appropriate. appropriate.
""" """
logger.info("Handling presence timeouts") logger.debug("Handling presence timeouts")
now = self.clock.time_msec() now = self.clock.time_msec()
# Fetch the list of users that *may* have timed out. Things may have # Fetch the list of users that *may* have timed out. Things may have

View file

@ -17,7 +17,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.handlers._base import BaseHandler from synapse.handlers._base import BaseHandler
from synapse.types import ReadReceipt from synapse.types import ReadReceipt, get_domain_from_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -40,7 +40,19 @@ class ReceiptsHandler(BaseHandler):
def _received_remote_receipt(self, origin, content): def _received_remote_receipt(self, origin, content):
"""Called when we receive an EDU of type m.receipt from a remote HS. """Called when we receive an EDU of type m.receipt from a remote HS.
""" """
receipts = [ receipts = []
for room_id, room_values in content.items():
for receipt_type, users in room_values.items():
for user_id, user_values in users.items():
if get_domain_from_id(user_id) != origin:
logger.info(
"Received receipt for user %r from server %s, ignoring",
user_id,
origin,
)
continue
receipts.append(
ReadReceipt( ReadReceipt(
room_id=room_id, room_id=room_id,
receipt_type=receipt_type, receipt_type=receipt_type,
@ -48,10 +60,7 @@ class ReceiptsHandler(BaseHandler):
event_ids=user_values["event_ids"], event_ids=user_values["event_ids"],
data=user_values.get("data", {}), data=user_values.get("data", {}),
) )
for room_id, room_values in content.items() )
for receipt_type, users in room_values.items()
for user_id, user_values in users.items()
]
yield self._handle_new_receipts(receipts) yield self._handle_new_receipts(receipts)

View file

@ -26,8 +26,7 @@ from unpaddedbase64 import decode_base64
from twisted.internet import defer from twisted.internet import defer
import synapse.server from synapse import types
import synapse.types
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, HttpResponseException, SynapseError from synapse.api.errors import AuthError, Codes, HttpResponseException, SynapseError
from synapse.types import RoomID, UserID from synapse.types import RoomID, UserID
@ -543,7 +542,7 @@ class RoomMemberHandler(object):
), "Sender (%s) must be same as requester (%s)" % (sender, requester.user) ), "Sender (%s) must be same as requester (%s)" % (sender, requester.user)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else: else:
requester = synapse.types.create_requester(target_user) requester = types.create_requester(target_user)
prev_event = yield self.event_creation_handler.deduplicate_state_event( prev_event = yield self.event_creation_handler.deduplicate_state_event(
event, context event, context
@ -945,6 +944,47 @@ class RoomMemberMasterHandler(RoomMemberHandler):
self.distributor.declare("user_joined_room") self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room") self.distributor.declare("user_left_room")
@defer.inlineCallbacks
def _is_remote_room_too_complex(self, room_id, remote_room_hosts):
"""
Check if complexity of a remote room is too great.
Args:
room_id (str)
remote_room_hosts (list[str])
Returns: bool of whether the complexity is too great, or None
if unable to be fetched
"""
max_complexity = self.hs.config.limit_remote_rooms.complexity
complexity = yield self.federation_handler.get_room_complexity(
remote_room_hosts, room_id
)
if complexity:
if complexity["v1"] > max_complexity:
return True
return False
return None
@defer.inlineCallbacks
def _is_local_room_too_complex(self, room_id):
"""
Check if the complexity of a local room is too great.
Args:
room_id (str)
Returns: bool
"""
max_complexity = self.hs.config.limit_remote_rooms.complexity
complexity = yield self.store.get_room_complexity(room_id)
if complexity["v1"] > max_complexity:
return True
return False
@defer.inlineCallbacks @defer.inlineCallbacks
def _remote_join(self, requester, remote_room_hosts, room_id, user, content): def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
"""Implements RoomMemberHandler._remote_join """Implements RoomMemberHandler._remote_join
@ -952,7 +992,6 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# filter ourselves out of remote_room_hosts: do_invite_join ignores it # filter ourselves out of remote_room_hosts: do_invite_join ignores it
# and if it is the only entry we'd like to return a 404 rather than a # and if it is the only entry we'd like to return a 404 rather than a
# 500. # 500.
remote_room_hosts = [ remote_room_hosts = [
host for host in remote_room_hosts if host != self.hs.hostname host for host in remote_room_hosts if host != self.hs.hostname
] ]
@ -960,6 +999,18 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if len(remote_room_hosts) == 0: if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers") raise SynapseError(404, "No known servers")
if self.hs.config.limit_remote_rooms.enabled:
# Fetch the room complexity
too_complex = yield self._is_remote_room_too_complex(
room_id, remote_room_hosts
)
if too_complex is True:
raise SynapseError(
code=400,
msg=self.hs.config.limit_remote_rooms.complexity_error,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
)
# We don't do an auth check if we are doing an invite # We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking # join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we # that we are allowed to join when we decide whether or not we
@ -969,6 +1020,31 @@ class RoomMemberMasterHandler(RoomMemberHandler):
) )
yield self._user_joined_room(user, room_id) yield self._user_joined_room(user, room_id)
# Check the room we just joined wasn't too large, if we didn't fetch the
# complexity of it before.
if self.hs.config.limit_remote_rooms.enabled:
if too_complex is False:
# We checked, and we're under the limit.
return
# Check again, but with the local state events
too_complex = yield self._is_local_room_too_complex(room_id)
if too_complex is False:
# We're under the limit.
return
# The room is too large. Leave.
requester = types.create_requester(user, None, False, None)
yield self.update_membership(
requester=requester, target=user, room_id=room_id, action="leave"
)
raise SynapseError(
code=400,
msg=self.hs.config.limit_remote_rooms.complexity_error,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
"""Implements RoomMemberHandler._remote_reject_invite """Implements RoomMemberHandler._remote_reject_invite

View file

@ -83,7 +83,7 @@ class TypingHandler(object):
self._room_typing = {} self._room_typing = {}
def _handle_timeouts(self): def _handle_timeouts(self):
logger.info("Checking for typing timeouts") logger.debug("Checking for typing timeouts")
now = self.clock.time_msec() now = self.clock.time_msec()

View file

@ -64,10 +64,6 @@ class MatrixFederationAgent(object):
tls_client_options_factory (ClientTLSOptionsFactory|None): tls_client_options_factory (ClientTLSOptionsFactory|None):
factory to use for fetching client tls options, or none to disable TLS. factory to use for fetching client tls options, or none to disable TLS.
_well_known_tls_policy (IPolicyForHTTPS|None):
TLS policy to use for fetching .well-known files. None to use a default
(browser-like) implementation.
_srv_resolver (SrvResolver|None): _srv_resolver (SrvResolver|None):
SRVResolver impl to use for looking up SRV records. None to use a default SRVResolver impl to use for looking up SRV records. None to use a default
implementation. implementation.
@ -81,7 +77,6 @@ class MatrixFederationAgent(object):
self, self,
reactor, reactor,
tls_client_options_factory, tls_client_options_factory,
_well_known_tls_policy=None,
_srv_resolver=None, _srv_resolver=None,
_well_known_cache=well_known_cache, _well_known_cache=well_known_cache,
): ):
@ -98,13 +93,12 @@ class MatrixFederationAgent(object):
self._pool.maxPersistentPerHost = 5 self._pool.maxPersistentPerHost = 5
self._pool.cachedConnectionTimeout = 2 * 60 self._pool.cachedConnectionTimeout = 2 * 60
agent_args = {}
if _well_known_tls_policy is not None:
# the param is called 'contextFactory', but actually passing a
# contextfactory is deprecated, and it expects an IPolicyForHTTPS.
agent_args["contextFactory"] = _well_known_tls_policy
_well_known_agent = RedirectAgent( _well_known_agent = RedirectAgent(
Agent(self._reactor, pool=self._pool, **agent_args) Agent(
self._reactor,
pool=self._pool,
contextFactory=tls_client_options_factory,
)
) )
self._well_known_agent = _well_known_agent self._well_known_agent = _well_known_agent

View file

@ -245,7 +245,9 @@ class JsonResource(HttpServer, resource.Resource):
isLeaf = True isLeaf = True
_PathEntry = collections.namedtuple("_PathEntry", ["pattern", "callback"]) _PathEntry = collections.namedtuple(
"_PathEntry", ["pattern", "callback", "servlet_classname"]
)
def __init__(self, hs, canonical_json=True): def __init__(self, hs, canonical_json=True):
resource.Resource.__init__(self) resource.Resource.__init__(self)
@ -255,12 +257,28 @@ class JsonResource(HttpServer, resource.Resource):
self.path_regexs = {} self.path_regexs = {}
self.hs = hs self.hs = hs
def register_paths(self, method, path_patterns, callback): def register_paths(self, method, path_patterns, callback, servlet_classname):
"""
Registers a request handler against a regular expression. Later request URLs are
checked against these regular expressions in order to identify an appropriate
handler for that request.
Args:
method (str): GET, POST etc
path_patterns (Iterable[str]): A list of regular expressions to which
the request URLs are compared.
callback (function): The handler for the request. Usually a Servlet
servlet_classname (str): The name of the handler to be used in prometheus
and opentracing logs.
"""
method = method.encode("utf-8") # method is bytes on py3 method = method.encode("utf-8") # method is bytes on py3
for path_pattern in path_patterns: for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern) logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append( self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback) self._PathEntry(path_pattern, callback, servlet_classname)
) )
def render(self, request): def render(self, request):
@ -275,13 +293,9 @@ class JsonResource(HttpServer, resource.Resource):
This checks if anyone has registered a callback for that method and This checks if anyone has registered a callback for that method and
path. path.
""" """
callback, group_dict = self._get_handler_for_request(request) callback, servlet_classname, group_dict = self._get_handler_for_request(request)
servlet_instance = getattr(callback, "__self__", None) # Make sure we have a name for this handler in prometheus.
if servlet_instance is not None:
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
request.request_metrics.name = servlet_classname request.request_metrics.name = servlet_classname
# Now trigger the callback. If it returns a response, we send it # Now trigger the callback. If it returns a response, we send it
@ -311,7 +325,8 @@ class JsonResource(HttpServer, resource.Resource):
request (twisted.web.http.Request): request (twisted.web.http.Request):
Returns: Returns:
Tuple[Callable, dict[unicode, unicode]]: callback method, and the Tuple[Callable, str, dict[unicode, unicode]]: callback method, the
label to use for that method in prometheus metrics, and the
dict mapping keys to path components as specified in the dict mapping keys to path components as specified in the
handler's path match regexp. handler's path match regexp.
@ -320,7 +335,7 @@ class JsonResource(HttpServer, resource.Resource):
None, or a tuple of (http code, response body). None, or a tuple of (http code, response body).
""" """
if request.method == b"OPTIONS": if request.method == b"OPTIONS":
return _options_handler, {} return _options_handler, "options_request_handler", {}
# Loop through all the registered callbacks to check if the method # Loop through all the registered callbacks to check if the method
# and path regex match # and path regex match
@ -328,10 +343,10 @@ class JsonResource(HttpServer, resource.Resource):
m = path_entry.pattern.match(request.path.decode("ascii")) m = path_entry.pattern.match(request.path.decode("ascii"))
if m: if m:
# We found a match! # We found a match!
return path_entry.callback, m.groupdict() return path_entry.callback, path_entry.servlet_classname, m.groupdict()
# Huh. No one wanted to handle that? Fiiiiiine. Send 400. # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
return _unrecognised_request_handler, {} return _unrecognised_request_handler, "unrecognised_request_handler", {}
def _send_response( def _send_response(
self, request, code, response_json_object, response_code_message=None self, request, code, response_json_object, response_code_message=None

View file

@ -166,7 +166,12 @@ def parse_string_from_args(
value = args[name][0] value = args[name][0]
if encoding: if encoding:
try:
value = value.decode(encoding) value = value.decode(encoding)
except ValueError:
raise SynapseError(
400, "Query parameter %r must be %s" % (name, encoding)
)
if allowed_values is not None and value not in allowed_values: if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % ( message = "Query parameter %r must be one of [%s]" % (
@ -290,11 +295,13 @@ class RestServlet(object):
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"): for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
if hasattr(self, "on_%s" % (method,)): if hasattr(self, "on_%s" % (method,)):
servlet_classname = self.__class__.__name__
method_handler = getattr(self, "on_%s" % (method,)) method_handler = getattr(self, "on_%s" % (method,))
http_server.register_paths( http_server.register_paths(
method, method,
patterns, patterns,
trace_servlet(self.__class__.__name__, method_handler), trace_servlet(servlet_classname, method_handler),
servlet_classname,
) )
else: else:

View file

@ -245,7 +245,13 @@ BASE_APPEND_OVERRIDE_RULES = [
"key": "type", "key": "type",
"pattern": "m.room.tombstone", "pattern": "m.room.tombstone",
"_id": "_tombstone", "_id": "_tombstone",
} },
{
"kind": "event_match",
"key": "state_key",
"pattern": "",
"_id": "_tombstone_statekey",
},
], ],
"actions": ["notify", {"set_tweak": "highlight", "value": True}], "actions": ["notify", {"set_tweak": "highlight", "value": True}],
}, },

View file

@ -205,7 +205,7 @@ class ReplicationEndpoint(object):
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
http_server.register_paths(method, [pattern], handler) http_server.register_paths(method, [pattern], handler, self.__class__.__name__)
def _cached_handler(self, request, txn_id, **kwargs): def _cached_handler(self, request, txn_id, **kwargs):
"""Called on new incoming requests when caching is enabled. Checks """Called on new incoming requests when caching is enabled. Checks

View file

@ -0,0 +1 @@
<html><body>Your account has been successfully renewed.</body><html>

View file

@ -0,0 +1 @@
<html><body>Invalid renewal token.</body><html>

View file

@ -59,9 +59,14 @@ class SendServerNoticeServlet(RestServlet):
def register(self, json_resource): def register(self, json_resource):
PATTERN = "^/_synapse/admin/v1/send_server_notice" PATTERN = "^/_synapse/admin/v1/send_server_notice"
json_resource.register_paths("POST", (re.compile(PATTERN + "$"),), self.on_POST)
json_resource.register_paths( json_resource.register_paths(
"PUT", (re.compile(PATTERN + "/(?P<txn_id>[^/]*)$"),), self.on_PUT "POST", (re.compile(PATTERN + "$"),), self.on_POST, self.__class__.__name__
)
json_resource.register_paths(
"PUT",
(re.compile(PATTERN + "/(?P<txn_id>[^/]*)$"),),
self.on_PUT,
self.__class__.__name__,
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -67,11 +67,17 @@ class RoomCreateRestServlet(TransactionRestServlet):
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
# define CORS for all of /rooms in RoomCreateRestServlet for simplicity # define CORS for all of /rooms in RoomCreateRestServlet for simplicity
http_server.register_paths( http_server.register_paths(
"OPTIONS", client_patterns("/rooms(?:/.*)?$", v1=True), self.on_OPTIONS "OPTIONS",
client_patterns("/rooms(?:/.*)?$", v1=True),
self.on_OPTIONS,
self.__class__.__name__,
) )
# define CORS for /createRoom[/txnid] # define CORS for /createRoom[/txnid]
http_server.register_paths( http_server.register_paths(
"OPTIONS", client_patterns("/createRoom(?:/.*)?$", v1=True), self.on_OPTIONS "OPTIONS",
client_patterns("/createRoom(?:/.*)?$", v1=True),
self.on_OPTIONS,
self.__class__.__name__,
) )
def on_PUT(self, request, txn_id): def on_PUT(self, request, txn_id):
@ -116,16 +122,28 @@ class RoomStateEventRestServlet(TransactionRestServlet):
) )
http_server.register_paths( http_server.register_paths(
"GET", client_patterns(state_key, v1=True), self.on_GET "GET",
client_patterns(state_key, v1=True),
self.on_GET,
self.__class__.__name__,
) )
http_server.register_paths( http_server.register_paths(
"PUT", client_patterns(state_key, v1=True), self.on_PUT "PUT",
client_patterns(state_key, v1=True),
self.on_PUT,
self.__class__.__name__,
) )
http_server.register_paths( http_server.register_paths(
"GET", client_patterns(no_state_key, v1=True), self.on_GET_no_state_key "GET",
client_patterns(no_state_key, v1=True),
self.on_GET_no_state_key,
self.__class__.__name__,
) )
http_server.register_paths( http_server.register_paths(
"PUT", client_patterns(no_state_key, v1=True), self.on_PUT_no_state_key "PUT",
client_patterns(no_state_key, v1=True),
self.on_PUT_no_state_key,
self.__class__.__name__,
) )
def on_GET_no_state_key(self, request, room_id, event_type): def on_GET_no_state_key(self, request, room_id, event_type):
@ -845,18 +863,23 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
with_get: True to also register respective GET paths for the PUTs. with_get: True to also register respective GET paths for the PUTs.
""" """
http_server.register_paths( http_server.register_paths(
"POST", client_patterns(regex_string + "$", v1=True), servlet.on_POST "POST",
client_patterns(regex_string + "$", v1=True),
servlet.on_POST,
servlet.__class__.__name__,
) )
http_server.register_paths( http_server.register_paths(
"PUT", "PUT",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True), client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_PUT, servlet.on_PUT,
servlet.__class__.__name__,
) )
if with_get: if with_get:
http_server.register_paths( http_server.register_paths(
"GET", "GET",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True), client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_GET, servlet.on_GET,
servlet.__class__.__name__,
) )

View file

@ -42,6 +42,8 @@ class AccountValidityRenewServlet(RestServlet):
self.hs = hs self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler() self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.success_html = hs.config.account_validity.account_renewed_html_content
self.failure_html = hs.config.account_validity.invalid_token_html_content
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -49,16 +51,23 @@ class AccountValidityRenewServlet(RestServlet):
raise SynapseError(400, "Missing renewal token") raise SynapseError(400, "Missing renewal token")
renewal_token = request.args[b"token"][0] renewal_token = request.args[b"token"][0]
yield self.account_activity_handler.renew_account(renewal_token.decode("utf8")) token_valid = yield self.account_activity_handler.renew_account(
renewal_token.decode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(
b"Content-Length", b"%d" % (len(AccountValidityRenewServlet.SUCCESS_HTML),)
) )
request.write(AccountValidityRenewServlet.SUCCESS_HTML)
if token_valid:
status_code = 200
response = self.success_html
else:
status_code = 404
response = self.failure_html
request.setResponseCode(status_code)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(response),))
request.write(response.encode("utf8"))
finish_request(request) finish_request(request)
return None defer.returnValue(None)
class AccountValiditySendMailServlet(RestServlet): class AccountValiditySendMailServlet(RestServlet):
@ -87,7 +96,7 @@ class AccountValiditySendMailServlet(RestServlet):
user_id = requester.user.to_string() user_id = requester.user.to_string()
yield self.account_activity_handler.send_renewal_email_to_user(user_id) yield self.account_activity_handler.send_renewal_email_to_user(user_id)
return (200, {}) defer.returnValue((200, {}))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):

View file

@ -72,11 +72,13 @@ class RelationSendServlet(RestServlet):
"POST", "POST",
client_patterns(self.PATTERN + "$", releases=()), client_patterns(self.PATTERN + "$", releases=()),
self.on_PUT_or_POST, self.on_PUT_or_POST,
self.__class__.__name__,
) )
http_server.register_paths( http_server.register_paths(
"PUT", "PUT",
client_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()), client_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()),
self.on_PUT, self.on_PUT,
self.__class__.__name__,
) )
def on_PUT(self, request, *args, **kwargs): def on_PUT(self, request, *args, **kwargs):

View file

@ -139,8 +139,11 @@ class EventsWorkerStore(SQLBaseStore):
If there is a mismatch, behave as per allow_none. If there is a mismatch, behave as per allow_none.
Returns: Returns:
Deferred : A FrozenEvent. Deferred[EventBase|None]
""" """
if not isinstance(event_id, str):
raise TypeError("Invalid event event_id %r" % (event_id,))
events = yield self.get_events_as_list( events = yield self.get_events_as_list(
[event_id], [event_id],
check_redacted=check_redacted, check_redacted=check_redacted,
@ -268,6 +271,14 @@ class EventsWorkerStore(SQLBaseStore):
) )
continue continue
if original_event.room_id != entry.event.room_id:
logger.info(
"Withholding redaction %s of event %s from a different room",
event_id,
redacted_event_id,
)
continue
if entry.event.internal_metadata.need_to_check_redaction(): if entry.event.internal_metadata.need_to_check_redaction():
original_domain = get_domain_from_id(original_event.sender) original_domain = get_domain_from_id(original_event.sender)
redaction_domain = get_domain_from_id(entry.event.sender) redaction_domain = get_domain_from_id(entry.event.sender)
@ -629,6 +640,10 @@ class EventsWorkerStore(SQLBaseStore):
# we choose to ignore redactions of m.room.create events. # we choose to ignore redactions of m.room.create events.
return None return None
if original_ev.type == "m.room.redaction":
# ... and redaction events
return None
redaction_map = yield self._get_events_from_cache_or_db(redactions) redaction_map = yield self._get_events_from_cache_or_db(redactions)
for redaction_id in redactions: for redaction_id in redactions:
@ -636,9 +651,21 @@ class EventsWorkerStore(SQLBaseStore):
if not redaction_entry: if not redaction_entry:
# we don't have the redaction event, or the redaction event was not # we don't have the redaction event, or the redaction event was not
# authorized. # authorized.
logger.debug(
"%s was redacted by %s but redaction not found/authed",
original_ev.event_id,
redaction_id,
)
continue continue
redaction_event = redaction_entry.event redaction_event = redaction_entry.event
if redaction_event.room_id != original_ev.room_id:
logger.debug(
"%s was redacted by %s but redaction was in a different room!",
original_ev.event_id,
redaction_id,
)
continue
# Starting in room version v3, some redactions need to be # Starting in room version v3, some redactions need to be
# rechecked if we didn't have the redacted event at the # rechecked if we didn't have the redacted event at the
@ -650,8 +677,15 @@ class EventsWorkerStore(SQLBaseStore):
redaction_event.internal_metadata.recheck_redaction = False redaction_event.internal_metadata.recheck_redaction = False
else: else:
# Senders don't match, so the event isn't actually redacted # Senders don't match, so the event isn't actually redacted
logger.debug(
"%s was redacted by %s but the senders don't match",
original_ev.event_id,
redaction_id,
)
continue continue
logger.debug("Redacting %s due to %s", original_ev.event_id, redaction_id)
# we found a good redaction event. Redact! # we found a good redaction event. Redact!
redacted_event = prune_event(original_ev) redacted_event = prune_event(original_ev)
redacted_event.unsigned["redacted_by"] = redaction_id redacted_event.unsigned["redacted_by"] = redaction_id

View file

@ -569,6 +569,27 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_id_servers_user_bound", desc="get_id_servers_user_bound",
) )
@cachedInlineCallbacks()
def get_user_deactivated_status(self, user_id):
"""Retrieve the value for the `deactivated` property for the provided user.
Args:
user_id (str): The ID of the user to retrieve the status for.
Returns:
defer.Deferred(bool): The requested value.
"""
res = yield self._simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="deactivated",
desc="get_user_deactivated_status",
)
# Convert the integer into a boolean.
return res == 1
class RegistrationStore( class RegistrationStore(
RegistrationWorkerStore, background_updates.BackgroundUpdateStore RegistrationWorkerStore, background_updates.BackgroundUpdateStore
@ -1317,24 +1338,3 @@ class RegistrationStore(
user_id, user_id,
deactivated, deactivated,
) )
@cachedInlineCallbacks()
def get_user_deactivated_status(self, user_id):
"""Retrieve the value for the `deactivated` property for the provided user.
Args:
user_id (str): The ID of the user to retrieve the status for.
Returns:
defer.Deferred(bool): The requested value.
"""
res = yield self._simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="deactivated",
desc="get_user_deactivated_status",
)
# Convert the integer into a boolean.
return res == 1

View file

@ -156,9 +156,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# then we can avoid a join, which is a Very Good Thing given how # then we can avoid a join, which is a Very Good Thing given how
# frequently this function gets called. # frequently this function gets called.
if self._current_state_events_membership_up_to_date: if self._current_state_events_membership_up_to_date:
# Note, rejected events will have a null membership field, so
# we we manually filter them out.
sql = """ sql = """
SELECT count(*), membership FROM current_state_events SELECT count(*), membership FROM current_state_events
WHERE type = 'm.room.member' AND room_id = ? WHERE type = 'm.room.member' AND room_id = ?
AND membership IS NOT NULL
GROUP BY membership GROUP BY membership
""" """
else: else:
@ -179,17 +182,28 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# we order by membership and then fairly arbitrarily by event_id so # we order by membership and then fairly arbitrarily by event_id so
# heroes are consistent # heroes are consistent
if self._current_state_events_membership_up_to_date:
# Note, rejected events will have a null membership field, so
# we we manually filter them out.
sql = """ sql = """
SELECT m.user_id, m.membership, m.event_id SELECT state_key, membership, event_id
FROM current_state_events
WHERE type = 'm.room.member' AND room_id = ?
AND membership IS NOT NULL
ORDER BY
CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
event_id ASC
LIMIT ?
"""
else:
sql = """
SELECT c.state_key, m.membership, c.event_id
FROM room_memberships as m FROM room_memberships as m
INNER JOIN current_state_events as c INNER JOIN current_state_events as c USING (room_id, event_id)
ON m.event_id = c.event_id
AND m.room_id = c.room_id
AND m.user_id = c.state_key
WHERE c.type = 'm.room.member' AND c.room_id = ? WHERE c.type = 'm.room.member' AND c.room_id = ?
ORDER BY ORDER BY
CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
m.event_id ASC c.event_id ASC
LIMIT ? LIMIT ?
""" """
@ -256,28 +270,35 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return invite return invite
return None return None
@defer.inlineCallbacks
def get_rooms_for_user_where_membership_is(self, user_id, membership_list): def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user """ Get all the rooms for this user where the membership for this user
matches one in the membership list. matches one in the membership list.
Filters out forgotten rooms.
Args: Args:
user_id (str): The user ID. user_id (str): The user ID.
membership_list (list): A list of synapse.api.constants.Membership membership_list (list): A list of synapse.api.constants.Membership
values which the user must be in. values which the user must be in.
Returns: Returns:
A list of dictionary objects, with room_id, membership and sender Deferred[list[RoomsForUser]]
defined.
""" """
if not membership_list: if not membership_list:
return defer.succeed(None) return defer.succeed(None)
return self.runInteraction( rooms = yield self.runInteraction(
"get_rooms_for_user_where_membership_is", "get_rooms_for_user_where_membership_is",
self._get_rooms_for_user_where_membership_is_txn, self._get_rooms_for_user_where_membership_is_txn,
user_id, user_id,
membership_list, membership_list,
) )
# Now we filter out forgotten rooms
forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
return [room for room in rooms if room.room_id not in forgotten_rooms]
def _get_rooms_for_user_where_membership_is_txn( def _get_rooms_for_user_where_membership_is_txn(
self, txn, user_id, membership_list self, txn, user_id, membership_list
): ):
@ -287,26 +308,33 @@ class RoomMemberWorkerStore(EventsWorkerStore):
results = [] results = []
if membership_list: if membership_list:
where_clause = "user_id = ? AND (%s) AND forgotten = 0" % ( if self._current_state_events_membership_up_to_date:
" OR ".join(["m.membership = ?" for _ in membership_list]), sql = """
SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
FROM current_state_events AS c
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
AND state_key = ?
AND c.membership IN (%s)
""" % (
",".join("?" * len(membership_list))
)
else:
sql = """
SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering
FROM current_state_events AS c
INNER JOIN room_memberships AS m USING (room_id, event_id)
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
AND state_key = ?
AND m.membership IN (%s)
""" % (
",".join("?" * len(membership_list))
) )
args = [user_id] txn.execute(sql, (user_id, *membership_list))
args.extend(membership_list)
sql = (
"SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering"
" FROM current_state_events as c"
" INNER JOIN room_memberships as m"
" ON m.event_id = c.event_id"
" INNER JOIN events as e"
" ON e.event_id = c.event_id"
" AND m.room_id = c.room_id"
" AND m.user_id = c.state_key"
" WHERE c.type = 'm.room.member' AND %s"
) % (where_clause,)
txn.execute(sql, args)
results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)] results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)]
if do_invite: if do_invite:
@ -637,6 +665,44 @@ class RoomMemberWorkerStore(EventsWorkerStore):
count = yield self.runInteraction("did_forget_membership", f) count = yield self.runInteraction("did_forget_membership", f)
return count == 0 return count == 0
@cached()
def get_forgotten_rooms_for_user(self, user_id):
"""Gets all rooms the user has forgotten.
Args:
user_id (str)
Returns:
Deferred[set[str]]
"""
def _get_forgotten_rooms_for_user_txn(txn):
# This is a slightly convoluted query that first looks up all rooms
# that the user has forgotten in the past, then rechecks that list
# to see if any have subsequently been updated. This is done so that
# we can use a partial index on `forgotten = 1` on the assumption
# that few users will actually forget many rooms.
#
# Note that a room is considered "forgotten" if *all* membership
# events for that user and room have the forgotten field set (as
# when a user forgets a room we update all rows for that user and
# room, not just the current one).
sql = """
SELECT room_id, (
SELECT count(*) FROM room_memberships
WHERE room_id = m.room_id AND user_id = m.user_id AND forgotten = 0
) AS count
FROM room_memberships AS m
WHERE user_id = ? AND forgotten = 1
GROUP BY room_id, user_id;
"""
txn.execute(sql, (user_id,))
return set(row[0] for row in txn if row[1] == 0)
return self.runInteraction(
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_rooms_user_has_been_in(self, user_id): def get_rooms_user_has_been_in(self, user_id):
"""Get all rooms that the user has ever been in. """Get all rooms that the user has ever been in.
@ -668,6 +734,13 @@ class RoomMemberStore(RoomMemberWorkerStore):
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
self._background_current_state_membership, self._background_current_state_membership,
) )
self.register_background_index_update(
"room_membership_forgotten_idx",
index_name="room_memberships_user_room_forgotten",
table="room_memberships",
columns=["user_id", "room_id"],
where_clause="forgotten = 1",
)
def _store_room_members_txn(self, txn, events, backfilled): def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database. """Store a room member in the database.
@ -769,6 +842,9 @@ class RoomMemberStore(RoomMemberWorkerStore):
txn.execute(sql, (user_id, room_id)) txn.execute(sql, (user_id, room_id))
self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id)) self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
self._invalidate_cache_and_stream(
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
return self.runInteraction("forget_membership", f) return self.runInteraction("forget_membership", f)
@ -859,7 +935,7 @@ class RoomMemberStore(RoomMemberWorkerStore):
while processed < batch_size: while processed < batch_size:
txn.execute( txn.execute(
""" """
SELECT MIN(room_id) FROM rooms WHERE room_id > ? SELECT MIN(room_id) FROM current_state_events WHERE room_id > ?
""", """,
(last_processed_room,), (last_processed_room,),
) )
@ -870,10 +946,10 @@ class RoomMemberStore(RoomMemberWorkerStore):
next_room, = row next_room, = row
sql = """ sql = """
UPDATE current_state_events AS c UPDATE current_state_events
SET membership = ( SET membership = (
SELECT membership FROM room_memberships SELECT membership FROM room_memberships
WHERE event_id = c.event_id WHERE event_id = current_state_events.event_id
) )
WHERE room_id = ? WHERE room_id = ?
""" """

View file

@ -20,6 +20,3 @@
-- for membership events. (Will also be null for membership events until the -- for membership events. (Will also be null for membership events until the
-- background update job has finished). -- background update job has finished).
ALTER TABLE current_state_events ADD membership TEXT; ALTER TABLE current_state_events ADD membership TEXT;
INSERT INTO background_updates (update_name, progress_json) VALUES
('current_state_events_membership', '{}');

View file

@ -0,0 +1,24 @@
/* Copyright 2019 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- We add membership to current state so that we don't need to join against
-- room_memberships, which can be surprisingly costly (we do such queries
-- very frequently).
-- This will be null for non-membership events and the content.membership key
-- for membership events. (Will also be null for membership events until the
-- background update job has finished).
INSERT INTO background_updates (update_name, progress_json) VALUES
('current_state_events_membership', '{}');

View file

@ -0,0 +1,18 @@
/* Copyright 2019 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Adds an index on room_memberships for fetching all forgotten rooms for a user
INSERT INTO background_updates (update_name, progress_json) VALUES
('room_membership_forgotten_idx', '{}');

View file

@ -211,8 +211,7 @@ class StatsStore(StateDeltasStore):
avatar_id = current_state_ids.get((EventTypes.RoomAvatar, "")) avatar_id = current_state_ids.get((EventTypes.RoomAvatar, ""))
canonical_alias_id = current_state_ids.get((EventTypes.CanonicalAlias, "")) canonical_alias_id = current_state_ids.get((EventTypes.CanonicalAlias, ""))
state_events = yield self.get_events( event_ids = [
[
join_rules_id, join_rules_id,
history_visibility_id, history_visibility_id,
encryption_id, encryption_id,
@ -221,6 +220,9 @@ class StatsStore(StateDeltasStore):
avatar_id, avatar_id,
canonical_alias_id, canonical_alias_id,
] ]
state_events = yield self.get_events(
[ev for ev in event_ids if ev is not None]
) )
def _get_or_none(event_id, arg): def _get_or_none(event_id, arg):

View file

@ -59,7 +59,7 @@ class Clock(object):
"""Returns the current system time in miliseconds since epoch.""" """Returns the current system time in miliseconds since epoch."""
return int(self.time() * 1000) return int(self.time() * 1000)
def looping_call(self, f, msec): def looping_call(self, f, msec, *args, **kwargs):
"""Call a function repeatedly. """Call a function repeatedly.
Waits `msec` initially before calling `f` for the first time. Waits `msec` initially before calling `f` for the first time.
@ -70,8 +70,10 @@ class Clock(object):
Args: Args:
f(function): The function to call repeatedly. f(function): The function to call repeatedly.
msec(float): How long to wait between calls in milliseconds. msec(float): How long to wait between calls in milliseconds.
*args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function.
""" """
call = task.LoopingCall(f) call = task.LoopingCall(f, *args, **kwargs)
call.clock = self._reactor call.clock = self._reactor
d = call.start(msec / 1000.0, now=False) d = call.start(msec / 1000.0, now=False)
d.addErrback(log_failure, "Looping call died", consumeErrors=False) d.addErrback(log_failure, "Looping call died", consumeErrors=False)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -51,7 +52,19 @@ response_cache_evicted = Gauge(
response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"]) response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
def register_cache(cache_type, cache_name, cache): def register_cache(cache_type, cache_name, cache, collect_callback=None):
"""Register a cache object for metric collection.
Args:
cache_type (str):
cache_name (str): name of the cache
cache (object): cache itself
collect_callback (callable|None): if not None, a function which is called during
metric collection to update additional metrics.
Returns:
CacheMetric: an object which provides inc_{hits,misses,evictions} methods
"""
# Check if the metric is already registered. Unregister it, if so. # Check if the metric is already registered. Unregister it, if so.
# This usually happens during tests, as at runtime these caches are # This usually happens during tests, as at runtime these caches are
@ -90,6 +103,8 @@ def register_cache(cache_type, cache_name, cache):
cache_hits.labels(cache_name).set(self.hits) cache_hits.labels(cache_name).set(self.hits)
cache_evicted.labels(cache_name).set(self.evicted_size) cache_evicted.labels(cache_name).set(self.evicted_size)
cache_total.labels(cache_name).set(self.hits + self.misses) cache_total.labels(cache_name).set(self.hits + self.misses)
if collect_callback:
collect_callback()
except Exception as e: except Exception as e:
logger.warn("Error calculating metrics for %s: %s", cache_name, e) logger.warn("Error calculating metrics for %s: %s", cache_name, e)
raise raise

View file

@ -19,8 +19,9 @@ import logging
import threading import threading
from collections import namedtuple from collections import namedtuple
import six from six import itervalues
from six import itervalues, string_types
from prometheus_client import Gauge
from twisted.internet import defer from twisted.internet import defer
@ -30,13 +31,18 @@ from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import get_cache_factor_for from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.stringutils import to_ascii
from . import register_cache from . import register_cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
cache_pending_metric = Gauge(
"synapse_util_caches_cache_pending",
"Number of lookups currently pending for this cache",
["name"],
)
_CacheSentinel = object() _CacheSentinel = object()
@ -82,11 +88,19 @@ class Cache(object):
self.name = name self.name = name
self.keylen = keylen self.keylen = keylen
self.thread = None self.thread = None
self.metrics = register_cache("cache", name, self.cache) self.metrics = register_cache(
"cache",
name,
self.cache,
collect_callback=self._metrics_collection_callback,
)
def _on_evicted(self, evicted_count): def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count) self.metrics.inc_evictions(evicted_count)
def _metrics_collection_callback(self):
cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
def check_thread(self): def check_thread(self):
expected_thread = self.thread expected_thread = self.thread
if expected_thread is None: if expected_thread is None:
@ -108,7 +122,7 @@ class Cache(object):
update_metrics (bool): whether to update the cache hit rate metrics update_metrics (bool): whether to update the cache hit rate metrics
Returns: Returns:
Either a Deferred or the raw result Either an ObservableDeferred or the raw result
""" """
callbacks = [callback] if callback else [] callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel) val = self._pending_deferred_cache.get(key, _CacheSentinel)
@ -132,9 +146,14 @@ class Cache(object):
return default return default
def set(self, key, value, callback=None): def set(self, key, value, callback=None):
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")
callbacks = [callback] if callback else [] callbacks = [callback] if callback else []
self.check_thread() self.check_thread()
entry = CacheEntry(deferred=value, callbacks=callbacks) observable = ObservableDeferred(value, consumeErrors=True)
observer = defer.maybeDeferred(observable.observe)
entry = CacheEntry(deferred=observable, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None) existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry: if existing_entry:
@ -142,11 +161,16 @@ class Cache(object):
self._pending_deferred_cache[key] = entry self._pending_deferred_cache[key] = entry
def shuffle(result): def compare_and_pop():
"""Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it.
Returns true if the entries matched.
"""
existing_entry = self._pending_deferred_cache.pop(key, None) existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry: if existing_entry is entry:
self.cache.set(key, result, entry.callbacks) return True
else:
# oops, the _pending_deferred_cache has been updated since # oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date. # we started our query, so we are out of date.
# #
@ -156,6 +180,12 @@ class Cache(object):
if existing_entry is not None: if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry self._pending_deferred_cache[key] = existing_entry
return False
def cb(result):
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
# we're not going to put this entry into the cache, so need # we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called. # to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was # That was probably done when _pending_deferred_cache was
@ -163,9 +193,16 @@ class Cache(object):
# `invalidate` being previously called, in which case it may # `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now. # not have been. Either way, let's double-check now.
entry.invalidate() entry.invalidate()
return result
entry.deferred.addCallback(shuffle) def eb(_fail):
compare_and_pop()
entry.invalidate()
# once the deferred completes, we can move the entry from the
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
return observable
def prefill(self, key, value, callback=None): def prefill(self, key, value, callback=None):
callbacks = [callback] if callback else [] callbacks = [callback] if callback else []
@ -398,20 +435,10 @@ class CacheDescriptor(_CacheDescriptorBase):
ret.addErrback(onErr) ret.addErrback(onErr)
# If our cache_key is a string on py2, try to convert to ascii result_d = cache.set(cache_key, ret, callback=invalidate_callback)
# to save a bit of space in large caches. Py3 does this
# internally automatically.
if six.PY2 and isinstance(cache_key, string_types):
cache_key = to_ascii(cache_key)
result_d = ObservableDeferred(ret, consumeErrors=True)
cache.set(cache_key, result_d, callback=invalidate_callback)
observer = result_d.observe() observer = result_d.observe()
if isinstance(observer, defer.Deferred):
return make_deferred_yieldable(observer) return make_deferred_yieldable(observer)
else:
return observer
if self.num_args == 1: if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0]) wrapped.invalidate = lambda key: cache.invalidate(key[0])
@ -527,7 +554,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
missing.add(arg) missing.add(arg)
if missing: if missing:
# we need an observable deferred for each entry in the list, # we need a deferred for each entry in the list,
# which we put in the cache. Each deferred resolves with the # which we put in the cache. Each deferred resolves with the
# relevant result for that key. # relevant result for that key.
deferreds_map = {} deferreds_map = {}
@ -535,8 +562,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
deferred = defer.Deferred() deferred = defer.Deferred()
deferreds_map[arg] = deferred deferreds_map[arg] = deferred
key = arg_to_cache_key(arg) key = arg_to_cache_key(arg)
observable = ObservableDeferred(deferred) cache.set(key, deferred, callback=invalidate_callback)
cache.set(key, observable, callback=invalidate_callback)
def complete_all(res): def complete_all(res):
# the wrapped function has completed. It returns a # the wrapped function has completed. It returns a

View file

@ -13,12 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.federation.transport import server from synapse.federation.transport import server
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client.v1 import login, room from synapse.rest.client.v1 import login, room
from synapse.types import UserID
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from tests import unittest from tests import unittest
@ -33,9 +37,8 @@ class RoomComplexityTests(unittest.HomeserverTestCase):
] ]
def default_config(self, name="test"): def default_config(self, name="test"):
config = super(RoomComplexityTests, self).default_config(name=name) config = super().default_config(name=name)
config["limit_large_remote_room_joins"] = True config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05}
config["limit_large_remote_room_complexity"] = 0.05
return config return config
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor, clock, homeserver):
@ -88,3 +91,71 @@ class RoomComplexityTests(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code) self.assertEquals(200, channel.code)
complexity = channel.json_body["v1"] complexity = channel.json_body["v1"]
self.assertEqual(complexity, 1.23) self.assertEqual(complexity, 1.23)
def test_join_too_large(self):
u1 = self.register_user("u1", "pass")
handler = self.hs.get_room_member_handler()
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1))
d = handler._remote_join(
None,
["otherserver.example"],
"roomid",
UserID.from_string(u1),
{"membership": "join"},
)
self.pump()
# The request failed with a SynapseError saying the resource limit was
# exceeded.
f = self.get_failure(d, SynapseError)
self.assertEqual(f.value.code, 400, f.value)
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_join_too_large_once_joined(self):
u1 = self.register_user("u1", "pass")
u1_token = self.login("u1", "pass")
# Ok, this might seem a bit weird -- I want to test that we actually
# leave the room, but I don't want to simulate two servers. So, we make
# a local room, which we say we're joining remotely, even if there's no
# remote, because we mock that out. Then, we'll leave the (actually
# local) room, which will be propagated over federation in a real
# scenario.
room_1 = self.helper.create_room_as(u1, tok=u1_token)
handler = self.hs.get_room_member_handler()
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed(None))
handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1))
# Artificially raise the complexity
self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed(
600
)
d = handler._remote_join(
None,
["otherserver.example"],
room_1,
UserID.from_string(u1),
{"membership": "join"},
)
self.pump()
# The request failed with a SynapseError saying the resource limit was
# exceeded.
f = self.get_failure(d, SynapseError)
self.assertEqual(f.value.code, 400)
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)

View file

@ -44,7 +44,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
hs_config["max_mau_value"] = 50 hs_config["max_mau_value"] = 50
hs_config["limit_usage_by_mau"] = True hs_config["limit_usage_by_mau"] = True
hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True) hs = self.setup_test_homeserver(config=hs_config)
return hs return hs
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):

View file

@ -75,7 +75,6 @@ class MatrixFederationAgentTests(TestCase):
config_dict = default_config("test", parse=False) config_dict = default_config("test", parse=False)
config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()] config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()]
# config_dict["trusted_key_servers"] = []
self._config = config = HomeServerConfig() self._config = config = HomeServerConfig()
config.parse_config_dict(config_dict, "", "") config.parse_config_dict(config_dict, "", "")
@ -83,7 +82,6 @@ class MatrixFederationAgentTests(TestCase):
self.agent = MatrixFederationAgent( self.agent = MatrixFederationAgent(
reactor=self.reactor, reactor=self.reactor,
tls_client_options_factory=ClientTLSOptionsFactory(config), tls_client_options_factory=ClientTLSOptionsFactory(config),
_well_known_tls_policy=TrustingTLSPolicyForHTTPS(),
_srv_resolver=self.mock_resolver, _srv_resolver=self.mock_resolver,
_well_known_cache=self.well_known_cache, _well_known_cache=self.well_known_cache,
) )
@ -691,16 +689,18 @@ class MatrixFederationAgentTests(TestCase):
not signed by a CA not signed by a CA
""" """
# we use the same test server as the other tests, but use an agent # we use the same test server as the other tests, but use an agent with
# with _well_known_tls_policy left to the default, which will not # the config left to the default, which will not trust it (since the
# trust it (since the presented cert is signed by a test CA) # presented cert is signed by a test CA)
self.mock_resolver.resolve_service.side_effect = lambda _: [] self.mock_resolver.resolve_service.side_effect = lambda _: []
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
config = default_config("test", parse=True)
agent = MatrixFederationAgent( agent = MatrixFederationAgent(
reactor=self.reactor, reactor=self.reactor,
tls_client_options_factory=ClientTLSOptionsFactory(self._config), tls_client_options_factory=ClientTLSOptionsFactory(config),
_srv_resolver=self.mock_resolver, _srv_resolver=self.mock_resolver,
_well_known_cache=self.well_known_cache, _well_known_cache=self.well_known_cache,
) )

View file

@ -323,6 +323,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"renew_at": 172800000, # Time in ms for 2 days "renew_at": 172800000, # Time in ms for 2 days
"renew_by_email_enabled": True, "renew_by_email_enabled": True,
"renew_email_subject": "Renew your account", "renew_email_subject": "Renew your account",
"account_renewed_html_path": "account_renewed.html",
"invalid_token_html_path": "invalid_token.html",
} }
# Email config. # Email config.
@ -373,6 +375,19 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.render(request) self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
# Check that we're getting HTML back.
content_type = None
for header in channel.result.get("headers", []):
if header[0] == b"Content-Type":
content_type = header[1]
self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result)
# Check that the HTML we're getting is the one we expect on a successful renewal.
expected_html = self.hs.config.account_validity.account_renewed_html_content
self.assertEqual(
channel.result["body"], expected_html.encode("utf8"), channel.result
)
# Move 3 days forward. If the renewal failed, every authed request with # Move 3 days forward. If the renewal failed, every authed request with
# our access token should be denied from now, otherwise they should # our access token should be denied from now, otherwise they should
# succeed. # succeed.
@ -381,6 +396,28 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.render(request) self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
def test_renewal_invalid_token(self):
# Hit the renewal endpoint with an invalid token and check that it behaves as
# expected, i.e. that it responds with 404 Not Found and the correct HTML.
url = "/_matrix/client/unstable/account_validity/renew?token=123"
request, channel = self.make_request(b"GET", url)
self.render(request)
self.assertEquals(channel.result["code"], b"404", channel.result)
# Check that we're getting HTML back.
content_type = None
for header in channel.result.get("headers", []):
if header[0] == b"Content-Type":
content_type = header[1]
self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result)
# Check that the HTML we're getting is the one we expect when using an
# invalid/unknown token.
expected_html = self.hs.config.account_validity.invalid_token_html_content
self.assertEqual(
channel.result["body"], expected_html.encode("utf8"), channel.result
)
def test_manual_email_send(self): def test_manual_email_send(self):
self.email_attempts = [] self.email_attempts = []

View file

@ -36,7 +36,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"room_name": "Server Notices", "room_name": "Server Notices",
} }
hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True) hs = self.setup_test_homeserver(config=hs_config)
return hs return hs
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,23 +17,21 @@
from mock import Mock from mock import Mock
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID from synapse.types import RoomID, UserID
from tests import unittest from tests import unittest
from tests.utils import create_room, setup_test_homeserver from tests.utils import create_room
class RedactionTestCase(unittest.TestCase): class RedactionTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks def make_homeserver(self, reactor, clock):
def setUp(self): return self.setup_test_homeserver(
hs = yield setup_test_homeserver( resource_for_federation=Mock(), http_client=None
self.addCleanup, resource_for_federation=Mock(), http_client=None
) )
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
@ -42,11 +41,12 @@ class RedactionTestCase(unittest.TestCase):
self.room1 = RoomID.from_string("!abc123:test") self.room1 = RoomID.from_string("!abc123:test")
yield create_room(hs, self.room1.to_string(), self.u_alice.to_string()) self.get_success(
create_room(hs, self.room1.to_string(), self.u_alice.to_string())
)
self.depth = 1 self.depth = 1
@defer.inlineCallbacks
def inject_room_member( def inject_room_member(
self, room, user, membership, replaces_state=None, extra_content={} self, room, user, membership, replaces_state=None, extra_content={}
): ):
@ -63,15 +63,14 @@ class RedactionTestCase(unittest.TestCase):
}, },
) )
event, context = yield self.event_creation_handler.create_new_client_event( event, context = self.get_success(
builder self.event_creation_handler.create_new_client_event(builder)
) )
yield self.store.persist_event(event, context) self.get_success(self.store.persist_event(event, context))
return event return event
@defer.inlineCallbacks
def inject_message(self, room, user, body): def inject_message(self, room, user, body):
self.depth += 1 self.depth += 1
@ -86,15 +85,14 @@ class RedactionTestCase(unittest.TestCase):
}, },
) )
event, context = yield self.event_creation_handler.create_new_client_event( event, context = self.get_success(
builder self.event_creation_handler.create_new_client_event(builder)
) )
yield self.store.persist_event(event, context) self.get_success(self.store.persist_event(event, context))
return event return event
@defer.inlineCallbacks
def inject_redaction(self, room, event_id, user, reason): def inject_redaction(self, room, event_id, user, reason):
builder = self.event_builder_factory.for_room_version( builder = self.event_builder_factory.for_room_version(
RoomVersions.V1, RoomVersions.V1,
@ -108,20 +106,21 @@ class RedactionTestCase(unittest.TestCase):
}, },
) )
event, context = yield self.event_creation_handler.create_new_client_event( event, context = self.get_success(
builder self.event_creation_handler.create_new_client_event(builder)
) )
yield self.store.persist_event(event, context) self.get_success(self.store.persist_event(event, context))
@defer.inlineCallbacks
def test_redact(self): def test_redact(self):
yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) self.get_success(
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
)
msg_event = yield self.inject_message(self.room1, self.u_alice, "t") msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t"))
# Check event has not been redacted: # Check event has not been redacted:
event = yield self.store.get_event(msg_event.event_id) event = self.get_success(self.store.get_event(msg_event.event_id))
self.assertObjectHasAttributes( self.assertObjectHasAttributes(
{ {
@ -136,11 +135,11 @@ class RedactionTestCase(unittest.TestCase):
# Redact event # Redact event
reason = "Because I said so" reason = "Because I said so"
yield self.inject_redaction( self.get_success(
self.room1, msg_event.event_id, self.u_alice, reason self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
) )
event = yield self.store.get_event(msg_event.event_id) event = self.get_success(self.store.get_event(msg_event.event_id))
self.assertEqual(msg_event.event_id, event.event_id) self.assertEqual(msg_event.event_id, event.event_id)
@ -164,15 +163,18 @@ class RedactionTestCase(unittest.TestCase):
event.unsigned["redacted_because"], event.unsigned["redacted_because"],
) )
@defer.inlineCallbacks
def test_redact_join(self): def test_redact_join(self):
yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) self.get_success(
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
msg_event = yield self.inject_room_member(
self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}
) )
event = yield self.store.get_event(msg_event.event_id) msg_event = self.get_success(
self.inject_room_member(
self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}
)
)
event = self.get_success(self.store.get_event(msg_event.event_id))
self.assertObjectHasAttributes( self.assertObjectHasAttributes(
{ {
@ -187,13 +189,13 @@ class RedactionTestCase(unittest.TestCase):
# Redact event # Redact event
reason = "Because I said so" reason = "Because I said so"
yield self.inject_redaction( self.get_success(
self.room1, msg_event.event_id, self.u_alice, reason self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
) )
# Check redaction # Check redaction
event = yield self.store.get_event(msg_event.event_id) event = self.get_success(self.store.get_event(msg_event.event_id))
self.assertTrue("redacted_because" in event.unsigned) self.assertTrue("redacted_because" in event.unsigned)

View file

@ -20,7 +20,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID from synapse.types import Requester, RoomID, UserID
from tests import unittest from tests import unittest
from tests.utils import create_room, setup_test_homeserver from tests.utils import create_room, setup_test_homeserver
@ -84,3 +84,38 @@ class RoomMemberStoreTestCase(unittest.TestCase):
) )
], ],
) )
class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.store = homeserver.get_datastore()
self.room_creator = homeserver.get_room_creation_handler()
def test_can_rerun_update(self):
# First make sure we have completed all updates.
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)
# Now let's create a room, which will insert a membership
user = UserID("alice", "test")
requester = Requester(user, None, False, None, None)
self.get_success(self.room_creator.create_room(requester, {}))
# Register the background update to run again.
self.get_success(
self.store._simple_insert(
table="background_updates",
values={
"update_name": "current_state_events_membership",
"progress_json": "{}",
"depends_on": None,
},
)
)
# ... and tell the DataStore that it hasn't finished all updates yet
self.store._all_done = False
# Now let's actually drive the updates to completion
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)

View file

@ -61,7 +61,10 @@ class JsonResourceTests(unittest.TestCase):
res = JsonResource(self.homeserver) res = JsonResource(self.homeserver)
res.register_paths( res.register_paths(
"GET", [re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")], _callback "GET",
[re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")],
_callback,
"test_servlet",
) )
request, channel = make_request( request, channel = make_request(
@ -82,7 +85,9 @@ class JsonResourceTests(unittest.TestCase):
raise Exception("boo") raise Exception("boo")
res = JsonResource(self.homeserver) res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) res.register_paths(
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
render(request, res, self.reactor) render(request, res, self.reactor)
@ -105,7 +110,9 @@ class JsonResourceTests(unittest.TestCase):
return make_deferred_yieldable(d) return make_deferred_yieldable(d)
res = JsonResource(self.homeserver) res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) res.register_paths(
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
render(request, res, self.reactor) render(request, res, self.reactor)
@ -122,7 +129,9 @@ class JsonResourceTests(unittest.TestCase):
raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN) raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN)
res = JsonResource(self.homeserver) res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) res.register_paths(
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
render(request, res, self.reactor) render(request, res, self.reactor)
@ -143,7 +152,9 @@ class JsonResourceTests(unittest.TestCase):
self.fail("shouldn't ever get here") self.fail("shouldn't ever get here")
res = JsonResource(self.homeserver) res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) res.register_paths(
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar") request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar")
render(request, res, self.reactor) render(request, res, self.reactor)

View file

@ -23,8 +23,6 @@ from mock import Mock
from canonicaljson import json from canonicaljson import json
import twisted
import twisted.logger
from twisted.internet.defer import Deferred, succeed from twisted.internet.defer import Deferred, succeed
from twisted.python.threadpool import ThreadPool from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest from twisted.trial import unittest
@ -80,10 +78,6 @@ class TestCase(unittest.TestCase):
@around(self) @around(self)
def setUp(orig): def setUp(orig):
# enable debugging of delayed calls - this means that we get a
# traceback when a unit test exits leaving things on the reactor.
twisted.internet.base.DelayedCall.debug = True
# if we're not starting in the sentinel logcontext, then to be honest # if we're not starting in the sentinel logcontext, then to be honest
# all future bets are off. # all future bets are off.
if LoggingContext.current_context() is not LoggingContext.sentinel: if LoggingContext.current_context() is not LoggingContext.sentinel:

View file

@ -27,6 +27,7 @@ from synapse.logging.context import (
make_deferred_yieldable, make_deferred_yieldable,
) )
from synapse.util.caches import descriptors from synapse.util.caches import descriptors
from synapse.util.caches.descriptors import cached
from tests import unittest from tests import unittest
@ -55,12 +56,15 @@ class CacheTestCase(unittest.TestCase):
d2 = defer.Deferred() d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1)) cache.set("key2", d2, partial(record_callback, 1))
# lookup should return the deferreds # lookup should return observable deferreds
self.assertIs(cache.get("key1"), d1) self.assertFalse(cache.get("key1").has_called())
self.assertIs(cache.get("key2"), d2) self.assertFalse(cache.get("key2").has_called())
# let one of the lookups complete # let one of the lookups complete
d2.callback("result2") d2.callback("result2")
# for now at least, the cache will return real results rather than an
# observabledeferred
self.assertEqual(cache.get("key2"), "result2") self.assertEqual(cache.get("key2"), "result2")
# now do the invalidation # now do the invalidation
@ -146,6 +150,28 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r, "chips") self.assertEqual(r, "chips")
obj.mock.assert_not_called() obj.mock.assert_not_called()
def test_cache_with_sync_exception(self):
"""If the wrapped function throws synchronously, things should continue to work
"""
class Cls(object):
@cached()
def fn(self, arg1):
raise SynapseError(100, "mai spoon iz too big!!1")
obj = Cls()
# this should fail immediately
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
# ... leaving the cache empty
self.assertEqual(len(obj.fn.cache.cache), 0)
# and a second call should result in a second exception
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
def test_cache_logcontexts(self): def test_cache_logcontexts(self):
"""Check that logcontexts are set and restored correctly when """Check that logcontexts are set and restored correctly when
using the cache.""" using the cache."""
@ -222,6 +248,9 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(LoggingContext.current_context(), c1) self.assertEqual(LoggingContext.current_context(), c1)
# the cache should now be empty
self.assertEqual(len(obj.fn.cache.cache), 0)
obj = Cls() obj = Cls()
# set off a deferred which will do a cache lookup # set off a deferred which will do a cache lookup
@ -268,6 +297,61 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r, "chips") self.assertEqual(r, "chips")
obj.mock.assert_not_called() obj.mock.assert_not_called()
def test_cache_iterable(self):
class Cls(object):
def __init__(self):
self.mock = mock.Mock()
@descriptors.cached(iterable=True)
def fn(self, arg1, arg2):
return self.mock(arg1, arg2)
obj = Cls()
obj.mock.return_value = ["spam", "eggs"]
r = obj.fn(1, 2)
self.assertEqual(r, ["spam", "eggs"])
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = ["chips"]
r = obj.fn(1, 3)
self.assertEqual(r, ["chips"])
obj.mock.assert_called_once_with(1, 3)
obj.mock.reset_mock()
# the two values should now be cached
self.assertEqual(len(obj.fn.cache.cache), 3)
r = obj.fn(1, 2)
self.assertEqual(r, ["spam", "eggs"])
r = obj.fn(1, 3)
self.assertEqual(r, ["chips"])
obj.mock.assert_not_called()
def test_cache_iterable_with_sync_exception(self):
"""If the wrapped function throws synchronously, things should continue to work
"""
class Cls(object):
@descriptors.cached(iterable=True)
def fn(self, arg1):
raise SynapseError(100, "mai spoon iz too big!!1")
obj = Cls()
# this should fail immediately
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
# ... leaving the cache empty
self.assertEqual(len(obj.fn.cache.cache), 0)
# and a second call should result in a second exception
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
class CachedListDescriptorTestCase(unittest.TestCase): class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -126,7 +126,6 @@ def default_config(name, parse=False):
"enable_registration": True, "enable_registration": True,
"enable_registration_captcha": False, "enable_registration_captcha": False,
"macaroon_secret_key": "not even a little secret", "macaroon_secret_key": "not even a little secret",
"expire_access_token": False,
"trusted_third_party_id_servers": [], "trusted_third_party_id_servers": [],
"room_invite_state_types": [], "room_invite_state_types": [],
"password_providers": [], "password_providers": [],
@ -471,7 +470,7 @@ class MockHttpResource(HttpServer):
raise KeyError("No event can handle %s" % path) raise KeyError("No event can handle %s" % path)
def register_paths(self, method, path_patterns, callback): def register_paths(self, method, path_patterns, callback, servlet_name):
for path_pattern in path_patterns: for path_pattern in path_patterns:
self.callbacks.append((method, path_pattern, callback)) self.callbacks.append((method, path_pattern, callback))