Merge remote-tracking branch 'origin/develop' into joriks/opentracing_e2e

This commit is contained in:
Jorik Schellekens 2019-08-05 15:47:15 +01:00
commit a68119e676
94 changed files with 1397 additions and 421 deletions

View file

@ -49,14 +49,15 @@ steps:
- command:
- "python -m pip install tox"
- "apt-get update && apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev zlib1g-dev"
- "python3.5 -m pip install tox"
- "tox -e py35-old,codecov"
label: ":python: 3.5 / SQLite / Old Deps"
env:
TRIAL_FLAGS: "-j 2"
plugins:
- docker#v3.0.1:
image: "python:3.5"
image: "ubuntu:xenial" # We use xenail to get an old sqlite and python
propagate-environment: true
retry:
automatic:

View file

@ -1,5 +1,4 @@
comment:
layout: "diff"
comment: off
coverage:
status:

View file

@ -1,6 +1,48 @@
Synapse 1.2.1 (2019-07-26)
==========================
Security update
---------------
This release includes *four* security fixes:
- Prevent an attack where a federated server could send redactions for arbitrary events in v1 and v2 rooms. ([\#5767](https://github.com/matrix-org/synapse/issues/5767))
- Prevent a denial-of-service attack where cycles of redaction events would make Synapse spin infinitely. Thanks to `@lrizika:matrix.org` for identifying and responsibly disclosing this issue. ([0f2ecb961](https://github.com/matrix-org/synapse/commit/0f2ecb961))
- Prevent an attack where users could be joined or parted from public rooms without their consent. Thanks to @dylangerdaly for identifying and responsibly disclosing this issue. ([\#5744](https://github.com/matrix-org/synapse/issues/5744))
- Fix a vulnerability where a federated server could spoof read-receipts from
users on other servers. Thanks to @dylangerdaly for identifying this issue too. ([\#5743](https://github.com/matrix-org/synapse/issues/5743))
Additionally, the following fix was in Synapse **1.2.0**, but was not correctly
identified during the original release:
- It was possible for a room moderator to send a redaction for an `m.room.create` event, which would downgrade the room to version 1. Thanks to `/dev/ponies` for identifying and responsibly disclosing this issue! ([\#5701](https://github.com/matrix-org/synapse/issues/5701))
Synapse 1.2.0 (2019-07-25)
==========================
No significant changes.
Synapse 1.2.0rc2 (2019-07-24)
=============================
Bugfixes
--------
- Fix a regression introduced in v1.2.0rc1 which led to incorrect labels on some prometheus metrics. ([\#5734](https://github.com/matrix-org/synapse/issues/5734))
Synapse 1.2.0rc1 (2019-07-22)
=============================
Security fixes
--------------
This update included a security fix which was initially incorrectly flagged as
a regular bug fix.
- It was possible for a room moderator to send a redaction for an `m.room.create` event, which would downgrade the room to version 1. Thanks to `/dev/ponies` for identifying and responsibly disclosing this issue! ([\#5701](https://github.com/matrix-org/synapse/issues/5701))
Features
--------
@ -26,7 +68,6 @@ Bugfixes
- Fix bug in #5626 that prevented the original_event field from actually having the contents of the original event in a call to `/relations`. ([\#5654](https://github.com/matrix-org/synapse/issues/5654))
- Fix 3PID bind requests being sent to identity servers as `application/x-form-www-urlencoded` data, which is deprecated. ([\#5658](https://github.com/matrix-org/synapse/issues/5658))
- Fix some problems with authenticating redactions in recent room versions. ([\#5699](https://github.com/matrix-org/synapse/issues/5699), [\#5700](https://github.com/matrix-org/synapse/issues/5700), [\#5707](https://github.com/matrix-org/synapse/issues/5707))
- Ignore redactions of m.room.create events. ([\#5701](https://github.com/matrix-org/synapse/issues/5701))
Updates to the Docker image

1
changelog.d/5686.feature Normal file
View file

@ -0,0 +1 @@
Use `M_USER_DEACTIVATED` instead of `M_UNKNOWN` for errcode when a deactivated user attempts to login.

1
changelog.d/5693.bugfix Normal file
View file

@ -0,0 +1 @@
Fix UISIs during homeserver outage.

1
changelog.d/5743.bugfix Normal file
View file

@ -0,0 +1 @@
Log when we receive an event receipt from an unexpected origin.

1
changelog.d/5746.misc Normal file
View file

@ -0,0 +1 @@
Reduce database IO usage by optimising queries for current membership.

1
changelog.d/5749.misc Normal file
View file

@ -0,0 +1 @@
Fix some error cases in the caching layer.

1
changelog.d/5750.misc Normal file
View file

@ -0,0 +1 @@
Add a prometheus metric for pending cache lookups.

1
changelog.d/5752.misc Normal file
View file

@ -0,0 +1 @@
Reduce database IO usage by optimising queries for current membership.

1
changelog.d/5753.misc Normal file
View file

@ -0,0 +1 @@
Stop trying to fetch events with event_id=None.

1
changelog.d/5768.misc Normal file
View file

@ -0,0 +1 @@
Convert RedactionTestCase to modern test style.

1
changelog.d/5770.misc Normal file
View file

@ -0,0 +1 @@
Reduce database IO usage by optimising queries for current membership.

1
changelog.d/5774.misc Normal file
View file

@ -0,0 +1 @@
Reduce database IO usage by optimising queries for current membership.

1
changelog.d/5775.bugfix Normal file
View file

@ -0,0 +1 @@
Fix debian packaging scripts to correctly build sid packages.

1
changelog.d/5780.misc Normal file
View file

@ -0,0 +1 @@
Allow looping calls to be given arguments.

1
changelog.d/5782.removal Normal file
View file

@ -0,0 +1 @@
Remove non-functional 'expire_access_token' setting.

1
changelog.d/5783.feature Normal file
View file

@ -0,0 +1 @@
Synapse can now be configured to not join remote rooms of a given "complexity" (currently, state events) over federation. This option can be used to prevent adverse performance on resource-constrained homeservers.

1
changelog.d/5785.misc Normal file
View file

@ -0,0 +1 @@
Set the logs emitted when checking typing and presence timeouts to DEBUG level, not INFO.

1
changelog.d/5787.misc Normal file
View file

@ -0,0 +1 @@
Remove DelayedCall debugging from the test suite, as it is no longer required in the vast majority of Synapse's tests.

1
changelog.d/5789.bugfix Normal file
View file

@ -0,0 +1 @@
Fix UISIs during homeserver outage.

1
changelog.d/5790.misc Normal file
View file

@ -0,0 +1 @@
Remove some spurious exceptions from the logs where we failed to talk to a remote server.

1
changelog.d/5792.misc Normal file
View file

@ -0,0 +1 @@
Reduce database IO usage by optimising queries for current membership.

1
changelog.d/5793.misc Normal file
View file

@ -0,0 +1 @@
Reduce database IO usage by optimising queries for current membership.

1
changelog.d/5794.misc Normal file
View file

@ -0,0 +1 @@
Improve performance when making `.well-known` requests by sharing the SSL options between requests.

1
changelog.d/5796.misc Normal file
View file

@ -0,0 +1 @@
Disable codecov GitHub comments on PRs.

1
changelog.d/5801.misc Normal file
View file

@ -0,0 +1 @@
Don't allow clients to send tombstone events that reference the room it's sent in.

1
changelog.d/5802.misc Normal file
View file

@ -0,0 +1 @@
Deny redactions of events sent in a different room.

1
changelog.d/5804.bugfix Normal file
View file

@ -0,0 +1 @@
Fix check that tombstone is a state event in push rules.

1
changelog.d/5805.misc Normal file
View file

@ -0,0 +1 @@
Deny sending well known state types as non-state events.

1
changelog.d/5806.bugfix Normal file
View file

@ -0,0 +1 @@
Fix error when trying to login as a deactivated user when using a worker to handle login.

1
changelog.d/5807.feature Normal file
View file

@ -0,0 +1 @@
Allow defining HTML templates to serve the user on account renewal attempt when using the account validity feature.

1
changelog.d/5808.misc Normal file
View file

@ -0,0 +1 @@
Handle incorrectly encoded query params correctly by returning a 400.

1
changelog.d/5810.misc Normal file
View file

@ -0,0 +1 @@
Return 502 not 500 when failing to reach any remote server.

13
debian/changelog vendored
View file

@ -1,4 +1,10 @@
matrix-synapse-py3 (1.1.0-1) UNRELEASED; urgency=medium
matrix-synapse-py3 (1.2.1) stable; urgency=medium
* New synapse release 1.2.1.
-- Synapse Packaging team <packages@matrix.org> Fri, 26 Jul 2019 11:32:47 +0100
matrix-synapse-py3 (1.2.0) stable; urgency=medium
[ Amber Brown ]
* Update logging config defaults to match API changes in Synapse.
@ -6,7 +12,10 @@ matrix-synapse-py3 (1.1.0-1) UNRELEASED; urgency=medium
[ Richard van der Hoff ]
* Add Recommends and Depends for some libraries which you probably want.
-- Erik Johnston <erikj@rae> Thu, 04 Jul 2019 13:59:02 +0100
[ Synapse Packaging team ]
* New synapse release 1.2.0.
-- Synapse Packaging team <packages@matrix.org> Thu, 25 Jul 2019 14:10:07 +0100
matrix-synapse-py3 (1.1.0) stable; urgency=medium

View file

@ -42,6 +42,11 @@ RUN cd dh-virtualenv-1.1 && dpkg-buildpackage -us -uc -b
###
FROM ${distro}
# Get the distro we want to pull from as a dynamic build variable
# (We need to define it in each build stage)
ARG distro=""
ENV distro ${distro}
# Install the build dependencies
#
# NB: keep this list in sync with the list of build-deps in debian/control

View file

@ -4,7 +4,8 @@
set -ex
DIST=`lsb_release -c -s`
# Get the codename from distro env
DIST=`cut -d ':' -f2 <<< $distro`
# we get a read-only copy of the source: make a writeable copy
cp -aT /synapse/source /synapse/build

View file

@ -278,6 +278,23 @@ listeners:
# Used by phonehome stats to group together related servers.
#server_context: context
# Resource-constrained Homeserver Settings
#
# If limit_remote_rooms.enabled is True, the room complexity will be
# checked before a user joins a new remote room. If it is above
# limit_remote_rooms.complexity, it will disallow joining or
# instantly leave.
#
# limit_remote_rooms.complexity_error can be set to customise the text
# displayed to the user when a room above the complexity threshold has
# its join cancelled.
#
# Uncomment the below lines to enable:
#limit_remote_rooms:
# enabled: True
# complexity: 1.0
# complexity_error: "This room is too complex."
# Whether to require a user to be in the room to add an alias to it.
# Defaults to 'true'.
#
@ -785,6 +802,16 @@ uploads_path: "DATADIR/uploads"
# period: 6w
# renew_at: 1w
# renew_email_subject: "Renew your %(app)s account"
# # Directory in which Synapse will try to find the HTML files to serve to the
# # user when trying to renew an account. Optional, defaults to
# # synapse/res/templates.
# template_dir: "res/templates"
# # HTML to be displayed to the user after they successfully renewed their
# # account. Optional.
# account_renewed_html_path: "account_renewed.html"
# # HTML to be displayed when the user tries to renew an account with an invalid
# # renewal token. Optional.
# invalid_token_html_path: "invalid_token.html"
# Time that a user's session remains valid for, after they log in.
#
@ -925,10 +952,6 @@ uploads_path: "DATADIR/uploads"
#
# macaroon_secret_key: <PRIVATE STRING>
# Used to enable access token expiration.
#
#expire_access_token: False
# a secret which is used to calculate HMACs for form values, to stop
# falsification of values. Must be specified for the User Consent
# forms to work.

View file

@ -35,4 +35,4 @@ try:
except ImportError:
pass
__version__ = "1.2.0rc1"
__version__ = "1.2.1"

View file

@ -421,21 +421,16 @@ class Auth(object):
try:
user_id = self.get_user_id_from_macaroon(macaroon)
has_expiry = False
guest = False
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith("time "):
has_expiry = True
elif caveat.caveat_id == "guest = true":
if caveat.caveat_id == "guest = true":
guest = True
self.validate_macaroon(
macaroon, rights, self.hs.config.expire_access_token, user_id=user_id
)
self.validate_macaroon(macaroon, rights, user_id=user_id)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise InvalidClientTokenError("Invalid macaroon passed.")
if not has_expiry and rights == "access":
if rights == "access":
self.token_cache[token] = (user_id, guest)
return user_id, guest
@ -461,7 +456,7 @@ class Auth(object):
return caveat.caveat_id[len(user_prefix) :]
raise InvalidClientTokenError("No user caveat in macaroon")
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
def validate_macaroon(self, macaroon, type_string, user_id):
"""
validate that a Macaroon is understood by and was signed by this server.
@ -469,7 +464,6 @@ class Auth(object):
macaroon(pymacaroons.Macaroon): The macaroon to validate
type_string(str): The kind of token required (e.g. "access",
"delete_pusher")
verify_expiry(bool): Whether to verify whether the macaroon has expired.
user_id (str): The user_id required
"""
v = pymacaroons.Verifier()
@ -482,19 +476,7 @@ class Auth(object):
v.satisfy_exact("type = " + type_string)
v.satisfy_exact("user_id = %s" % user_id)
v.satisfy_exact("guest = true")
# verify_expiry should really always be True, but there exist access
# tokens in the wild which expire when they should not, so we can't
# enforce expiry yet (so we have to allow any caveat starting with
# 'time < ' in access tokens).
#
# On the other hand, short-term login tokens (as used by CAS login, for
# example) have an expiry time which we do want to enforce.
if verify_expiry:
v.satisfy_general(self._verify_expiry)
else:
v.satisfy_general(lambda c: c.startswith("time < "))
v.satisfy_general(self._verify_expiry)
# access_tokens include a nonce for uniqueness: any value is acceptable
v.satisfy_general(lambda c: c.startswith("nonce = "))

View file

@ -61,6 +61,7 @@ class Codes(object):
INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION"
EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT"
USER_DEACTIVATED = "M_USER_DEACTIVATED"
class CodeMessageException(RuntimeError):
@ -151,7 +152,7 @@ class UserDeactivatedError(SynapseError):
msg (str): The human-readable error message
"""
super(UserDeactivatedError, self).__init__(
code=http_client.FORBIDDEN, msg=msg, errcode=Codes.UNKNOWN
code=http_client.FORBIDDEN, msg=msg, errcode=Codes.USER_DEACTIVATED
)

View file

@ -116,8 +116,6 @@ class KeyConfig(Config):
seed = bytes(self.signing_key[0])
self.macaroon_secret_key = hashlib.sha256(seed).digest()
self.expire_access_token = config.get("expire_access_token", False)
# a secret which is used to calculate HMACs for form values, to stop
# falsification of values
self.form_secret = config.get("form_secret", None)
@ -144,10 +142,6 @@ class KeyConfig(Config):
#
%(macaroon_secret_key)s
# Used to enable access token expiration.
#
#expire_access_token: False
# a secret which is used to calculate HMACs for form values, to stop
# falsification of values. Must be specified for the User Consent
# forms to work.

View file

@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from distutils.util import strtobool
import pkg_resources
from synapse.config._base import Config, ConfigError
from synapse.types import RoomAlias
from synapse.util.stringutils import random_string_with_symbols
@ -41,8 +44,36 @@ class AccountValidityConfig(Config):
self.startup_job_max_delta = self.period * 10.0 / 100.0
if self.renew_by_email_enabled and "public_baseurl" not in synapse_config:
raise ConfigError("Can't send renewal emails without 'public_baseurl'")
if self.renew_by_email_enabled:
if "public_baseurl" not in synapse_config:
raise ConfigError("Can't send renewal emails without 'public_baseurl'")
template_dir = config.get("template_dir")
if not template_dir:
template_dir = pkg_resources.resource_filename("synapse", "res/templates")
if "account_renewed_html_path" in config:
file_path = os.path.join(template_dir, config["account_renewed_html_path"])
self.account_renewed_html_content = self.read_file(
file_path, "account_validity.account_renewed_html_path"
)
else:
self.account_renewed_html_content = (
"<html><body>Your account has been successfully renewed.</body><html>"
)
if "invalid_token_html_path" in config:
file_path = os.path.join(template_dir, config["invalid_token_html_path"])
self.invalid_token_html_content = self.read_file(
file_path, "account_validity.invalid_token_html_path"
)
else:
self.invalid_token_html_content = (
"<html><body>Invalid renewal token.</body><html>"
)
class RegistrationConfig(Config):
@ -145,6 +176,16 @@ class RegistrationConfig(Config):
# period: 6w
# renew_at: 1w
# renew_email_subject: "Renew your %%(app)s account"
# # Directory in which Synapse will try to find the HTML files to serve to the
# # user when trying to renew an account. Optional, defaults to
# # synapse/res/templates.
# template_dir: "res/templates"
# # HTML to be displayed to the user after they successfully renewed their
# # account. Optional.
# account_renewed_html_path: "account_renewed.html"
# # HTML to be displayed when the user tries to renew an account with an invalid
# # renewal token. Optional.
# invalid_token_html_path: "invalid_token.html"
# Time that a user's session remains valid for, after they log in.
#

View file

@ -18,6 +18,7 @@
import logging
import os.path
import attr
from netaddr import IPSet
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@ -38,6 +39,12 @@ DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"]
DEFAULT_ROOM_VERSION = "4"
ROOM_COMPLEXITY_TOO_GREAT = (
"Your homeserver is unable to join rooms this large or complex. "
"Please speak to your server administrator, or upgrade your instance "
"to join this room."
)
class ServerConfig(Config):
def read_config(self, config, **kwargs):
@ -247,6 +254,23 @@ class ServerConfig(Config):
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
@attr.s
class LimitRemoteRoomsConfig(object):
enabled = attr.ib(
validator=attr.validators.instance_of(bool), default=False
)
complexity = attr.ib(
validator=attr.validators.instance_of((int, float)), default=1.0
)
complexity_error = attr.ib(
validator=attr.validators.instance_of(str),
default=ROOM_COMPLEXITY_TOO_GREAT,
)
self.limit_remote_rooms = LimitRemoteRoomsConfig(
**config.get("limit_remote_rooms", {})
)
bind_port = config.get("bind_port")
if bind_port:
if config.get("no_tls", False):
@ -617,6 +641,23 @@ class ServerConfig(Config):
# Used by phonehome stats to group together related servers.
#server_context: context
# Resource-constrained Homeserver Settings
#
# If limit_remote_rooms.enabled is True, the room complexity will be
# checked before a user joins a new remote room. If it is above
# limit_remote_rooms.complexity, it will disallow joining or
# instantly leave.
#
# limit_remote_rooms.complexity_error can be set to customise the text
# displayed to the user when a room above the complexity threshold has
# its join cancelled.
#
# Uncomment the below lines to enable:
#limit_remote_rooms:
# enabled: True
# complexity: 1.0
# complexity_error: "This room is too complex."
# Whether to require a user to be in the room to add an alias to it.
# Defaults to 'true'.
#

View file

@ -31,6 +31,7 @@ from twisted.internet.ssl import (
platformTrust,
)
from twisted.python.failure import Failure
from twisted.web.iweb import IPolicyForHTTPS
logger = logging.getLogger(__name__)
@ -74,6 +75,7 @@ class ServerContextFactory(ContextFactory):
return self._context
@implementer(IPolicyForHTTPS)
class ClientTLSOptionsFactory(object):
"""Factory for Twisted SSLClientConnectionCreators that are used to make connections
to remote servers for federation.
@ -146,6 +148,12 @@ class ClientTLSOptionsFactory(object):
f = Failure()
tls_protocol.failVerification(f)
def creatorForNetloc(self, hostname, port):
"""Implements the IPolicyForHTTPS interace so that this can be passed
directly to agents.
"""
return self.get_options(hostname)
@implementer(IOpenSSLClientConnectionCreator)
class SSLClientConnectionCreator(object):

View file

@ -95,10 +95,10 @@ class EventValidator(object):
elif event.type == EventTypes.Topic:
self._ensure_strings(event.content, ["topic"])
self._ensure_state_event(event)
elif event.type == EventTypes.Name:
self._ensure_strings(event.content, ["name"])
self._ensure_state_event(event)
elif event.type == EventTypes.Member:
if "membership" not in event.content:
raise SynapseError(400, "Content has not membership key")
@ -106,9 +106,25 @@ class EventValidator(object):
if event.content["membership"] not in Membership.LIST:
raise SynapseError(400, "Invalid membership key")
self._ensure_state_event(event)
elif event.type == EventTypes.Tombstone:
if "replacement_room" not in event.content:
raise SynapseError(400, "Content has no replacement_room key")
if event.content["replacement_room"] == event.room_id:
raise SynapseError(
400, "Tombstone cannot reference the room it was sent in"
)
self._ensure_state_event(event)
def _ensure_strings(self, d, keys):
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))
if not isinstance(d[s], string_types):
raise SynapseError(400, "'%s' not a string type" % (s,))
def _ensure_state_event(self, event):
if not event.is_state():
raise SynapseError(400, "'%s' must be state events" % (event.type,))

View file

@ -511,9 +511,8 @@ class FederationClient(FederationBase):
The [Deferred] result of callback, if it succeeds
Raises:
SynapseError if the chosen remote server returns a 300/400 code.
RuntimeError if no servers were reachable.
SynapseError if the chosen remote server returns a 300/400 code, or
no servers were reachable.
"""
for destination in destinations:
if destination == self.server_name:
@ -538,7 +537,7 @@ class FederationClient(FederationBase):
except Exception:
logger.warn("Failed to %s via %s", description, destination, exc_info=1)
raise RuntimeError("Failed to %s via any server" % (description,))
raise SynapseError(502, "Failed to %s via any server" % (description,))
def make_membership_event(
self, destinations, room_id, user_id, membership, content, params
@ -993,3 +992,39 @@ class FederationClient(FederationBase):
)
raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks
def get_room_complexity(self, destination, room_id):
"""
Fetch the complexity of a remote room from another server.
Args:
destination (str): The remote server
room_id (str): The room ID to ask about.
Returns:
Deferred[dict] or Deferred[None]: Dict contains the complexity
metric versions, while None means we could not fetch the complexity.
"""
try:
complexity = yield self.transport_layer.get_room_complexity(
destination=destination, room_id=room_id
)
defer.returnValue(complexity)
except CodeMessageException as e:
# We didn't manage to get it -- probably a 404. We are okay if other
# servers don't give it to us.
logger.debug(
"Failed to fetch room complexity via %s for %s, got a %d",
destination,
room_id,
e.code,
)
except Exception:
logger.exception(
"Failed to fetch room complexity via %s for %s", destination, room_id
)
# If we don't manage to find it, return None. It's not an error if a
# server doesn't give it to us.
defer.returnValue(None)

View file

@ -365,7 +365,7 @@ class FederationServer(FederationBase):
logger.warn("Room version %s not in %s", room_version, supported_versions)
raise IncompatibleRoomVersionError(room_version=room_version)
pdu = yield self.handler.on_make_join_request(room_id, user_id)
pdu = yield self.handler.on_make_join_request(origin, room_id, user_id)
time_now = self._clock.time_msec()
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
@ -415,7 +415,7 @@ class FederationServer(FederationBase):
def on_make_leave_request(self, origin, room_id, user_id):
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, room_id)
pdu = yield self.handler.on_make_leave_request(room_id, user_id)
pdu = yield self.handler.on_make_leave_request(origin, room_id, user_id)
room_version = yield self.store.get_room_version(room_id)

View file

@ -21,7 +21,11 @@ from six.moves import urllib
from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX
from synapse.api.urls import (
FEDERATION_UNSTABLE_PREFIX,
FEDERATION_V1_PREFIX,
FEDERATION_V2_PREFIX,
)
from synapse.logging.utils import log_function
logger = logging.getLogger(__name__)
@ -935,6 +939,23 @@ class TransportLayerClient(object):
destination=destination, path=path, data=content, ignore_backoff=True
)
def get_room_complexity(self, destination, room_id):
"""
Args:
destination (str): The remote server
room_id (str): The room ID to ask about.
"""
path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/rooms/%s/complexity", room_id)
return self.client.get_json(destination=destination, path=path)
def _create_path(federation_prefix, path, *args):
"""
Ensures that all args are url encoded.
"""
return federation_prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args)
def _create_v1_path(path, *args):
"""Creates a path against V1 federation API from the path template and
@ -951,9 +972,7 @@ def _create_v1_path(path, *args):
Returns:
str
"""
return FEDERATION_V1_PREFIX + path % tuple(
urllib.parse.quote(arg, "") for arg in args
)
return _create_path(FEDERATION_V1_PREFIX, path, *args)
def _create_v2_path(path, *args):
@ -971,6 +990,4 @@ def _create_v2_path(path, *args):
Returns:
str
"""
return FEDERATION_V2_PREFIX + path % tuple(
urllib.parse.quote(arg, "") for arg in args
)
return _create_path(FEDERATION_V2_PREFIX, path, *args)

View file

@ -326,7 +326,9 @@ class BaseFederationServlet(object):
if code is None:
continue
server.register_paths(method, (pattern,), self._wrap(code))
server.register_paths(
method, (pattern,), self._wrap(code), self.__class__.__name__
)
class FederationSendServlet(BaseFederationServlet):

View file

@ -226,11 +226,19 @@ class AccountValidityHandler(object):
Args:
renewal_token (str): Token sent with the renewal request.
Returns:
bool: Whether the provided token is valid.
"""
user_id = yield self.store.get_user_from_renewal_token(renewal_token)
try:
user_id = yield self.store.get_user_from_renewal_token(renewal_token)
except StoreError:
defer.returnValue(False)
logger.debug("Renewing an account for user %s", user_id)
yield self.renew_account_for_user(user_id)
defer.returnValue(True)
@defer.inlineCallbacks
def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False):
"""Renews the account attached to a given user by pushing back the

View file

@ -860,7 +860,7 @@ class AuthHandler(BaseHandler):
try:
macaroon = pymacaroons.Macaroon.deserialize(login_token)
user_id = auth_api.get_user_id_from_macaroon(macaroon)
auth_api.validate_macaroon(macaroon, "login", True, user_id)
auth_api.validate_macaroon(macaroon, "login", user_id)
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
self.ratelimit_login_per_account(user_id)

View file

@ -229,12 +229,12 @@ class DeviceHandler(DeviceWorkerHandler):
self.federation_sender = hs.get_federation_sender()
self._edu_updater = DeviceListEduUpdater(hs, self)
self.device_list_updater = DeviceListUpdater(hs, self)
federation_registry = hs.get_federation_registry()
federation_registry.register_edu_handler(
"m.device_list_update", self._edu_updater.incoming_device_list_update
"m.device_list_update", self.device_list_updater.incoming_device_list_update
)
federation_registry.register_query_handler(
"user_devices", self.on_federation_query_user_devices
@ -460,7 +460,7 @@ def _update_device_from_client_ips(device, client_ips):
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
class DeviceListEduUpdater(object):
class DeviceListUpdater(object):
"Handles incoming device list updates from federation and updates the DB"
def __init__(self, hs, device_handler):
@ -574,85 +574,7 @@ class DeviceListEduUpdater(object):
logger.debug("Need to re-sync devices for %r? %r", user_id, resync)
if resync:
opentracing.log_kv({"message": "Doing resync to update device list."})
# Fetch all devices for the user.
origin = get_domain_from_id(user_id)
try:
result = yield self.federation.query_user_devices(origin, user_id)
except (
NotRetryingDestination,
RequestSendFailed,
HttpResponseException,
):
# TODO: Remember that we are now out of sync and try again
# later
logger.warn("Failed to handle device list update for %s", user_id)
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
# is out of date. If we bail then we will retry the resync
# next time we get a device list update for this user_id.
# This makes it more likely that the device lists will
# eventually become consistent.
return
except FederationDeniedError as e:
opentracing.set_tag("error", True)
opentracing.log_kv({"reason": "FederationDeniedError"})
logger.info(e)
return
except Exception as e:
# TODO: Remember that we are now out of sync and try again
# later
opentracing.set_tag("error", True)
opentracing.log_kv(
{
"message": "Exception raised by federation request",
"exception": e,
}
)
logger.exception(
"Failed to handle device list update for %s", user_id
)
return
opentracing.log_kv({"result": result})
stream_id = result["stream_id"]
devices = result["devices"]
# If the remote server has more than ~1000 devices for this user
# we assume that something is going horribly wrong (e.g. a bot
# that logs in and creates a new device every time it tries to
# send a message). Maintaining lots of devices per user in the
# cache can cause serious performance issues as if this request
# takes more than 60s to complete, internal replication from the
# inbound federation worker to the synapse master may time out
# causing the inbound federation to fail and causing the remote
# server to retry, causing a DoS. So in this scenario we give
# up on storing the total list of devices and only handle the
# delta instead.
if len(devices) > 1000:
logger.warn(
"Ignoring device list snapshot for %s as it has >1K devs (%d)",
user_id,
len(devices),
)
devices = []
for device in devices:
logger.debug(
"Handling resync update %r/%r, ID: %r",
user_id,
device["device_id"],
stream_id,
)
yield self.store.update_remote_device_list_cache(
user_id, devices, stream_id
)
device_ids = [device["device_id"] for device in devices]
yield self.device_handler.notify_device_update(user_id, device_ids)
# We clobber the seen updates since we've re-synced from a given
# point.
self._seen_updates[user_id] = set([stream_id])
yield self.user_device_resync(user_id)
else:
# Simply update the single device, since we know that is the only
# change (because of the single prev_id matching the current cache)
@ -699,3 +621,85 @@ class DeviceListEduUpdater(object):
stream_id_in_updates.add(stream_id)
return False
@defer.inlineCallbacks
def user_device_resync(self, user_id):
"""Fetches all devices for a user and updates the device cache with them.
Args:
user_id (str): The user's id whose device_list will be updated.
Returns:
Deferred[dict]: a dict with device info as under the "devices" in the result of this
request:
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
"""
opentracing.log_kv({"message": "Doing resync to update device list."})
# Fetch all devices for the user.
origin = get_domain_from_id(user_id)
try:
result = yield self.federation.query_user_devices(origin, user_id)
except (NotRetryingDestination, RequestSendFailed, HttpResponseException):
# TODO: Remember that we are now out of sync and try again
# later
logger.warn("Failed to handle device list update for %s", user_id)
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
# is out of date. If we bail then we will retry the resync
# next time we get a device list update for this user_id.
# This makes it more likely that the device lists will
# eventually become consistent.
return
except FederationDeniedError as e:
opentracing.set_tag("error", True)
opentracing.log_kv({"reason": "FederationDeniedError"})
logger.info(e)
return
except Exception:
# TODO: Remember that we are now out of sync and try again
# later
opentracing.set_tag("error", True)
opentracing.log_kv(
{"message": "Exception raised by federation request", "exception": e}
)
logger.exception("Failed to handle device list update for %s", user_id)
return
opentracing.log_kv({"result": result})
stream_id = result["stream_id"]
devices = result["devices"]
# If the remote server has more than ~1000 devices for this user
# we assume that something is going horribly wrong (e.g. a bot
# that logs in and creates a new device every time it tries to
# send a message). Maintaining lots of devices per user in the
# cache can cause serious performance issues as if this request
# takes more than 60s to complete, internal replication from the
# inbound federation worker to the synapse master may time out
# causing the inbound federation to fail and causing the remote
# server to retry, causing a DoS. So in this scenario we give
# up on storing the total list of devices and only handle the
# delta instead.
if len(devices) > 1000:
logger.warn(
"Ignoring device list snapshot for %s as it has >1K devs (%d)",
user_id,
len(devices),
)
devices = []
for device in devices:
logger.debug(
"Handling resync update %r/%r, ID: %r",
user_id,
device["device_id"],
stream_id,
)
yield self.store.update_remote_device_list_cache(user_id, devices, stream_id)
device_ids = [device["device_id"] for device in devices]
yield self.device_handler.notify_device_update(user_id, device_ids)
# We clobber the seen updates since we've re-synced from a given
# point.
self._seen_updates[user_id] = set([stream_id])
defer.returnValue(result)

View file

@ -278,7 +278,6 @@ class DirectoryHandler(BaseHandler):
servers = list(servers)
return {"room_id": room_id, "servers": servers}
return
@defer.inlineCallbacks
def on_directory_query(self, args):

View file

@ -26,6 +26,7 @@ import synapse.logging.opentracing as opentracing
from synapse.api.errors import CodeMessageException, SynapseError
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import UserID, get_domain_from_id
from synapse.util import unwrapFirstError
from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@ -128,9 +129,57 @@ class E2eKeysHandler(object):
@opentracing.trace
@defer.inlineCallbacks
def do_remote_query(destination):
"""This is called when we are querying the device list of a user on
a remote homeserver and their device list is not in the device list
cache. If we share a room with this user and we're not querying for
specific user we will update the cache
with their device list."""
destination_query = remote_queries_not_in_cache[destination]
opentracing.set_tag("key_query", destination_query)
# We first consider whether we wish to update the device list cache with
# the users device list. We want to track a user's devices when the
# authenticated user shares a room with the queried user and the query
# has not specified a particular device.
# If we update the cache for the queried user we remove them from further
# queries. We use the more efficient batched query_client_keys for all
# remaining users
user_ids_updated = []
for (user_id, device_list) in destination_query.items():
if user_id in user_ids_updated:
continue
if device_list:
continue
room_ids = yield self.store.get_rooms_for_user(user_id)
if not room_ids:
continue
# We've decided we're sharing a room with this user and should
# probably be tracking their device lists. However, we haven't
# done an initial sync on the device list so we do it now.
try:
user_devices = yield self.device_handler.device_list_updater.user_device_resync(
user_id
)
user_devices = user_devices["devices"]
for device in user_devices:
results[user_id] = {device["device_id"]: device["keys"]}
user_ids_updated.append(user_id)
except Exception as e:
failures[destination] = _exception_to_failure(e)
if len(destination_query) == len(user_ids_updated):
# We've updated all the users in the query and we do not need to
# make any further remote calls.
return
# Remove all the users from the query which we have updated
for user_id in user_ids_updated:
destination_query.pop(user_id)
try:
remote_result = yield self.federation.query_client_keys(
destination, {"device_keys": destination_query}, timeout=timeout
@ -153,7 +202,7 @@ class E2eKeysHandler(object):
for destination in remote_queries_not_in_cache
],
consumeErrors=True,
)
).addErrback(unwrapFirstError)
)
return {"device_keys": results, "failures": failures}

View file

@ -978,6 +978,9 @@ class FederationHandler(BaseHandler):
except NotRetryingDestination as e:
logger.info(str(e))
continue
except RequestSendFailed as e:
logger.info("Falied to get backfill from %s because %s", dom, e)
continue
except FederationDeniedError as e:
logger.info(e)
continue
@ -1204,11 +1207,28 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def on_make_join_request(self, room_id, user_id):
def on_make_join_request(self, origin, room_id, user_id):
""" We've received a /make_join/ request, so we create a partial
join event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
Args:
origin (str): The (verified) server name of the requesting server.
room_id (str): Room to create join event in
user_id (str): The user to create the join for
Returns:
Deferred[FrozenEvent]
"""
if get_domain_from_id(user_id) != origin:
logger.info(
"Got /make_join request for user %r from different origin %s, ignoring",
user_id,
origin,
)
raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
event_content = {"membership": Membership.JOIN}
room_version = yield self.store.get_room_version(room_id)
@ -1411,11 +1431,27 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def on_make_leave_request(self, room_id, user_id):
def on_make_leave_request(self, origin, room_id, user_id):
""" We've received a /make_leave/ request, so we create a partial
leave event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
Args:
origin (str): The (verified) server name of the requesting server.
room_id (str): Room to create leave event in
user_id (str): The user to create the leave for
Returns:
Deferred[FrozenEvent]
"""
if get_domain_from_id(user_id) != origin:
logger.info(
"Got /make_leave request for user %r from different origin %s, ignoring",
user_id,
origin,
)
raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
room_version = yield self.store.get_room_version(room_id)
builder = self.event_builder_factory.new(
room_version,
@ -2763,3 +2799,28 @@ class FederationHandler(BaseHandler):
)
else:
return user_joined_room(self.distributor, user, room_id)
@defer.inlineCallbacks
def get_room_complexity(self, remote_room_hosts, room_id):
"""
Fetch the complexity of a remote room over federation.
Args:
remote_room_hosts (list[str]): The remote servers to ask.
room_id (str): The room ID to ask about.
Returns:
Deferred[dict] or Deferred[None]: Dict contains the complexity
metric versions, while None means we could not fetch the complexity.
"""
for host in remote_room_hosts:
res = yield self.federation_client.get_room_complexity(host, room_id)
# We got a result, return it.
if res:
defer.returnValue(res)
# We fell off the bottom, couldn't get the complexity from anyone. Oh
# well.
defer.returnValue(None)

View file

@ -126,9 +126,12 @@ class GroupsLocalHandler(object):
group_id, requester_user_id
)
else:
res = yield self.transport_client.get_group_summary(
get_domain_from_id(group_id), group_id, requester_user_id
)
try:
res = yield self.transport_client.get_group_summary(
get_domain_from_id(group_id), group_id, requester_user_id
)
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")
group_server_name = get_domain_from_id(group_id)
@ -183,9 +186,12 @@ class GroupsLocalHandler(object):
content["user_profile"] = yield self.profile_handler.get_profile(user_id)
res = yield self.transport_client.create_group(
get_domain_from_id(group_id), group_id, user_id, content
)
try:
res = yield self.transport_client.create_group(
get_domain_from_id(group_id), group_id, user_id, content
)
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")
remote_attestation = res["attestation"]
yield self.attestations.verify_attestation(
@ -221,9 +227,12 @@ class GroupsLocalHandler(object):
group_server_name = get_domain_from_id(group_id)
res = yield self.transport_client.get_users_in_group(
get_domain_from_id(group_id), group_id, requester_user_id
)
try:
res = yield self.transport_client.get_users_in_group(
get_domain_from_id(group_id), group_id, requester_user_id
)
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")
chunk = res["chunk"]
valid_entries = []
@ -258,9 +267,12 @@ class GroupsLocalHandler(object):
local_attestation = self.attestations.create_attestation(group_id, user_id)
content["attestation"] = local_attestation
res = yield self.transport_client.join_group(
get_domain_from_id(group_id), group_id, user_id, content
)
try:
res = yield self.transport_client.join_group(
get_domain_from_id(group_id), group_id, user_id, content
)
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")
remote_attestation = res["attestation"]
@ -299,9 +311,12 @@ class GroupsLocalHandler(object):
local_attestation = self.attestations.create_attestation(group_id, user_id)
content["attestation"] = local_attestation
res = yield self.transport_client.accept_group_invite(
get_domain_from_id(group_id), group_id, user_id, content
)
try:
res = yield self.transport_client.accept_group_invite(
get_domain_from_id(group_id), group_id, user_id, content
)
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")
remote_attestation = res["attestation"]
@ -338,13 +353,16 @@ class GroupsLocalHandler(object):
group_id, user_id, requester_user_id, content
)
else:
res = yield self.transport_client.invite_to_group(
get_domain_from_id(group_id),
group_id,
user_id,
requester_user_id,
content,
)
try:
res = yield self.transport_client.invite_to_group(
get_domain_from_id(group_id),
group_id,
user_id,
requester_user_id,
content,
)
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")
return res
@ -398,13 +416,16 @@ class GroupsLocalHandler(object):
)
else:
content["requester_user_id"] = requester_user_id
res = yield self.transport_client.remove_user_from_group(
get_domain_from_id(group_id),
group_id,
requester_user_id,
user_id,
content,
)
try:
res = yield self.transport_client.remove_user_from_group(
get_domain_from_id(group_id),
group_id,
requester_user_id,
user_id,
content,
)
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")
return res
@ -435,9 +456,13 @@ class GroupsLocalHandler(object):
return {"groups": result}
else:
bulk_result = yield self.transport_client.bulk_get_publicised_groups(
get_domain_from_id(user_id), [user_id]
)
try:
bulk_result = yield self.transport_client.bulk_get_publicised_groups(
get_domain_from_id(user_id), [user_id]
)
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")
result = bulk_result.get("users", {}).get(user_id)
# TODO: Verify attestations
return {"groups": result}

View file

@ -378,7 +378,11 @@ class EventCreationHandler(object):
# tolerate them in event_auth.check().
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
prev_event = (
yield self.store.get_event(prev_event_id, allow_none=True)
if prev_event_id
else None
)
if not prev_event or prev_event.membership != Membership.JOIN:
logger.warning(
(
@ -521,6 +525,8 @@ class EventCreationHandler(object):
"""
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_event_id = prev_state_ids.get((event.type, event.state_key))
if not prev_event_id:
return
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
@ -789,7 +795,6 @@ class EventCreationHandler(object):
get_prev_content=False,
allow_rejected=False,
allow_none=True,
check_room_id=event.room_id,
)
# we can make some additional checks now if we have the original event.
@ -797,6 +802,9 @@ class EventCreationHandler(object):
if original_event.type == EventTypes.Create:
raise AuthError(403, "Redacting create events is not permitted")
if original_event.room_id != event.room_id:
raise SynapseError(400, "Cannot redact event from a different room")
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True

View file

@ -333,7 +333,7 @@ class PresenceHandler(object):
"""Checks the presence of users that have timed out and updates as
appropriate.
"""
logger.info("Handling presence timeouts")
logger.debug("Handling presence timeouts")
now = self.clock.time_msec()
# Fetch the list of users that *may* have timed out. Things may have

View file

@ -17,7 +17,7 @@ import logging
from twisted.internet import defer
from synapse.handlers._base import BaseHandler
from synapse.types import ReadReceipt
from synapse.types import ReadReceipt, get_domain_from_id
logger = logging.getLogger(__name__)
@ -40,18 +40,27 @@ class ReceiptsHandler(BaseHandler):
def _received_remote_receipt(self, origin, content):
"""Called when we receive an EDU of type m.receipt from a remote HS.
"""
receipts = [
ReadReceipt(
room_id=room_id,
receipt_type=receipt_type,
user_id=user_id,
event_ids=user_values["event_ids"],
data=user_values.get("data", {}),
)
for room_id, room_values in content.items()
for receipt_type, users in room_values.items()
for user_id, user_values in users.items()
]
receipts = []
for room_id, room_values in content.items():
for receipt_type, users in room_values.items():
for user_id, user_values in users.items():
if get_domain_from_id(user_id) != origin:
logger.info(
"Received receipt for user %r from server %s, ignoring",
user_id,
origin,
)
continue
receipts.append(
ReadReceipt(
room_id=room_id,
receipt_type=receipt_type,
user_id=user_id,
event_ids=user_values["event_ids"],
data=user_values.get("data", {}),
)
)
yield self._handle_new_receipts(receipts)

View file

@ -26,8 +26,7 @@ from unpaddedbase64 import decode_base64
from twisted.internet import defer
import synapse.server
import synapse.types
from synapse import types
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, HttpResponseException, SynapseError
from synapse.types import RoomID, UserID
@ -543,7 +542,7 @@ class RoomMemberHandler(object):
), "Sender (%s) must be same as requester (%s)" % (sender, requester.user)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
requester = synapse.types.create_requester(target_user)
requester = types.create_requester(target_user)
prev_event = yield self.event_creation_handler.deduplicate_state_event(
event, context
@ -945,6 +944,47 @@ class RoomMemberMasterHandler(RoomMemberHandler):
self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
@defer.inlineCallbacks
def _is_remote_room_too_complex(self, room_id, remote_room_hosts):
"""
Check if complexity of a remote room is too great.
Args:
room_id (str)
remote_room_hosts (list[str])
Returns: bool of whether the complexity is too great, or None
if unable to be fetched
"""
max_complexity = self.hs.config.limit_remote_rooms.complexity
complexity = yield self.federation_handler.get_room_complexity(
remote_room_hosts, room_id
)
if complexity:
if complexity["v1"] > max_complexity:
return True
return False
return None
@defer.inlineCallbacks
def _is_local_room_too_complex(self, room_id):
"""
Check if the complexity of a local room is too great.
Args:
room_id (str)
Returns: bool
"""
max_complexity = self.hs.config.limit_remote_rooms.complexity
complexity = yield self.store.get_room_complexity(room_id)
if complexity["v1"] > max_complexity:
return True
return False
@defer.inlineCallbacks
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
"""Implements RoomMemberHandler._remote_join
@ -952,7 +992,6 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# filter ourselves out of remote_room_hosts: do_invite_join ignores it
# and if it is the only entry we'd like to return a 404 rather than a
# 500.
remote_room_hosts = [
host for host in remote_room_hosts if host != self.hs.hostname
]
@ -960,6 +999,18 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
if self.hs.config.limit_remote_rooms.enabled:
# Fetch the room complexity
too_complex = yield self._is_remote_room_too_complex(
room_id, remote_room_hosts
)
if too_complex is True:
raise SynapseError(
code=400,
msg=self.hs.config.limit_remote_rooms.complexity_error,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
)
# We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
@ -969,6 +1020,31 @@ class RoomMemberMasterHandler(RoomMemberHandler):
)
yield self._user_joined_room(user, room_id)
# Check the room we just joined wasn't too large, if we didn't fetch the
# complexity of it before.
if self.hs.config.limit_remote_rooms.enabled:
if too_complex is False:
# We checked, and we're under the limit.
return
# Check again, but with the local state events
too_complex = yield self._is_local_room_too_complex(room_id)
if too_complex is False:
# We're under the limit.
return
# The room is too large. Leave.
requester = types.create_requester(user, None, False, None)
yield self.update_membership(
requester=requester, target=user, room_id=room_id, action="leave"
)
raise SynapseError(
code=400,
msg=self.hs.config.limit_remote_rooms.complexity_error,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
)
@defer.inlineCallbacks
def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
"""Implements RoomMemberHandler._remote_reject_invite

View file

@ -83,7 +83,7 @@ class TypingHandler(object):
self._room_typing = {}
def _handle_timeouts(self):
logger.info("Checking for typing timeouts")
logger.debug("Checking for typing timeouts")
now = self.clock.time_msec()

View file

@ -64,10 +64,6 @@ class MatrixFederationAgent(object):
tls_client_options_factory (ClientTLSOptionsFactory|None):
factory to use for fetching client tls options, or none to disable TLS.
_well_known_tls_policy (IPolicyForHTTPS|None):
TLS policy to use for fetching .well-known files. None to use a default
(browser-like) implementation.
_srv_resolver (SrvResolver|None):
SRVResolver impl to use for looking up SRV records. None to use a default
implementation.
@ -81,7 +77,6 @@ class MatrixFederationAgent(object):
self,
reactor,
tls_client_options_factory,
_well_known_tls_policy=None,
_srv_resolver=None,
_well_known_cache=well_known_cache,
):
@ -98,13 +93,12 @@ class MatrixFederationAgent(object):
self._pool.maxPersistentPerHost = 5
self._pool.cachedConnectionTimeout = 2 * 60
agent_args = {}
if _well_known_tls_policy is not None:
# the param is called 'contextFactory', but actually passing a
# contextfactory is deprecated, and it expects an IPolicyForHTTPS.
agent_args["contextFactory"] = _well_known_tls_policy
_well_known_agent = RedirectAgent(
Agent(self._reactor, pool=self._pool, **agent_args)
Agent(
self._reactor,
pool=self._pool,
contextFactory=tls_client_options_factory,
)
)
self._well_known_agent = _well_known_agent

View file

@ -245,7 +245,9 @@ class JsonResource(HttpServer, resource.Resource):
isLeaf = True
_PathEntry = collections.namedtuple("_PathEntry", ["pattern", "callback"])
_PathEntry = collections.namedtuple(
"_PathEntry", ["pattern", "callback", "servlet_classname"]
)
def __init__(self, hs, canonical_json=True):
resource.Resource.__init__(self)
@ -255,12 +257,28 @@ class JsonResource(HttpServer, resource.Resource):
self.path_regexs = {}
self.hs = hs
def register_paths(self, method, path_patterns, callback):
def register_paths(self, method, path_patterns, callback, servlet_classname):
"""
Registers a request handler against a regular expression. Later request URLs are
checked against these regular expressions in order to identify an appropriate
handler for that request.
Args:
method (str): GET, POST etc
path_patterns (Iterable[str]): A list of regular expressions to which
the request URLs are compared.
callback (function): The handler for the request. Usually a Servlet
servlet_classname (str): The name of the handler to be used in prometheus
and opentracing logs.
"""
method = method.encode("utf-8") # method is bytes on py3
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback)
self._PathEntry(path_pattern, callback, servlet_classname)
)
def render(self, request):
@ -275,13 +293,9 @@ class JsonResource(HttpServer, resource.Resource):
This checks if anyone has registered a callback for that method and
path.
"""
callback, group_dict = self._get_handler_for_request(request)
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
servlet_instance = getattr(callback, "__self__", None)
if servlet_instance is not None:
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
# Make sure we have a name for this handler in prometheus.
request.request_metrics.name = servlet_classname
# Now trigger the callback. If it returns a response, we send it
@ -311,7 +325,8 @@ class JsonResource(HttpServer, resource.Resource):
request (twisted.web.http.Request):
Returns:
Tuple[Callable, dict[unicode, unicode]]: callback method, and the
Tuple[Callable, str, dict[unicode, unicode]]: callback method, the
label to use for that method in prometheus metrics, and the
dict mapping keys to path components as specified in the
handler's path match regexp.
@ -320,7 +335,7 @@ class JsonResource(HttpServer, resource.Resource):
None, or a tuple of (http code, response body).
"""
if request.method == b"OPTIONS":
return _options_handler, {}
return _options_handler, "options_request_handler", {}
# Loop through all the registered callbacks to check if the method
# and path regex match
@ -328,10 +343,10 @@ class JsonResource(HttpServer, resource.Resource):
m = path_entry.pattern.match(request.path.decode("ascii"))
if m:
# We found a match!
return path_entry.callback, m.groupdict()
return path_entry.callback, path_entry.servlet_classname, m.groupdict()
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
return _unrecognised_request_handler, {}
return _unrecognised_request_handler, "unrecognised_request_handler", {}
def _send_response(
self, request, code, response_json_object, response_code_message=None

View file

@ -166,7 +166,12 @@ def parse_string_from_args(
value = args[name][0]
if encoding:
value = value.decode(encoding)
try:
value = value.decode(encoding)
except ValueError:
raise SynapseError(
400, "Query parameter %r must be %s" % (name, encoding)
)
if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
@ -290,11 +295,13 @@ class RestServlet(object):
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
if hasattr(self, "on_%s" % (method,)):
servlet_classname = self.__class__.__name__
method_handler = getattr(self, "on_%s" % (method,))
http_server.register_paths(
method,
patterns,
trace_servlet(self.__class__.__name__, method_handler),
trace_servlet(servlet_classname, method_handler),
servlet_classname,
)
else:

View file

@ -245,7 +245,13 @@ BASE_APPEND_OVERRIDE_RULES = [
"key": "type",
"pattern": "m.room.tombstone",
"_id": "_tombstone",
}
},
{
"kind": "event_match",
"key": "state_key",
"pattern": "",
"_id": "_tombstone_statekey",
},
],
"actions": ["notify", {"set_tweak": "highlight", "value": True}],
},

View file

@ -205,7 +205,7 @@ class ReplicationEndpoint(object):
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
http_server.register_paths(method, [pattern], handler)
http_server.register_paths(method, [pattern], handler, self.__class__.__name__)
def _cached_handler(self, request, txn_id, **kwargs):
"""Called on new incoming requests when caching is enabled. Checks

View file

@ -0,0 +1 @@
<html><body>Your account has been successfully renewed.</body><html>

View file

@ -0,0 +1 @@
<html><body>Invalid renewal token.</body><html>

View file

@ -59,9 +59,14 @@ class SendServerNoticeServlet(RestServlet):
def register(self, json_resource):
PATTERN = "^/_synapse/admin/v1/send_server_notice"
json_resource.register_paths("POST", (re.compile(PATTERN + "$"),), self.on_POST)
json_resource.register_paths(
"PUT", (re.compile(PATTERN + "/(?P<txn_id>[^/]*)$"),), self.on_PUT
"POST", (re.compile(PATTERN + "$"),), self.on_POST, self.__class__.__name__
)
json_resource.register_paths(
"PUT",
(re.compile(PATTERN + "/(?P<txn_id>[^/]*)$"),),
self.on_PUT,
self.__class__.__name__,
)
@defer.inlineCallbacks

View file

@ -67,11 +67,17 @@ class RoomCreateRestServlet(TransactionRestServlet):
register_txn_path(self, PATTERNS, http_server)
# define CORS for all of /rooms in RoomCreateRestServlet for simplicity
http_server.register_paths(
"OPTIONS", client_patterns("/rooms(?:/.*)?$", v1=True), self.on_OPTIONS
"OPTIONS",
client_patterns("/rooms(?:/.*)?$", v1=True),
self.on_OPTIONS,
self.__class__.__name__,
)
# define CORS for /createRoom[/txnid]
http_server.register_paths(
"OPTIONS", client_patterns("/createRoom(?:/.*)?$", v1=True), self.on_OPTIONS
"OPTIONS",
client_patterns("/createRoom(?:/.*)?$", v1=True),
self.on_OPTIONS,
self.__class__.__name__,
)
def on_PUT(self, request, txn_id):
@ -116,16 +122,28 @@ class RoomStateEventRestServlet(TransactionRestServlet):
)
http_server.register_paths(
"GET", client_patterns(state_key, v1=True), self.on_GET
"GET",
client_patterns(state_key, v1=True),
self.on_GET,
self.__class__.__name__,
)
http_server.register_paths(
"PUT", client_patterns(state_key, v1=True), self.on_PUT
"PUT",
client_patterns(state_key, v1=True),
self.on_PUT,
self.__class__.__name__,
)
http_server.register_paths(
"GET", client_patterns(no_state_key, v1=True), self.on_GET_no_state_key
"GET",
client_patterns(no_state_key, v1=True),
self.on_GET_no_state_key,
self.__class__.__name__,
)
http_server.register_paths(
"PUT", client_patterns(no_state_key, v1=True), self.on_PUT_no_state_key
"PUT",
client_patterns(no_state_key, v1=True),
self.on_PUT_no_state_key,
self.__class__.__name__,
)
def on_GET_no_state_key(self, request, room_id, event_type):
@ -845,18 +863,23 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
with_get: True to also register respective GET paths for the PUTs.
"""
http_server.register_paths(
"POST", client_patterns(regex_string + "$", v1=True), servlet.on_POST
"POST",
client_patterns(regex_string + "$", v1=True),
servlet.on_POST,
servlet.__class__.__name__,
)
http_server.register_paths(
"PUT",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_PUT,
servlet.__class__.__name__,
)
if with_get:
http_server.register_paths(
"GET",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_GET,
servlet.__class__.__name__,
)

View file

@ -42,6 +42,8 @@ class AccountValidityRenewServlet(RestServlet):
self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth()
self.success_html = hs.config.account_validity.account_renewed_html_content
self.failure_html = hs.config.account_validity.invalid_token_html_content
@defer.inlineCallbacks
def on_GET(self, request):
@ -49,16 +51,23 @@ class AccountValidityRenewServlet(RestServlet):
raise SynapseError(400, "Missing renewal token")
renewal_token = request.args[b"token"][0]
yield self.account_activity_handler.renew_account(renewal_token.decode("utf8"))
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(
b"Content-Length", b"%d" % (len(AccountValidityRenewServlet.SUCCESS_HTML),)
token_valid = yield self.account_activity_handler.renew_account(
renewal_token.decode("utf8")
)
request.write(AccountValidityRenewServlet.SUCCESS_HTML)
if token_valid:
status_code = 200
response = self.success_html
else:
status_code = 404
response = self.failure_html
request.setResponseCode(status_code)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(response),))
request.write(response.encode("utf8"))
finish_request(request)
return None
defer.returnValue(None)
class AccountValiditySendMailServlet(RestServlet):
@ -87,7 +96,7 @@ class AccountValiditySendMailServlet(RestServlet):
user_id = requester.user.to_string()
yield self.account_activity_handler.send_renewal_email_to_user(user_id)
return (200, {})
defer.returnValue((200, {}))
def register_servlets(hs, http_server):

View file

@ -72,11 +72,13 @@ class RelationSendServlet(RestServlet):
"POST",
client_patterns(self.PATTERN + "$", releases=()),
self.on_PUT_or_POST,
self.__class__.__name__,
)
http_server.register_paths(
"PUT",
client_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()),
self.on_PUT,
self.__class__.__name__,
)
def on_PUT(self, request, *args, **kwargs):

View file

@ -139,8 +139,11 @@ class EventsWorkerStore(SQLBaseStore):
If there is a mismatch, behave as per allow_none.
Returns:
Deferred : A FrozenEvent.
Deferred[EventBase|None]
"""
if not isinstance(event_id, str):
raise TypeError("Invalid event event_id %r" % (event_id,))
events = yield self.get_events_as_list(
[event_id],
check_redacted=check_redacted,
@ -268,6 +271,14 @@ class EventsWorkerStore(SQLBaseStore):
)
continue
if original_event.room_id != entry.event.room_id:
logger.info(
"Withholding redaction %s of event %s from a different room",
event_id,
redacted_event_id,
)
continue
if entry.event.internal_metadata.need_to_check_redaction():
original_domain = get_domain_from_id(original_event.sender)
redaction_domain = get_domain_from_id(entry.event.sender)
@ -629,6 +640,10 @@ class EventsWorkerStore(SQLBaseStore):
# we choose to ignore redactions of m.room.create events.
return None
if original_ev.type == "m.room.redaction":
# ... and redaction events
return None
redaction_map = yield self._get_events_from_cache_or_db(redactions)
for redaction_id in redactions:
@ -636,9 +651,21 @@ class EventsWorkerStore(SQLBaseStore):
if not redaction_entry:
# we don't have the redaction event, or the redaction event was not
# authorized.
logger.debug(
"%s was redacted by %s but redaction not found/authed",
original_ev.event_id,
redaction_id,
)
continue
redaction_event = redaction_entry.event
if redaction_event.room_id != original_ev.room_id:
logger.debug(
"%s was redacted by %s but redaction was in a different room!",
original_ev.event_id,
redaction_id,
)
continue
# Starting in room version v3, some redactions need to be
# rechecked if we didn't have the redacted event at the
@ -650,8 +677,15 @@ class EventsWorkerStore(SQLBaseStore):
redaction_event.internal_metadata.recheck_redaction = False
else:
# Senders don't match, so the event isn't actually redacted
logger.debug(
"%s was redacted by %s but the senders don't match",
original_ev.event_id,
redaction_id,
)
continue
logger.debug("Redacting %s due to %s", original_ev.event_id, redaction_id)
# we found a good redaction event. Redact!
redacted_event = prune_event(original_ev)
redacted_event.unsigned["redacted_by"] = redaction_id

View file

@ -569,6 +569,27 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_id_servers_user_bound",
)
@cachedInlineCallbacks()
def get_user_deactivated_status(self, user_id):
"""Retrieve the value for the `deactivated` property for the provided user.
Args:
user_id (str): The ID of the user to retrieve the status for.
Returns:
defer.Deferred(bool): The requested value.
"""
res = yield self._simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="deactivated",
desc="get_user_deactivated_status",
)
# Convert the integer into a boolean.
return res == 1
class RegistrationStore(
RegistrationWorkerStore, background_updates.BackgroundUpdateStore
@ -1317,24 +1338,3 @@ class RegistrationStore(
user_id,
deactivated,
)
@cachedInlineCallbacks()
def get_user_deactivated_status(self, user_id):
"""Retrieve the value for the `deactivated` property for the provided user.
Args:
user_id (str): The ID of the user to retrieve the status for.
Returns:
defer.Deferred(bool): The requested value.
"""
res = yield self._simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="deactivated",
desc="get_user_deactivated_status",
)
# Convert the integer into a boolean.
return res == 1

View file

@ -156,9 +156,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# then we can avoid a join, which is a Very Good Thing given how
# frequently this function gets called.
if self._current_state_events_membership_up_to_date:
# Note, rejected events will have a null membership field, so
# we we manually filter them out.
sql = """
SELECT count(*), membership FROM current_state_events
WHERE type = 'm.room.member' AND room_id = ?
AND membership IS NOT NULL
GROUP BY membership
"""
else:
@ -179,19 +182,30 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# we order by membership and then fairly arbitrarily by event_id so
# heroes are consistent
sql = """
SELECT m.user_id, m.membership, m.event_id
FROM room_memberships as m
INNER JOIN current_state_events as c
ON m.event_id = c.event_id
AND m.room_id = c.room_id
AND m.user_id = c.state_key
WHERE c.type = 'm.room.member' AND c.room_id = ?
ORDER BY
CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
m.event_id ASC
LIMIT ?
"""
if self._current_state_events_membership_up_to_date:
# Note, rejected events will have a null membership field, so
# we we manually filter them out.
sql = """
SELECT state_key, membership, event_id
FROM current_state_events
WHERE type = 'm.room.member' AND room_id = ?
AND membership IS NOT NULL
ORDER BY
CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
event_id ASC
LIMIT ?
"""
else:
sql = """
SELECT c.state_key, m.membership, c.event_id
FROM room_memberships as m
INNER JOIN current_state_events as c USING (room_id, event_id)
WHERE c.type = 'm.room.member' AND c.room_id = ?
ORDER BY
CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
c.event_id ASC
LIMIT ?
"""
# 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
@ -256,28 +270,35 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return invite
return None
@defer.inlineCallbacks
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user
matches one in the membership list.
Filters out forgotten rooms.
Args:
user_id (str): The user ID.
membership_list (list): A list of synapse.api.constants.Membership
values which the user must be in.
Returns:
A list of dictionary objects, with room_id, membership and sender
defined.
Deferred[list[RoomsForUser]]
"""
if not membership_list:
return defer.succeed(None)
return self.runInteraction(
rooms = yield self.runInteraction(
"get_rooms_for_user_where_membership_is",
self._get_rooms_for_user_where_membership_is_txn,
user_id,
membership_list,
)
# Now we filter out forgotten rooms
forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
return [room for room in rooms if room.room_id not in forgotten_rooms]
def _get_rooms_for_user_where_membership_is_txn(
self, txn, user_id, membership_list
):
@ -287,26 +308,33 @@ class RoomMemberWorkerStore(EventsWorkerStore):
results = []
if membership_list:
where_clause = "user_id = ? AND (%s) AND forgotten = 0" % (
" OR ".join(["m.membership = ?" for _ in membership_list]),
)
if self._current_state_events_membership_up_to_date:
sql = """
SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
FROM current_state_events AS c
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
AND state_key = ?
AND c.membership IN (%s)
""" % (
",".join("?" * len(membership_list))
)
else:
sql = """
SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering
FROM current_state_events AS c
INNER JOIN room_memberships AS m USING (room_id, event_id)
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
AND state_key = ?
AND m.membership IN (%s)
""" % (
",".join("?" * len(membership_list))
)
args = [user_id]
args.extend(membership_list)
sql = (
"SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering"
" FROM current_state_events as c"
" INNER JOIN room_memberships as m"
" ON m.event_id = c.event_id"
" INNER JOIN events as e"
" ON e.event_id = c.event_id"
" AND m.room_id = c.room_id"
" AND m.user_id = c.state_key"
" WHERE c.type = 'm.room.member' AND %s"
) % (where_clause,)
txn.execute(sql, args)
txn.execute(sql, (user_id, *membership_list))
results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)]
if do_invite:
@ -637,6 +665,44 @@ class RoomMemberWorkerStore(EventsWorkerStore):
count = yield self.runInteraction("did_forget_membership", f)
return count == 0
@cached()
def get_forgotten_rooms_for_user(self, user_id):
"""Gets all rooms the user has forgotten.
Args:
user_id (str)
Returns:
Deferred[set[str]]
"""
def _get_forgotten_rooms_for_user_txn(txn):
# This is a slightly convoluted query that first looks up all rooms
# that the user has forgotten in the past, then rechecks that list
# to see if any have subsequently been updated. This is done so that
# we can use a partial index on `forgotten = 1` on the assumption
# that few users will actually forget many rooms.
#
# Note that a room is considered "forgotten" if *all* membership
# events for that user and room have the forgotten field set (as
# when a user forgets a room we update all rows for that user and
# room, not just the current one).
sql = """
SELECT room_id, (
SELECT count(*) FROM room_memberships
WHERE room_id = m.room_id AND user_id = m.user_id AND forgotten = 0
) AS count
FROM room_memberships AS m
WHERE user_id = ? AND forgotten = 1
GROUP BY room_id, user_id;
"""
txn.execute(sql, (user_id,))
return set(row[0] for row in txn if row[1] == 0)
return self.runInteraction(
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
)
@defer.inlineCallbacks
def get_rooms_user_has_been_in(self, user_id):
"""Get all rooms that the user has ever been in.
@ -668,6 +734,13 @@ class RoomMemberStore(RoomMemberWorkerStore):
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
self._background_current_state_membership,
)
self.register_background_index_update(
"room_membership_forgotten_idx",
index_name="room_memberships_user_room_forgotten",
table="room_memberships",
columns=["user_id", "room_id"],
where_clause="forgotten = 1",
)
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
@ -769,6 +842,9 @@ class RoomMemberStore(RoomMemberWorkerStore):
txn.execute(sql, (user_id, room_id))
self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
self._invalidate_cache_and_stream(
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
return self.runInteraction("forget_membership", f)
@ -859,7 +935,7 @@ class RoomMemberStore(RoomMemberWorkerStore):
while processed < batch_size:
txn.execute(
"""
SELECT MIN(room_id) FROM rooms WHERE room_id > ?
SELECT MIN(room_id) FROM current_state_events WHERE room_id > ?
""",
(last_processed_room,),
)
@ -870,10 +946,10 @@ class RoomMemberStore(RoomMemberWorkerStore):
next_room, = row
sql = """
UPDATE current_state_events AS c
UPDATE current_state_events
SET membership = (
SELECT membership FROM room_memberships
WHERE event_id = c.event_id
WHERE event_id = current_state_events.event_id
)
WHERE room_id = ?
"""

View file

@ -20,6 +20,3 @@
-- for membership events. (Will also be null for membership events until the
-- background update job has finished).
ALTER TABLE current_state_events ADD membership TEXT;
INSERT INTO background_updates (update_name, progress_json) VALUES
('current_state_events_membership', '{}');

View file

@ -0,0 +1,24 @@
/* Copyright 2019 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- We add membership to current state so that we don't need to join against
-- room_memberships, which can be surprisingly costly (we do such queries
-- very frequently).
-- This will be null for non-membership events and the content.membership key
-- for membership events. (Will also be null for membership events until the
-- background update job has finished).
INSERT INTO background_updates (update_name, progress_json) VALUES
('current_state_events_membership', '{}');

View file

@ -0,0 +1,18 @@
/* Copyright 2019 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Adds an index on room_memberships for fetching all forgotten rooms for a user
INSERT INTO background_updates (update_name, progress_json) VALUES
('room_membership_forgotten_idx', '{}');

View file

@ -211,16 +211,18 @@ class StatsStore(StateDeltasStore):
avatar_id = current_state_ids.get((EventTypes.RoomAvatar, ""))
canonical_alias_id = current_state_ids.get((EventTypes.CanonicalAlias, ""))
event_ids = [
join_rules_id,
history_visibility_id,
encryption_id,
name_id,
topic_id,
avatar_id,
canonical_alias_id,
]
state_events = yield self.get_events(
[
join_rules_id,
history_visibility_id,
encryption_id,
name_id,
topic_id,
avatar_id,
canonical_alias_id,
]
[ev for ev in event_ids if ev is not None]
)
def _get_or_none(event_id, arg):

View file

@ -59,7 +59,7 @@ class Clock(object):
"""Returns the current system time in miliseconds since epoch."""
return int(self.time() * 1000)
def looping_call(self, f, msec):
def looping_call(self, f, msec, *args, **kwargs):
"""Call a function repeatedly.
Waits `msec` initially before calling `f` for the first time.
@ -70,8 +70,10 @@ class Clock(object):
Args:
f(function): The function to call repeatedly.
msec(float): How long to wait between calls in milliseconds.
*args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
call = task.LoopingCall(f)
call = task.LoopingCall(f, *args, **kwargs)
call.clock = self._reactor
d = call.start(msec / 1000.0, now=False)
d.addErrback(log_failure, "Looping call died", consumeErrors=False)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -51,7 +52,19 @@ response_cache_evicted = Gauge(
response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
def register_cache(cache_type, cache_name, cache):
def register_cache(cache_type, cache_name, cache, collect_callback=None):
"""Register a cache object for metric collection.
Args:
cache_type (str):
cache_name (str): name of the cache
cache (object): cache itself
collect_callback (callable|None): if not None, a function which is called during
metric collection to update additional metrics.
Returns:
CacheMetric: an object which provides inc_{hits,misses,evictions} methods
"""
# Check if the metric is already registered. Unregister it, if so.
# This usually happens during tests, as at runtime these caches are
@ -90,6 +103,8 @@ def register_cache(cache_type, cache_name, cache):
cache_hits.labels(cache_name).set(self.hits)
cache_evicted.labels(cache_name).set(self.evicted_size)
cache_total.labels(cache_name).set(self.hits + self.misses)
if collect_callback:
collect_callback()
except Exception as e:
logger.warn("Error calculating metrics for %s: %s", cache_name, e)
raise

View file

@ -19,8 +19,9 @@ import logging
import threading
from collections import namedtuple
import six
from six import itervalues, string_types
from six import itervalues
from prometheus_client import Gauge
from twisted.internet import defer
@ -30,13 +31,18 @@ from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.stringutils import to_ascii
from . import register_cache
logger = logging.getLogger(__name__)
cache_pending_metric = Gauge(
"synapse_util_caches_cache_pending",
"Number of lookups currently pending for this cache",
["name"],
)
_CacheSentinel = object()
@ -82,11 +88,19 @@ class Cache(object):
self.name = name
self.keylen = keylen
self.thread = None
self.metrics = register_cache("cache", name, self.cache)
self.metrics = register_cache(
"cache",
name,
self.cache,
collect_callback=self._metrics_collection_callback,
)
def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)
def _metrics_collection_callback(self):
cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
@ -108,7 +122,7 @@ class Cache(object):
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
Either a Deferred or the raw result
Either an ObservableDeferred or the raw result
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
@ -132,9 +146,14 @@ class Cache(object):
return default
def set(self, key, value, callback=None):
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")
callbacks = [callback] if callback else []
self.check_thread()
entry = CacheEntry(deferred=value, callbacks=callbacks)
observable = ObservableDeferred(value, consumeErrors=True)
observer = defer.maybeDeferred(observable.observe)
entry = CacheEntry(deferred=observable, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
@ -142,20 +161,31 @@ class Cache(object):
self._pending_deferred_cache[key] = entry
def shuffle(result):
def compare_and_pop():
"""Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it.
Returns true if the entries matched.
"""
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
return True
# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry
return False
def cb(result):
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
@ -163,9 +193,16 @@ class Cache(object):
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()
return result
entry.deferred.addCallback(shuffle)
def eb(_fail):
compare_and_pop()
entry.invalidate()
# once the deferred completes, we can move the entry from the
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
return observable
def prefill(self, key, value, callback=None):
callbacks = [callback] if callback else []
@ -398,20 +435,10 @@ class CacheDescriptor(_CacheDescriptorBase):
ret.addErrback(onErr)
# If our cache_key is a string on py2, try to convert to ascii
# to save a bit of space in large caches. Py3 does this
# internally automatically.
if six.PY2 and isinstance(cache_key, string_types):
cache_key = to_ascii(cache_key)
result_d = ObservableDeferred(ret, consumeErrors=True)
cache.set(cache_key, result_d, callback=invalidate_callback)
result_d = cache.set(cache_key, ret, callback=invalidate_callback)
observer = result_d.observe()
if isinstance(observer, defer.Deferred):
return make_deferred_yieldable(observer)
else:
return observer
return make_deferred_yieldable(observer)
if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0])
@ -527,7 +554,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
missing.add(arg)
if missing:
# we need an observable deferred for each entry in the list,
# we need a deferred for each entry in the list,
# which we put in the cache. Each deferred resolves with the
# relevant result for that key.
deferreds_map = {}
@ -535,8 +562,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
deferred = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
observable = ObservableDeferred(deferred)
cache.set(key, observable, callback=invalidate_callback)
cache.set(key, deferred, callback=invalidate_callback)
def complete_all(res):
# the wrapped function has completed. It returns a

View file

@ -13,12 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.federation.transport import server
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.types import UserID
from synapse.util.ratelimitutils import FederationRateLimiter
from tests import unittest
@ -33,9 +37,8 @@ class RoomComplexityTests(unittest.HomeserverTestCase):
]
def default_config(self, name="test"):
config = super(RoomComplexityTests, self).default_config(name=name)
config["limit_large_remote_room_joins"] = True
config["limit_large_remote_room_complexity"] = 0.05
config = super().default_config(name=name)
config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05}
return config
def prepare(self, reactor, clock, homeserver):
@ -88,3 +91,71 @@ class RoomComplexityTests(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code)
complexity = channel.json_body["v1"]
self.assertEqual(complexity, 1.23)
def test_join_too_large(self):
u1 = self.register_user("u1", "pass")
handler = self.hs.get_room_member_handler()
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1))
d = handler._remote_join(
None,
["otherserver.example"],
"roomid",
UserID.from_string(u1),
{"membership": "join"},
)
self.pump()
# The request failed with a SynapseError saying the resource limit was
# exceeded.
f = self.get_failure(d, SynapseError)
self.assertEqual(f.value.code, 400, f.value)
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_join_too_large_once_joined(self):
u1 = self.register_user("u1", "pass")
u1_token = self.login("u1", "pass")
# Ok, this might seem a bit weird -- I want to test that we actually
# leave the room, but I don't want to simulate two servers. So, we make
# a local room, which we say we're joining remotely, even if there's no
# remote, because we mock that out. Then, we'll leave the (actually
# local) room, which will be propagated over federation in a real
# scenario.
room_1 = self.helper.create_room_as(u1, tok=u1_token)
handler = self.hs.get_room_member_handler()
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed(None))
handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1))
# Artificially raise the complexity
self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed(
600
)
d = handler._remote_join(
None,
["otherserver.example"],
room_1,
UserID.from_string(u1),
{"membership": "join"},
)
self.pump()
# The request failed with a SynapseError saying the resource limit was
# exceeded.
f = self.get_failure(d, SynapseError)
self.assertEqual(f.value.code, 400)
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)

View file

@ -44,7 +44,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
hs_config["max_mau_value"] = 50
hs_config["limit_usage_by_mau"] = True
hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True)
hs = self.setup_test_homeserver(config=hs_config)
return hs
def prepare(self, reactor, clock, hs):

View file

@ -75,7 +75,6 @@ class MatrixFederationAgentTests(TestCase):
config_dict = default_config("test", parse=False)
config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()]
# config_dict["trusted_key_servers"] = []
self._config = config = HomeServerConfig()
config.parse_config_dict(config_dict, "", "")
@ -83,7 +82,6 @@ class MatrixFederationAgentTests(TestCase):
self.agent = MatrixFederationAgent(
reactor=self.reactor,
tls_client_options_factory=ClientTLSOptionsFactory(config),
_well_known_tls_policy=TrustingTLSPolicyForHTTPS(),
_srv_resolver=self.mock_resolver,
_well_known_cache=self.well_known_cache,
)
@ -691,16 +689,18 @@ class MatrixFederationAgentTests(TestCase):
not signed by a CA
"""
# we use the same test server as the other tests, but use an agent
# with _well_known_tls_policy left to the default, which will not
# trust it (since the presented cert is signed by a test CA)
# we use the same test server as the other tests, but use an agent with
# the config left to the default, which will not trust it (since the
# presented cert is signed by a test CA)
self.mock_resolver.resolve_service.side_effect = lambda _: []
self.reactor.lookups["testserv"] = "1.2.3.4"
config = default_config("test", parse=True)
agent = MatrixFederationAgent(
reactor=self.reactor,
tls_client_options_factory=ClientTLSOptionsFactory(self._config),
tls_client_options_factory=ClientTLSOptionsFactory(config),
_srv_resolver=self.mock_resolver,
_well_known_cache=self.well_known_cache,
)

View file

@ -323,6 +323,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"renew_at": 172800000, # Time in ms for 2 days
"renew_by_email_enabled": True,
"renew_email_subject": "Renew your account",
"account_renewed_html_path": "account_renewed.html",
"invalid_token_html_path": "invalid_token.html",
}
# Email config.
@ -373,6 +375,19 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
# Check that we're getting HTML back.
content_type = None
for header in channel.result.get("headers", []):
if header[0] == b"Content-Type":
content_type = header[1]
self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result)
# Check that the HTML we're getting is the one we expect on a successful renewal.
expected_html = self.hs.config.account_validity.account_renewed_html_content
self.assertEqual(
channel.result["body"], expected_html.encode("utf8"), channel.result
)
# Move 3 days forward. If the renewal failed, every authed request with
# our access token should be denied from now, otherwise they should
# succeed.
@ -381,6 +396,28 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_renewal_invalid_token(self):
# Hit the renewal endpoint with an invalid token and check that it behaves as
# expected, i.e. that it responds with 404 Not Found and the correct HTML.
url = "/_matrix/client/unstable/account_validity/renew?token=123"
request, channel = self.make_request(b"GET", url)
self.render(request)
self.assertEquals(channel.result["code"], b"404", channel.result)
# Check that we're getting HTML back.
content_type = None
for header in channel.result.get("headers", []):
if header[0] == b"Content-Type":
content_type = header[1]
self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result)
# Check that the HTML we're getting is the one we expect when using an
# invalid/unknown token.
expected_html = self.hs.config.account_validity.invalid_token_html_content
self.assertEqual(
channel.result["body"], expected_html.encode("utf8"), channel.result
)
def test_manual_email_send(self):
self.email_attempts = []

View file

@ -36,7 +36,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"room_name": "Server Notices",
}
hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True)
hs = self.setup_test_homeserver(config=hs_config)
return hs
def prepare(self, reactor, clock, hs):

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,23 +17,21 @@
from mock import Mock
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID
from tests import unittest
from tests.utils import create_room, setup_test_homeserver
from tests.utils import create_room
class RedactionTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver(
self.addCleanup, resource_for_federation=Mock(), http_client=None
class RedactionTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
return self.setup_test_homeserver(
resource_for_federation=Mock(), http_client=None
)
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@ -42,11 +41,12 @@ class RedactionTestCase(unittest.TestCase):
self.room1 = RoomID.from_string("!abc123:test")
yield create_room(hs, self.room1.to_string(), self.u_alice.to_string())
self.get_success(
create_room(hs, self.room1.to_string(), self.u_alice.to_string())
)
self.depth = 1
@defer.inlineCallbacks
def inject_room_member(
self, room, user, membership, replaces_state=None, extra_content={}
):
@ -63,15 +63,14 @@ class RedactionTestCase(unittest.TestCase):
},
)
event, context = yield self.event_creation_handler.create_new_client_event(
builder
event, context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
yield self.store.persist_event(event, context)
self.get_success(self.store.persist_event(event, context))
return event
@defer.inlineCallbacks
def inject_message(self, room, user, body):
self.depth += 1
@ -86,15 +85,14 @@ class RedactionTestCase(unittest.TestCase):
},
)
event, context = yield self.event_creation_handler.create_new_client_event(
builder
event, context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
yield self.store.persist_event(event, context)
self.get_success(self.store.persist_event(event, context))
return event
@defer.inlineCallbacks
def inject_redaction(self, room, event_id, user, reason):
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
@ -108,20 +106,21 @@ class RedactionTestCase(unittest.TestCase):
},
)
event, context = yield self.event_creation_handler.create_new_client_event(
builder
event, context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
yield self.store.persist_event(event, context)
self.get_success(self.store.persist_event(event, context))
@defer.inlineCallbacks
def test_redact(self):
yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
self.get_success(
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
)
msg_event = yield self.inject_message(self.room1, self.u_alice, "t")
msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t"))
# Check event has not been redacted:
event = yield self.store.get_event(msg_event.event_id)
event = self.get_success(self.store.get_event(msg_event.event_id))
self.assertObjectHasAttributes(
{
@ -136,11 +135,11 @@ class RedactionTestCase(unittest.TestCase):
# Redact event
reason = "Because I said so"
yield self.inject_redaction(
self.room1, msg_event.event_id, self.u_alice, reason
self.get_success(
self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
)
event = yield self.store.get_event(msg_event.event_id)
event = self.get_success(self.store.get_event(msg_event.event_id))
self.assertEqual(msg_event.event_id, event.event_id)
@ -164,15 +163,18 @@ class RedactionTestCase(unittest.TestCase):
event.unsigned["redacted_because"],
)
@defer.inlineCallbacks
def test_redact_join(self):
yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
msg_event = yield self.inject_room_member(
self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}
self.get_success(
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
)
event = yield self.store.get_event(msg_event.event_id)
msg_event = self.get_success(
self.inject_room_member(
self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}
)
)
event = self.get_success(self.store.get_event(msg_event.event_id))
self.assertObjectHasAttributes(
{
@ -187,13 +189,13 @@ class RedactionTestCase(unittest.TestCase):
# Redact event
reason = "Because I said so"
yield self.inject_redaction(
self.room1, msg_event.event_id, self.u_alice, reason
self.get_success(
self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
)
# Check redaction
event = yield self.store.get_event(msg_event.event_id)
event = self.get_success(self.store.get_event(msg_event.event_id))
self.assertTrue("redacted_because" in event.unsigned)

View file

@ -20,7 +20,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID
from synapse.types import Requester, RoomID, UserID
from tests import unittest
from tests.utils import create_room, setup_test_homeserver
@ -84,3 +84,38 @@ class RoomMemberStoreTestCase(unittest.TestCase):
)
],
)
class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.store = homeserver.get_datastore()
self.room_creator = homeserver.get_room_creation_handler()
def test_can_rerun_update(self):
# First make sure we have completed all updates.
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)
# Now let's create a room, which will insert a membership
user = UserID("alice", "test")
requester = Requester(user, None, False, None, None)
self.get_success(self.room_creator.create_room(requester, {}))
# Register the background update to run again.
self.get_success(
self.store._simple_insert(
table="background_updates",
values={
"update_name": "current_state_events_membership",
"progress_json": "{}",
"depends_on": None,
},
)
)
# ... and tell the DataStore that it hasn't finished all updates yet
self.store._all_done = False
# Now let's actually drive the updates to completion
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)

View file

@ -61,7 +61,10 @@ class JsonResourceTests(unittest.TestCase):
res = JsonResource(self.homeserver)
res.register_paths(
"GET", [re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")], _callback
"GET",
[re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")],
_callback,
"test_servlet",
)
request, channel = make_request(
@ -82,7 +85,9 @@ class JsonResourceTests(unittest.TestCase):
raise Exception("boo")
res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
res.register_paths(
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
render(request, res, self.reactor)
@ -105,7 +110,9 @@ class JsonResourceTests(unittest.TestCase):
return make_deferred_yieldable(d)
res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
res.register_paths(
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
render(request, res, self.reactor)
@ -122,7 +129,9 @@ class JsonResourceTests(unittest.TestCase):
raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN)
res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
res.register_paths(
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
render(request, res, self.reactor)
@ -143,7 +152,9 @@ class JsonResourceTests(unittest.TestCase):
self.fail("shouldn't ever get here")
res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
res.register_paths(
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar")
render(request, res, self.reactor)

View file

@ -23,8 +23,6 @@ from mock import Mock
from canonicaljson import json
import twisted
import twisted.logger
from twisted.internet.defer import Deferred, succeed
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
@ -80,10 +78,6 @@ class TestCase(unittest.TestCase):
@around(self)
def setUp(orig):
# enable debugging of delayed calls - this means that we get a
# traceback when a unit test exits leaving things on the reactor.
twisted.internet.base.DelayedCall.debug = True
# if we're not starting in the sentinel logcontext, then to be honest
# all future bets are off.
if LoggingContext.current_context() is not LoggingContext.sentinel:

View file

@ -27,6 +27,7 @@ from synapse.logging.context import (
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
from synapse.util.caches.descriptors import cached
from tests import unittest
@ -55,12 +56,15 @@ class CacheTestCase(unittest.TestCase):
d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))
# lookup should return the deferreds
self.assertIs(cache.get("key1"), d1)
self.assertIs(cache.get("key2"), d2)
# lookup should return observable deferreds
self.assertFalse(cache.get("key1").has_called())
self.assertFalse(cache.get("key2").has_called())
# let one of the lookups complete
d2.callback("result2")
# for now at least, the cache will return real results rather than an
# observabledeferred
self.assertEqual(cache.get("key2"), "result2")
# now do the invalidation
@ -146,6 +150,28 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r, "chips")
obj.mock.assert_not_called()
def test_cache_with_sync_exception(self):
"""If the wrapped function throws synchronously, things should continue to work
"""
class Cls(object):
@cached()
def fn(self, arg1):
raise SynapseError(100, "mai spoon iz too big!!1")
obj = Cls()
# this should fail immediately
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
# ... leaving the cache empty
self.assertEqual(len(obj.fn.cache.cache), 0)
# and a second call should result in a second exception
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
def test_cache_logcontexts(self):
"""Check that logcontexts are set and restored correctly when
using the cache."""
@ -222,6 +248,9 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(LoggingContext.current_context(), c1)
# the cache should now be empty
self.assertEqual(len(obj.fn.cache.cache), 0)
obj = Cls()
# set off a deferred which will do a cache lookup
@ -268,6 +297,61 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r, "chips")
obj.mock.assert_not_called()
def test_cache_iterable(self):
class Cls(object):
def __init__(self):
self.mock = mock.Mock()
@descriptors.cached(iterable=True)
def fn(self, arg1, arg2):
return self.mock(arg1, arg2)
obj = Cls()
obj.mock.return_value = ["spam", "eggs"]
r = obj.fn(1, 2)
self.assertEqual(r, ["spam", "eggs"])
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = ["chips"]
r = obj.fn(1, 3)
self.assertEqual(r, ["chips"])
obj.mock.assert_called_once_with(1, 3)
obj.mock.reset_mock()
# the two values should now be cached
self.assertEqual(len(obj.fn.cache.cache), 3)
r = obj.fn(1, 2)
self.assertEqual(r, ["spam", "eggs"])
r = obj.fn(1, 3)
self.assertEqual(r, ["chips"])
obj.mock.assert_not_called()
def test_cache_iterable_with_sync_exception(self):
"""If the wrapped function throws synchronously, things should continue to work
"""
class Cls(object):
@descriptors.cached(iterable=True)
def fn(self, arg1):
raise SynapseError(100, "mai spoon iz too big!!1")
obj = Cls()
# this should fail immediately
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
# ... leaving the cache empty
self.assertEqual(len(obj.fn.cache.cache), 0)
# and a second call should result in a second exception
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks

View file

@ -126,7 +126,6 @@ def default_config(name, parse=False):
"enable_registration": True,
"enable_registration_captcha": False,
"macaroon_secret_key": "not even a little secret",
"expire_access_token": False,
"trusted_third_party_id_servers": [],
"room_invite_state_types": [],
"password_providers": [],
@ -471,7 +470,7 @@ class MockHttpResource(HttpServer):
raise KeyError("No event can handle %s" % path)
def register_paths(self, method, path_patterns, callback):
def register_paths(self, method, path_patterns, callback, servlet_name):
for path_pattern in path_patterns:
self.callbacks.append((method, path_pattern, callback))