flake8 and make API more compact & docced

This commit is contained in:
Matthew Hodgson 2018-09-07 22:59:38 +01:00
commit 967fdfef10
78 changed files with 908 additions and 757 deletions

View file

@ -3,6 +3,5 @@ Dockerfile
.gitignore .gitignore
demo/etc demo/etc
tox.ini tox.ini
synctl
.git/* .git/*
.tox/* .tox/*

1
.gitignore vendored
View file

@ -44,6 +44,7 @@ media_store/
build/ build/
venv/ venv/
venv*/ venv*/
*venv/
localhost-800*/ localhost-800*/
static/client/register/register_config.js static/client/register/register_config.js

View file

@ -8,9 +8,6 @@ before_script:
- git remote set-branches --add origin develop - git remote set-branches --add origin develop
- git fetch origin develop - git fetch origin develop
services:
- postgresql
matrix: matrix:
fast_finish: true fast_finish: true
include: include:
@ -25,6 +22,8 @@ matrix:
- python: 2.7 - python: 2.7
env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4" env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4"
services:
- postgresql
- python: 3.6 - python: 3.6
env: TOX_ENV=py36 env: TOX_ENV=py36

View file

@ -1,3 +1,77 @@
Synapse 0.33.4 (2018-09-07)
===========================
Internal Changes
----------------
- Unignore synctl in .dockerignore to fix docker builds ([\#3802](https://github.com/matrix-org/synapse/issues/3802))
Synapse 0.33.4rc2 (2018-09-06)
==============================
Pull in security fixes from v0.33.3.1
Synapse 0.33.3.1 (2018-09-06)
=============================
SECURITY FIXES
--------------
- Fix an issue where event signatures were not always correctly validated ([\#3796](https://github.com/matrix-org/synapse/issues/3796))
- Fix an issue where server_acls could be circumvented for incoming events ([\#3796](https://github.com/matrix-org/synapse/issues/3796))
Internal Changes
----------------
- Unignore synctl in .dockerignore to fix docker builds ([\#3802](https://github.com/matrix-org/synapse/issues/3802))
Synapse 0.33.4rc1 (2018-09-04)
==============================
Features
--------
- Support profile API endpoints on workers ([\#3659](https://github.com/matrix-org/synapse/issues/3659))
- Server notices for resource limit blocking ([\#3680](https://github.com/matrix-org/synapse/issues/3680))
- Allow guests to use /rooms/:roomId/event/:eventId ([\#3724](https://github.com/matrix-org/synapse/issues/3724))
- Add mau_trial_days config param, so that users only get counted as MAU after N days. ([\#3749](https://github.com/matrix-org/synapse/issues/3749))
- Require twisted 17.1 or later (fixes [#3741](https://github.com/matrix-org/synapse/issues/3741)). ([\#3751](https://github.com/matrix-org/synapse/issues/3751))
Bugfixes
--------
- Fix error collecting prometheus metrics when run on dedicated thread due to threading concurrency issues ([\#3722](https://github.com/matrix-org/synapse/issues/3722))
- Fix bug where we resent "limit exceeded" server notices repeatedly ([\#3747](https://github.com/matrix-org/synapse/issues/3747))
- Fix bug where we broke sync when using limit_usage_by_mau but hadn't configured server notices ([\#3753](https://github.com/matrix-org/synapse/issues/3753))
- Fix 'federation_domain_whitelist' such that an empty list correctly blocks all outbound federation traffic ([\#3754](https://github.com/matrix-org/synapse/issues/3754))
- Fix tagging of server notice rooms ([\#3755](https://github.com/matrix-org/synapse/issues/3755), [\#3756](https://github.com/matrix-org/synapse/issues/3756))
- Fix 'admin_uri' config variable and error parameter to be 'admin_contact' to match the spec. ([\#3758](https://github.com/matrix-org/synapse/issues/3758))
- Don't return non-LL-member state in incremental sync state blocks ([\#3760](https://github.com/matrix-org/synapse/issues/3760))
- Fix bug in sending presence over federation ([\#3768](https://github.com/matrix-org/synapse/issues/3768))
- Fix bug where preserved threepid user comes to sign up and server is mau blocked ([\#3777](https://github.com/matrix-org/synapse/issues/3777))
Internal Changes
----------------
- Removed the link to the unmaintained matrix-synapse-auto-deploy project from the readme. ([\#3378](https://github.com/matrix-org/synapse/issues/3378))
- Refactor state module to support multiple room versions ([\#3673](https://github.com/matrix-org/synapse/issues/3673))
- The synapse.storage module has been ported to Python 3. ([\#3725](https://github.com/matrix-org/synapse/issues/3725))
- Split the state_group_cache into member and non-member state events (and so speed up LL /sync) ([\#3726](https://github.com/matrix-org/synapse/issues/3726))
- Log failure to authenticate remote servers as warnings (without stack traces) ([\#3727](https://github.com/matrix-org/synapse/issues/3727))
- The CONTRIBUTING guidelines have been updated to mention our use of Markdown and that .misc files have content. ([\#3730](https://github.com/matrix-org/synapse/issues/3730))
- Reference the need for an HTTP replication port when using the federation_reader worker ([\#3734](https://github.com/matrix-org/synapse/issues/3734))
- Fix minor spelling error in federation client documentation. ([\#3735](https://github.com/matrix-org/synapse/issues/3735))
- Remove redundant state resolution function ([\#3737](https://github.com/matrix-org/synapse/issues/3737))
- The test suite now passes on PostgreSQL. ([\#3740](https://github.com/matrix-org/synapse/issues/3740))
- Fix MAU cache invalidation due to missing yield ([\#3746](https://github.com/matrix-org/synapse/issues/3746))
- Make sure that we close db connections opened during init ([\#3764](https://github.com/matrix-org/synapse/issues/3764))
Synapse 0.33.3 (2018-08-22) Synapse 0.33.3 (2018-08-22)
=========================== ===========================

View file

@ -742,6 +742,18 @@ so an example nginx configuration might look like::
} }
} }
and an example apache configuration may look like::
<VirtualHost *:443>
SSLEngine on
ServerName matrix.example.com;
<Location /_matrix>
ProxyPass http://127.0.0.1:8008/_matrix nocanon
ProxyPassReverse http://127.0.0.1:8008/_matrix
</Location>
</VirtualHost>
You will also want to set ``bind_addresses: ['127.0.0.1']`` and ``x_forwarded: true`` You will also want to set ``bind_addresses: ['127.0.0.1']`` and ``x_forwarded: true``
for port 8008 in ``homeserver.yaml`` to ensure that client IP addresses are for port 8008 in ``homeserver.yaml`` to ensure that client IP addresses are
recorded correctly. recorded correctly.

View file

@ -1 +0,0 @@
Removed the link to the unmaintained matrix-synapse-auto-deploy project from the readme.

View file

@ -1 +0,0 @@
Support profile API endpoints on workers

View file

@ -1 +0,0 @@
Refactor state module to support multiple room versions

View file

@ -1 +0,0 @@
Server notices for resource limit blocking

View file

@ -1 +0,0 @@
Fix error collecting prometheus metrics when run on dedicated thread due to threading concurrency issues

View file

@ -1 +0,0 @@
Allow guests to use /rooms/:roomId/event/:eventId

View file

@ -1 +0,0 @@
The synapse.storage module has been ported to Python 3.

View file

@ -1 +0,0 @@
Split the state_group_cache into member and non-member state events (and so speed up LL /sync)

View file

@ -1 +0,0 @@
Log failure to authenticate remote servers as warnings (without stack traces)

View file

@ -1 +0,0 @@
The CONTRIBUTING guidelines have been updated to mention our use of Markdown and that .misc files have content.

View file

@ -1 +0,0 @@
Reference the need for an HTTP replication port when using the federation_reader worker

View file

@ -1 +0,0 @@
Fix minor spelling error in federation client documentation.

View file

@ -1 +0,0 @@
Remove redundant state resolution function

View file

@ -1 +0,0 @@
The test suite now passes on PostgreSQL.

View file

@ -1 +0,0 @@
Fix MAU cache invalidation due to missing yield

View file

@ -1 +0,0 @@
Fix bug where we resent "limit exceeded" server notices repeatedly

View file

@ -1 +0,0 @@
Add mau_trial_days config param, so that users only get counted as MAU after N days.

View file

@ -1 +0,0 @@
Require twisted 17.1 or later (fixes [#3741](https://github.com/matrix-org/synapse/issues/3741)).

View file

@ -1 +0,0 @@
Fix bug where we broke sync when using limit_usage_by_mau but hadn't configured server notices

View file

@ -1 +0,0 @@
Fix 'federation_domain_whitelist' such that an empty list correctly blocks all outbound federation traffic

View file

@ -1 +0,0 @@
Fix tagging of server notice rooms

View file

@ -1 +0,0 @@
Fix tagging of server notice rooms

View file

@ -1 +0,0 @@
Fix 'admin_uri' config variable and error parameter to be 'admin_contact' to match the spec.

View file

@ -1 +0,0 @@
Don't return non-LL-member state in incremental sync state blocks

View file

@ -1 +0,0 @@
Make sure that we close db connections opened during init

View file

@ -1 +0,0 @@
Fix bug in sending presence over federation

1
changelog.d/3771.misc Normal file
View file

@ -0,0 +1 @@
http/ is now ported to Python 3.

View file

@ -1 +0,0 @@
Fix bug where preserved threepid user comes to sign up and server is mau blocked

1
changelog.d/3788.bugfix Normal file
View file

@ -0,0 +1 @@
Remove connection ID for replication prometheus metrics, as it creates a large number of new series.

1
changelog.d/3790.feature Normal file
View file

@ -0,0 +1 @@
Implement `event_format` filter param in `/sync`

1
changelog.d/3795.misc Normal file
View file

@ -0,0 +1 @@
Make /sync slightly faster by avoiding needless copies

1
changelog.d/3800.bugfix Normal file
View file

@ -0,0 +1 @@
guest users should not be part of mau total

1
changelog.d/3803.misc Normal file
View file

@ -0,0 +1 @@
handlers/ is now ported to Python 3.

1
changelog.d/3804.bugfix Normal file
View file

@ -0,0 +1 @@
Bump dependency on pyopenssl 16.x, to avoid incompatibility with recent Twisted.

1
changelog.d/3805.misc Normal file
View file

@ -0,0 +1 @@
Limit the number of PDUs/EDUs per federation transaction

1
changelog.d/3806.misc Normal file
View file

@ -0,0 +1 @@
Only start postgres instance for postgres tests on Travis CI

1
changelog.d/3808.misc Normal file
View file

@ -0,0 +1 @@
tests/ is now ported to Python 3.

1
changelog.d/3810.bugfix Normal file
View file

@ -0,0 +1 @@
Fix existing room tags not coming down sync when joining a room

View file

@ -17,13 +17,14 @@ ignore =
[pep8] [pep8]
max-line-length = 90 max-line-length = 90
# W503 requires that binary operators be at the end, not start, of lines. Erik # W503 requires that binary operators be at the end, not start, of lines. Erik
# doesn't like it. E203 is contrary to PEP8. # doesn't like it. E203 is contrary to PEP8. E731 is silly.
ignore = W503,E203 ignore = W503,E203,E731
[flake8] [flake8]
# note that flake8 inherits the "ignore" settings from "pep8" (because it uses # note that flake8 inherits the "ignore" settings from "pep8" (because it uses
# pep8 to do those checks), but not the "max-line-length" setting # pep8 to do those checks), but not the "max-line-length" setting
max-line-length = 90 max-line-length = 90
ignore=W503,E203,E731
[isort] [isort]
line_length = 89 line_length = 89

View file

@ -17,4 +17,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.33.3" __version__ = "0.33.4"

View file

@ -251,6 +251,7 @@ class FilterCollection(object):
"include_leave", False "include_leave", False
) )
self.event_fields = filter_json.get("event_fields", []) self.event_fields = filter_json.get("event_fields", [])
self.event_format = filter_json.get("event_format", "client")
def __repr__(self): def __repr__(self):
return "<FilterCollection %s>" % (json.dumps(self._filter_json),) return "<FilterCollection %s>" % (json.dumps(self._filter_json),)

View file

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import urllib
from six.moves import urllib
from prometheus_client import Counter from prometheus_client import Counter
@ -98,7 +99,7 @@ class ApplicationServiceApi(SimpleHttpClient):
def query_user(self, service, user_id): def query_user(self, service, user_id):
if service.url is None: if service.url is None:
defer.returnValue(False) defer.returnValue(False)
uri = service.url + ("/users/%s" % urllib.quote(user_id)) uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
response = None response = None
try: try:
response = yield self.get_json(uri, { response = yield self.get_json(uri, {
@ -119,7 +120,7 @@ class ApplicationServiceApi(SimpleHttpClient):
def query_alias(self, service, alias): def query_alias(self, service, alias):
if service.url is None: if service.url is None:
defer.returnValue(False) defer.returnValue(False)
uri = service.url + ("/rooms/%s" % urllib.quote(alias)) uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
response = None response = None
try: try:
response = yield self.get_json(uri, { response = yield self.get_json(uri, {
@ -153,7 +154,7 @@ class ApplicationServiceApi(SimpleHttpClient):
service.url, service.url,
APP_SERVICE_PREFIX, APP_SERVICE_PREFIX,
kind, kind,
urllib.quote(protocol) urllib.parse.quote(protocol)
) )
try: try:
response = yield self.get_json(uri, fields) response = yield self.get_json(uri, fields)
@ -188,7 +189,7 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = "%s%s/thirdparty/protocol/%s" % ( uri = "%s%s/thirdparty/protocol/%s" % (
service.url, service.url,
APP_SERVICE_PREFIX, APP_SERVICE_PREFIX,
urllib.quote(protocol) urllib.parse.quote(protocol)
) )
try: try:
info = yield self.get_json(uri, {}) info = yield self.get_json(uri, {})
@ -228,7 +229,7 @@ class ApplicationServiceApi(SimpleHttpClient):
txn_id = str(txn_id) txn_id = str(txn_id)
uri = service.url + ("/transactions/%s" % uri = service.url + ("/transactions/%s" %
urllib.quote(txn_id)) urllib.parse.quote(txn_id))
try: try:
yield self.put_json( yield self.put_json(
uri=uri, uri=uri,

View file

@ -13,17 +13,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from collections import namedtuple
import six import six
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import DeferredList
from synapse.api.constants import MAX_DEPTH from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.crypto.event_signing import check_event_content_hash from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_dict from synapse.http.servlet import assert_params_in_dict
from synapse.types import get_domain_from_id
from synapse.util import logcontext, unwrapFirstError from synapse.util import logcontext, unwrapFirstError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -133,34 +136,25 @@ class FederationBase(object):
* throws a SynapseError if the signature check failed. * throws a SynapseError if the signature check failed.
The deferreds run their callbacks in the sentinel logcontext. The deferreds run their callbacks in the sentinel logcontext.
""" """
deferreds = _check_sigs_on_pdus(self.keyring, pdus)
redacted_pdus = [
prune_event(pdu)
for pdu in pdus
]
deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json())
for p in redacted_pdus
])
ctx = logcontext.LoggingContext.current_context() ctx = logcontext.LoggingContext.current_context()
def callback(_, pdu, redacted): def callback(_, pdu):
with logcontext.PreserveLoggingContext(ctx): with logcontext.PreserveLoggingContext(ctx):
if not check_event_content_hash(pdu): if not check_event_content_hash(pdu):
logger.warn( logger.warn(
"Event content has been tampered, redacting %s: %s", "Event content has been tampered, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json() pdu.event_id, pdu.get_pdu_json()
) )
return redacted return prune_event(pdu)
if self.spam_checker.check_event_for_spam(pdu): if self.spam_checker.check_event_for_spam(pdu):
logger.warn( logger.warn(
"Event contains spam, redacting %s: %s", "Event contains spam, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json() pdu.event_id, pdu.get_pdu_json()
) )
return redacted return prune_event(pdu)
return pdu return pdu
@ -173,16 +167,116 @@ class FederationBase(object):
) )
return failure return failure
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus): for deferred, pdu in zip(deferreds, pdus):
deferred.addCallbacks( deferred.addCallbacks(
callback, errback, callback, errback,
callbackArgs=[pdu, redacted], callbackArgs=[pdu],
errbackArgs=[pdu], errbackArgs=[pdu],
) )
return deferreds return deferreds
class PduToCheckSig(namedtuple("PduToCheckSig", [
"pdu", "redacted_pdu_json", "event_id_domain", "sender_domain", "deferreds",
])):
pass
def _check_sigs_on_pdus(keyring, pdus):
"""Check that the given events are correctly signed
Args:
keyring (synapse.crypto.Keyring): keyring object to do the checks
pdus (Collection[EventBase]): the events to be checked
Returns:
List[Deferred]: a Deferred for each event in pdus, which will either succeed if
the signatures are valid, or fail (with a SynapseError) if not.
"""
# (currently this is written assuming the v1 room structure; we'll probably want a
# separate function for checking v2 rooms)
# we want to check that the event is signed by:
#
# (a) the server which created the event_id
#
# (b) the sender's server.
#
# - except in the case of invites created from a 3pid invite, which are exempt
# from this check, because the sender has to match that of the original 3pid
# invite, but the event may come from a different HS, for reasons that I don't
# entirely grok (why do the senders have to match? and if they do, why doesn't the
# joining server ask the inviting server to do the switcheroo with
# exchange_third_party_invite?).
#
# That's pretty awful, since redacting such an invite will render it invalid
# (because it will then look like a regular invite without a valid signature),
# and signatures are *supposed* to be valid whether or not an event has been
# redacted. But this isn't the worst of the ways that 3pid invites are broken.
#
# let's start by getting the domain for each pdu, and flattening the event back
# to JSON.
pdus_to_check = [
PduToCheckSig(
pdu=p,
redacted_pdu_json=prune_event(p).get_pdu_json(),
event_id_domain=get_domain_from_id(p.event_id),
sender_domain=get_domain_from_id(p.sender),
deferreds=[],
)
for p in pdus
]
# first make sure that the event is signed by the event_id's domain
deferreds = keyring.verify_json_objects_for_server([
(p.event_id_domain, p.redacted_pdu_json)
for p in pdus_to_check
])
for p, d in zip(pdus_to_check, deferreds):
p.deferreds.append(d)
# now let's look for events where the sender's domain is different to the
# event id's domain (normally only the case for joins/leaves), and add additional
# checks.
pdus_to_check_sender = [
p for p in pdus_to_check
if p.sender_domain != p.event_id_domain and not _is_invite_via_3pid(p.pdu)
]
more_deferreds = keyring.verify_json_objects_for_server([
(p.sender_domain, p.redacted_pdu_json)
for p in pdus_to_check_sender
])
for p, d in zip(pdus_to_check_sender, more_deferreds):
p.deferreds.append(d)
# replace lists of deferreds with single Deferreds
return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
def _flatten_deferred_list(deferreds):
"""Given a list of one or more deferreds, either return the single deferred, or
combine into a DeferredList.
"""
if len(deferreds) > 1:
return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True)
else:
assert len(deferreds) == 1
return deferreds[0]
def _is_invite_via_3pid(event):
return (
event.type == EventTypes.Member
and event.membership == Membership.INVITE
and "third_party_invite" in event.content
)
def event_from_pdu_json(pdu_json, outlier=False): def event_from_pdu_json(pdu_json, outlier=False):
"""Construct a FrozenEvent from an event json received over federation """Construct a FrozenEvent from an event json received over federation

View file

@ -99,7 +99,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_incoming_transaction(self, transaction_data): def on_incoming_transaction(self, origin, transaction_data):
# keep this as early as possible to make the calculated origin ts as # keep this as early as possible to make the calculated origin ts as
# accurate as possible. # accurate as possible.
request_time = self._clock.time_msec() request_time = self._clock.time_msec()
@ -108,34 +108,33 @@ class FederationServer(FederationBase):
if not transaction.transaction_id: if not transaction.transaction_id:
raise Exception("Transaction missing transaction_id") raise Exception("Transaction missing transaction_id")
if not transaction.origin:
raise Exception("Transaction missing origin")
logger.debug("[%s] Got transaction", transaction.transaction_id) logger.debug("[%s] Got transaction", transaction.transaction_id)
# use a linearizer to ensure that we don't process the same transaction # use a linearizer to ensure that we don't process the same transaction
# multiple times in parallel. # multiple times in parallel.
with (yield self._transaction_linearizer.queue( with (yield self._transaction_linearizer.queue(
(transaction.origin, transaction.transaction_id), (origin, transaction.transaction_id),
)): )):
result = yield self._handle_incoming_transaction( result = yield self._handle_incoming_transaction(
transaction, request_time, origin, transaction, request_time,
) )
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_incoming_transaction(self, transaction, request_time): def _handle_incoming_transaction(self, origin, transaction, request_time):
""" Process an incoming transaction and return the HTTP response """ Process an incoming transaction and return the HTTP response
Args: Args:
origin (unicode): the server making the request
transaction (Transaction): incoming transaction transaction (Transaction): incoming transaction
request_time (int): timestamp that the HTTP request arrived at request_time (int): timestamp that the HTTP request arrived at
Returns: Returns:
Deferred[(int, object)]: http response code and body Deferred[(int, object)]: http response code and body
""" """
response = yield self.transaction_actions.have_responded(transaction) response = yield self.transaction_actions.have_responded(origin, transaction)
if response: if response:
logger.debug( logger.debug(
@ -149,7 +148,7 @@ class FederationServer(FederationBase):
received_pdus_counter.inc(len(transaction.pdus)) received_pdus_counter.inc(len(transaction.pdus))
origin_host, _ = parse_server_name(transaction.origin) origin_host, _ = parse_server_name(origin)
pdus_by_room = {} pdus_by_room = {}
@ -190,7 +189,7 @@ class FederationServer(FederationBase):
event_id = pdu.event_id event_id = pdu.event_id
try: try:
yield self._handle_received_pdu( yield self._handle_received_pdu(
transaction.origin, pdu origin, pdu
) )
pdu_results[event_id] = {} pdu_results[event_id] = {}
except FederationError as e: except FederationError as e:
@ -212,7 +211,7 @@ class FederationServer(FederationBase):
if hasattr(transaction, "edus"): if hasattr(transaction, "edus"):
for edu in (Edu(**x) for x in transaction.edus): for edu in (Edu(**x) for x in transaction.edus):
yield self.received_edu( yield self.received_edu(
transaction.origin, origin,
edu.edu_type, edu.edu_type,
edu.content edu.content
) )
@ -224,6 +223,7 @@ class FederationServer(FederationBase):
logger.debug("Returning: %s", str(response)) logger.debug("Returning: %s", str(response))
yield self.transaction_actions.set_response( yield self.transaction_actions.set_response(
origin,
transaction, transaction,
200, response 200, response
) )
@ -838,9 +838,9 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
) )
return self._send_edu( return self._send_edu(
edu_type=edu_type, edu_type=edu_type,
origin=origin, origin=origin,
content=content, content=content,
) )
def on_query(self, query_type, args): def on_query(self, query_type, args):
@ -851,6 +851,6 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
return handler(args) return handler(args)
return self._get_query_client( return self._get_query_client(
query_type=query_type, query_type=query_type,
args=args, args=args,
) )

View file

@ -36,7 +36,7 @@ class TransactionActions(object):
self.store = datastore self.store = datastore
@log_function @log_function
def have_responded(self, transaction): def have_responded(self, origin, transaction):
""" Have we already responded to a transaction with the same id and """ Have we already responded to a transaction with the same id and
origin? origin?
@ -50,11 +50,11 @@ class TransactionActions(object):
"transaction_id") "transaction_id")
return self.store.get_received_txn_response( return self.store.get_received_txn_response(
transaction.transaction_id, transaction.origin transaction.transaction_id, origin
) )
@log_function @log_function
def set_response(self, transaction, code, response): def set_response(self, origin, transaction, code, response):
""" Persist how we responded to a transaction. """ Persist how we responded to a transaction.
Returns: Returns:
@ -66,7 +66,7 @@ class TransactionActions(object):
return self.store.set_received_txn_response( return self.store.set_received_txn_response(
transaction.transaction_id, transaction.transaction_id,
transaction.origin, origin,
code, code,
response, response,
) )

View file

@ -463,7 +463,19 @@ class TransactionQueue(object):
# pending_transactions flag. # pending_transactions flag.
pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
# We can only include at most 50 PDUs per transactions
pending_pdus, leftover_pdus = pending_pdus[:50], pending_pdus[50:]
if leftover_pdus:
self.pending_pdus_by_dest[destination] = leftover_pdus
pending_edus = self.pending_edus_by_dest.pop(destination, []) pending_edus = self.pending_edus_by_dest.pop(destination, [])
# We can only include at most 100 EDUs per transactions
pending_edus, leftover_edus = pending_edus[:100], pending_edus[100:]
if leftover_edus:
self.pending_edus_by_dest[destination] = leftover_edus
pending_presence = self.pending_presence_by_dest.pop(destination, {}) pending_presence = self.pending_presence_by_dest.pop(destination, {})
pending_edus.extend( pending_edus.extend(

View file

@ -353,7 +353,7 @@ class FederationSendServlet(BaseFederationServlet):
try: try:
code, response = yield self.handler.on_incoming_transaction( code, response = yield self.handler.on_incoming_transaction(
transaction_data origin, transaction_data,
) )
except Exception: except Exception:
logger.exception("on_incoming_transaction failed") logger.exception("on_incoming_transaction failed")

View file

@ -895,22 +895,24 @@ class AuthHandler(BaseHandler):
Args: Args:
password (unicode): Password to hash. password (unicode): Password to hash.
stored_hash (unicode): Expected hash value. stored_hash (bytes): Expected hash value.
Returns: Returns:
Deferred(bool): Whether self.hash(password) == stored_hash. Deferred(bool): Whether self.hash(password) == stored_hash.
""" """
def _do_validate_hash(): def _do_validate_hash():
# Normalise the Unicode in the password # Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password) pw = unicodedata.normalize("NFKC", password)
return bcrypt.checkpw( return bcrypt.checkpw(
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
stored_hash.encode('utf8') stored_hash
) )
if stored_hash: if stored_hash:
if not isinstance(stored_hash, bytes):
stored_hash = stored_hash.encode('ascii')
return make_deferred_yieldable( return make_deferred_yieldable(
threads.deferToThreadPool( threads.deferToThreadPool(
self.hs.get_reactor(), self.hs.get_reactor(),

View file

@ -330,7 +330,8 @@ class E2eKeysHandler(object):
(algorithm, key_id, ex_json, key) (algorithm, key_id, ex_json, key)
) )
else: else:
new_keys.append((algorithm, key_id, encode_canonical_json(key))) new_keys.append((
algorithm, key_id, encode_canonical_json(key).decode('ascii')))
yield self.store.add_e2e_one_time_keys( yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, new_keys user_id, device_id, time_now, new_keys
@ -358,7 +359,7 @@ def _exception_to_failure(e):
# Note that some Exceptions (notably twisted's ResponseFailed etc) don't # Note that some Exceptions (notably twisted's ResponseFailed etc) don't
# give a string for e.message, which json then fails to serialize. # give a string for e.message, which json then fails to serialize.
return { return {
"status": 503, "message": str(e.message), "status": 503, "message": str(e),
} }

View file

@ -594,7 +594,7 @@ class FederationHandler(BaseHandler):
required_auth = set( required_auth = set(
a_id a_id
for event in events + state_events.values() + auth_events.values() for event in events + list(state_events.values()) + list(auth_events.values())
for a_id, _ in event.auth_events for a_id, _ in event.auth_events
) )
auth_events.update({ auth_events.update({
@ -802,7 +802,7 @@ class FederationHandler(BaseHandler):
) )
continue continue
except NotRetryingDestination as e: except NotRetryingDestination as e:
logger.info(e.message) logger.info(str(e))
continue continue
except FederationDeniedError as e: except FederationDeniedError as e:
logger.info(e) logger.info(e)
@ -1358,7 +1358,7 @@ class FederationHandler(BaseHandler):
) )
if state_groups: if state_groups:
_, state = state_groups.items().pop() _, state = list(state_groups.items()).pop()
results = state results = state
if event.is_state(): if event.is_state():

View file

@ -162,7 +162,7 @@ class RoomListHandler(BaseHandler):
# Filter out rooms that we don't want to return # Filter out rooms that we don't want to return
rooms_to_scan = [ rooms_to_scan = [
r for r in sorted_rooms r for r in sorted_rooms
if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0 if r not in newly_unpublished and rooms_to_num_joined[r] > 0
] ]
total_room_count = len(rooms_to_scan) total_room_count = len(rooms_to_scan)

View file

@ -54,7 +54,7 @@ class SearchHandler(BaseHandler):
batch_token = None batch_token = None
if batch: if batch:
try: try:
b = decode_base64(batch) b = decode_base64(batch).decode('ascii')
batch_group, batch_group_key, batch_token = b.split("\n") batch_group, batch_group_key, batch_token = b.split("\n")
assert batch_group is not None assert batch_group is not None
@ -258,18 +258,18 @@ class SearchHandler(BaseHandler):
# it returns more from the same group (if applicable) rather # it returns more from the same group (if applicable) rather
# than reverting to searching all results again. # than reverting to searching all results again.
if batch_group and batch_group_key: if batch_group and batch_group_key:
global_next_batch = encode_base64("%s\n%s\n%s" % ( global_next_batch = encode_base64(("%s\n%s\n%s" % (
batch_group, batch_group_key, pagination_token batch_group, batch_group_key, pagination_token
)) )).encode('ascii'))
else: else:
global_next_batch = encode_base64("%s\n%s\n%s" % ( global_next_batch = encode_base64(("%s\n%s\n%s" % (
"all", "", pagination_token "all", "", pagination_token
)) )).encode('ascii'))
for room_id, group in room_groups.items(): for room_id, group in room_groups.items():
group["next_batch"] = encode_base64("%s\n%s\n%s" % ( group["next_batch"] = encode_base64(("%s\n%s\n%s" % (
"room_id", room_id, pagination_token "room_id", room_id, pagination_token
)) )).encode('ascii'))
allowed_events.extend(room_events) allowed_events.extend(room_events)

View file

@ -553,22 +553,36 @@ class SyncHandler(object):
summary = {} summary = {}
# TODO: only send these when they change. # TODO: only send these when they change.
summary["m.joined_member_count"] = details.get(Membership.JOIN, {}).get('count', 0) summary["m.joined_member_count"] = (
summary["m.invited_member_count"] = details.get(Membership.INVITE, {}).get('count', 0) details.get(Membership.JOIN, ([], 0))[1]
)
summary["m.invited_member_count"] = (
details.get(Membership.INVITE, ([], 0))[1]
)
if name_id or canonical_alias_id: if name_id or canonical_alias_id:
defer.returnValue(summary) defer.returnValue(summary)
joined_user_ids = [r[0] for r in details.get(Membership.JOIN, {}).get('users', [])] joined_user_ids = [
invited_user_ids = [r[0] for r in details.get(Membership.INVITE, {}).get('users', [])] r[0] for r in details.get(Membership.JOIN, ([], 0))[0]
]
invited_user_ids = [
r[0] for r in details.get(Membership.INVITE, ([], 0))[0]
]
gone_user_ids = ( gone_user_ids = (
[r[0] for r in details.get(Membership.LEAVE, {}).get('users', [])] + [r[0] for r in details.get(Membership.LEAVE, ([], 0))[0]] +
[r[0] for r in details.get(Membership.BAN, {}).get('users', [])] [r[0] for r in details.get(Membership.BAN, ([], 0))[0]]
) )
# FIXME: only build up a member_ids list for our heroes
member_ids = {} member_ids = {}
for m in (Membership.JOIN, Membership.INVITE, Membership.LEAVE, Membership.BAN): for m in (
for r in details.get(m, {}).get('users', []): Membership.JOIN,
Membership.INVITE,
Membership.LEAVE,
Membership.BAN
):
for r in details.get(m, ([], 0))[0]:
member_ids[r[0]] = r[1] member_ids[r[0]] = r[1]
# FIXME: order by stream ordering rather than as returned by SQL # FIXME: order by stream ordering rather than as returned by SQL
@ -775,7 +789,7 @@ class SyncHandler(object):
logger.debug("filtering state from %r...", state_ids) logger.debug("filtering state from %r...", state_ids)
state_ids = { state_ids = {
t: event_id t: event_id
for t, event_id in state_ids.iteritems() for t, event_id in iteritems(state_ids)
if cache.get(t[1]) != event_id if cache.get(t[1]) != event_id
} }
logger.debug("...to %r", state_ids) logger.debug("...to %r", state_ids)
@ -1576,6 +1590,19 @@ class SyncHandler(object):
newly_joined_room=newly_joined, newly_joined_room=newly_joined,
) )
# When we join the room (or the client requests full_state), we should
# send down any existing tags. Usually the user won't have tags in a
# newly joined room, unless either a) they've joined before or b) the
# tag was added by synapse e.g. for server notice rooms.
if full_state:
user_id = sync_result_builder.sync_config.user.to_string()
tags = yield self.store.get_tags_for_room(user_id, room_id)
# If there aren't any tags, don't send the empty tags list down
# sync
if not tags:
tags = None
account_data_events = [] account_data_events = []
if tags is not None: if tags is not None:
account_data_events.append({ account_data_events.append({
@ -1608,7 +1635,12 @@ class SyncHandler(object):
sync_config.filter_collection.lazy_load_members() and sync_config.filter_collection.lazy_load_members() and
( (
any(ev.type == EventTypes.Member for ev in batch.events) or any(ev.type == EventTypes.Member for ev in batch.events) or
#(batch.limited and any(ev.type == EventTypes.Member for ev in state)) or (
# XXX: this may include false positives in the form of LL
# members which have snuck into state
batch.limited and
any(t == EventTypes.Member for (t, k) in state.keys())
) or
since_token is None since_token is None
) )
): ):
@ -1731,17 +1763,17 @@ def _calculate_state(
event_id_to_key = { event_id_to_key = {
e: key e: key
for key, e in itertools.chain( for key, e in itertools.chain(
timeline_contains.items(), iteritems(timeline_contains),
previous.items(), iteritems(previous),
timeline_start.items(), iteritems(timeline_start),
current.items(), iteritems(current),
) )
} }
c_ids = set(e for e in current.values()) c_ids = set(e for e in itervalues(current))
ts_ids = set(e for e in timeline_start.values()) ts_ids = set(e for e in itervalues(timeline_start))
p_ids = set(e for e in previous.values()) p_ids = set(e for e in itervalues(previous))
tc_ids = set(e for e in timeline_contains.values()) tc_ids = set(e for e in itervalues(timeline_contains))
# If we are lazyloading room members, we explicitly add the membership events # If we are lazyloading room members, we explicitly add the membership events
# for the senders in the timeline into the state block returned by /sync, # for the senders in the timeline into the state block returned by /sync,
@ -1755,7 +1787,7 @@ def _calculate_state(
if lazy_load_members: if lazy_load_members:
p_ids.difference_update( p_ids.difference_update(
e for t, e in timeline_start.iteritems() e for t, e in iteritems(timeline_start)
if t[0] == EventTypes.Member if t[0] == EventTypes.Member
) )

View file

@ -13,24 +13,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import urllib
from six import StringIO from six import text_type
from six.moves import urllib
import treq
from canonicaljson import encode_canonical_json, json from canonicaljson import encode_canonical_json, json
from prometheus_client import Counter from prometheus_client import Counter
from OpenSSL import SSL from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, protocol, reactor, ssl, task from twisted.internet import defer, protocol, reactor, ssl
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web._newclient import ResponseDone from twisted.web._newclient import ResponseDone
from twisted.web.client import ( from twisted.web.client import (
Agent, Agent,
BrowserLikeRedirectAgent, BrowserLikeRedirectAgent,
ContentDecoderAgent, ContentDecoderAgent,
FileBodyProducer as TwistedFileBodyProducer,
GzipDecoder, GzipDecoder,
HTTPConnectionPool, HTTPConnectionPool,
PartialDownloadError, PartialDownloadError,
@ -83,18 +84,20 @@ class SimpleHttpClient(object):
if hs.config.user_agent_suffix: if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,) self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,)
self.user_agent = self.user_agent.encode('ascii')
@defer.inlineCallbacks @defer.inlineCallbacks
def request(self, method, uri, *args, **kwargs): def request(self, method, uri, data=b'', headers=None):
# A small wrapper around self.agent.request() so we can easily attach # A small wrapper around self.agent.request() so we can easily attach
# counters to it # counters to it
outgoing_requests_counter.labels(method).inc() outgoing_requests_counter.labels(method).inc()
# log request but strip `access_token` (AS requests for example include this) # log request but strip `access_token` (AS requests for example include this)
logger.info("Sending request %s %s", method, redact_uri(uri)) logger.info("Sending request %s %s", method, redact_uri(uri.encode('ascii')))
try: try:
request_deferred = self.agent.request( request_deferred = treq.request(
method, uri, *args, **kwargs method, uri, agent=self.agent, data=data, headers=headers
) )
add_timeout_to_deferred( add_timeout_to_deferred(
request_deferred, 60, self.hs.get_reactor(), request_deferred, 60, self.hs.get_reactor(),
@ -105,14 +108,14 @@ class SimpleHttpClient(object):
incoming_responses_counter.labels(method, response.code).inc() incoming_responses_counter.labels(method, response.code).inc()
logger.info( logger.info(
"Received response to %s %s: %s", "Received response to %s %s: %s",
method, redact_uri(uri), response.code method, redact_uri(uri.encode('ascii')), response.code
) )
defer.returnValue(response) defer.returnValue(response)
except Exception as e: except Exception as e:
incoming_responses_counter.labels(method, "ERR").inc() incoming_responses_counter.labels(method, "ERR").inc()
logger.info( logger.info(
"Error sending request to %s %s: %s %s", "Error sending request to %s %s: %s %s",
method, redact_uri(uri), type(e).__name__, e.message method, redact_uri(uri.encode('ascii')), type(e).__name__, e.args[0]
) )
raise raise
@ -137,7 +140,8 @@ class SimpleHttpClient(object):
# TODO: Do we ever want to log message contents? # TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args) logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(encode_urlencode_args(args), True) query_bytes = urllib.parse.urlencode(
encode_urlencode_args(args), True).encode("utf8")
actual_headers = { actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"], b"Content-Type": [b"application/x-www-form-urlencoded"],
@ -148,15 +152,14 @@ class SimpleHttpClient(object):
response = yield self.request( response = yield self.request(
"POST", "POST",
uri.encode("ascii"), uri,
headers=Headers(actual_headers), headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(query_bytes)) data=query_bytes
) )
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300: if 200 <= response.code < 300:
defer.returnValue(json.loads(body)) body = yield make_deferred_yieldable(treq.json_content(response))
defer.returnValue(body)
else: else:
raise HttpResponseException(response.code, response.phrase, body) raise HttpResponseException(response.code, response.phrase, body)
@ -191,9 +194,9 @@ class SimpleHttpClient(object):
response = yield self.request( response = yield self.request(
"POST", "POST",
uri.encode("ascii"), uri,
headers=Headers(actual_headers), headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str)) data=json_str
) )
body = yield make_deferred_yieldable(readBody(response)) body = yield make_deferred_yieldable(readBody(response))
@ -248,7 +251,7 @@ class SimpleHttpClient(object):
ValueError: if the response was not JSON ValueError: if the response was not JSON
""" """
if len(args): if len(args):
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes) uri = "%s?%s" % (uri, query_bytes)
json_str = encode_canonical_json(json_body) json_str = encode_canonical_json(json_body)
@ -262,9 +265,9 @@ class SimpleHttpClient(object):
response = yield self.request( response = yield self.request(
"PUT", "PUT",
uri.encode("ascii"), uri,
headers=Headers(actual_headers), headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str)) data=json_str
) )
body = yield make_deferred_yieldable(readBody(response)) body = yield make_deferred_yieldable(readBody(response))
@ -293,7 +296,7 @@ class SimpleHttpClient(object):
HttpResponseException on a non-2xx HTTP response. HttpResponseException on a non-2xx HTTP response.
""" """
if len(args): if len(args):
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes) uri = "%s?%s" % (uri, query_bytes)
actual_headers = { actual_headers = {
@ -304,7 +307,7 @@ class SimpleHttpClient(object):
response = yield self.request( response = yield self.request(
"GET", "GET",
uri.encode("ascii"), uri,
headers=Headers(actual_headers), headers=Headers(actual_headers),
) )
@ -339,7 +342,7 @@ class SimpleHttpClient(object):
response = yield self.request( response = yield self.request(
"GET", "GET",
url.encode("ascii"), url,
headers=Headers(actual_headers), headers=Headers(actual_headers),
) )
@ -434,12 +437,12 @@ class CaptchaServerHttpClient(SimpleHttpClient):
@defer.inlineCallbacks @defer.inlineCallbacks
def post_urlencoded_get_raw(self, url, args={}): def post_urlencoded_get_raw(self, url, args={}):
query_bytes = urllib.urlencode(encode_urlencode_args(args), True) query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True)
response = yield self.request( response = yield self.request(
"POST", "POST",
url.encode("ascii"), url,
bodyProducer=FileBodyProducer(StringIO(query_bytes)), data=query_bytes,
headers=Headers({ headers=Headers({
b"Content-Type": [b"application/x-www-form-urlencoded"], b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.user_agent], b"User-Agent": [self.user_agent],
@ -510,7 +513,7 @@ def encode_urlencode_args(args):
def encode_urlencode_arg(arg): def encode_urlencode_arg(arg):
if isinstance(arg, unicode): if isinstance(arg, text_type):
return arg.encode('utf-8') return arg.encode('utf-8')
elif isinstance(arg, list): elif isinstance(arg, list):
return [encode_urlencode_arg(i) for i in arg] return [encode_urlencode_arg(i) for i in arg]
@ -542,26 +545,3 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
def creatorForNetloc(self, hostname, port): def creatorForNetloc(self, hostname, port):
return self return self
class FileBodyProducer(TwistedFileBodyProducer):
"""Workaround for https://twistedmatrix.com/trac/ticket/8473
We override the pauseProducing and resumeProducing methods in twisted's
FileBodyProducer so that they do not raise exceptions if the task has
already completed.
"""
def pauseProducing(self):
try:
super(FileBodyProducer, self).pauseProducing()
except task.TaskDone:
# task has already completed
pass
def resumeProducing(self):
try:
super(FileBodyProducer, self).resumeProducing()
except task.NotPaused:
# task was not paused (probably because it had already completed)
pass

View file

@ -17,19 +17,19 @@ import cgi
import logging import logging
import random import random
import sys import sys
import urllib
from six import string_types from six import PY3, string_types
from six.moves.urllib import parse as urlparse from six.moves import urllib
from canonicaljson import encode_canonical_json, json import treq
from canonicaljson import encode_canonical_json
from prometheus_client import Counter from prometheus_client import Counter
from signedjson.sign import sign_json from signedjson.sign import sign_json
from twisted.internet import defer, protocol, reactor from twisted.internet import defer, protocol, reactor
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.web._newclient import ResponseDone from twisted.web._newclient import ResponseDone
from twisted.web.client import Agent, HTTPConnectionPool, readBody from twisted.web.client import Agent, HTTPConnectionPool
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
import synapse.metrics import synapse.metrics
@ -58,13 +58,18 @@ incoming_responses_counter = Counter("synapse_http_matrixfederationclient_respon
MAX_LONG_RETRIES = 10 MAX_LONG_RETRIES = 10
MAX_SHORT_RETRIES = 3 MAX_SHORT_RETRIES = 3
if PY3:
MAXINT = sys.maxsize
else:
MAXINT = sys.maxint
class MatrixFederationEndpointFactory(object): class MatrixFederationEndpointFactory(object):
def __init__(self, hs): def __init__(self, hs):
self.tls_client_options_factory = hs.tls_client_options_factory self.tls_client_options_factory = hs.tls_client_options_factory
def endpointForURI(self, uri): def endpointForURI(self, uri):
destination = uri.netloc destination = uri.netloc.decode('ascii')
return matrix_federation_endpoint( return matrix_federation_endpoint(
reactor, destination, timeout=10, reactor, destination, timeout=10,
@ -93,26 +98,32 @@ class MatrixFederationHttpClient(object):
) )
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._store = hs.get_datastore() self._store = hs.get_datastore()
self.version_string = hs.version_string self.version_string = hs.version_string.encode('ascii')
self._next_id = 1 self._next_id = 1
def _create_url(self, destination, path_bytes, param_bytes, query_bytes): def _create_url(self, destination, path_bytes, param_bytes, query_bytes):
return urlparse.urlunparse( return urllib.parse.urlunparse(
("matrix", destination, path_bytes, param_bytes, query_bytes, "") (b"matrix", destination, path_bytes, param_bytes, query_bytes, b"")
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _request(self, destination, method, path, def _request(self, destination, method, path,
body_callback, headers_dict={}, param_bytes=b"", json=None, json_callback=None,
query_bytes=b"", retry_on_dns_fail=True, param_bytes=b"",
query=None, retry_on_dns_fail=True,
timeout=None, long_retries=False, timeout=None, long_retries=False,
ignore_backoff=False, ignore_backoff=False,
backoff_on_404=False): backoff_on_404=False):
""" Creates and sends a request to the given server """
Creates and sends a request to the given server.
Args: Args:
destination (str): The remote server to send the HTTP request to. destination (str): The remote server to send the HTTP request to.
method (str): HTTP method method (str): HTTP method
path (str): The HTTP path path (str): The HTTP path
json (dict or None): JSON to send in the body.
json_callback (func or None): A callback to generate the JSON.
query (dict or None): Query arguments.
ignore_backoff (bool): true to ignore the historical backoff data ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway. and try the request anyway.
backoff_on_404 (bool): Back off if we get a 404 backoff_on_404 (bool): Back off if we get a 404
@ -146,22 +157,29 @@ class MatrixFederationHttpClient(object):
ignore_backoff=ignore_backoff, ignore_backoff=ignore_backoff,
) )
destination = destination.encode("ascii") headers_dict = {}
path_bytes = path.encode("ascii") path_bytes = path.encode("ascii")
with limiter: if query:
headers_dict[b"User-Agent"] = [self.version_string] query_bytes = encode_query_args(query)
headers_dict[b"Host"] = [destination] else:
query_bytes = b""
url_bytes = self._create_url( headers_dict = {
destination, path_bytes, param_bytes, query_bytes "User-Agent": [self.version_string],
) "Host": [destination],
}
with limiter:
url = self._create_url(
destination.encode("ascii"), path_bytes, param_bytes, query_bytes
).decode('ascii')
txn_id = "%s-O-%s" % (method, self._next_id) txn_id = "%s-O-%s" % (method, self._next_id)
self._next_id = (self._next_id + 1) % (sys.maxint - 1) self._next_id = (self._next_id + 1) % (MAXINT - 1)
outbound_logger.info( outbound_logger.info(
"{%s} [%s] Sending request: %s %s", "{%s} [%s] Sending request: %s %s",
txn_id, destination, method, url_bytes txn_id, destination, method, url
) )
# XXX: Would be much nicer to retry only at the transaction-layer # XXX: Would be much nicer to retry only at the transaction-layer
@ -171,23 +189,33 @@ class MatrixFederationHttpClient(object):
else: else:
retries_left = MAX_SHORT_RETRIES retries_left = MAX_SHORT_RETRIES
http_url_bytes = urlparse.urlunparse( http_url = urllib.parse.urlunparse(
("", "", path_bytes, param_bytes, query_bytes, "") (b"", b"", path_bytes, param_bytes, query_bytes, b"")
) ).decode('ascii')
log_result = None log_result = None
try: try:
while True: while True:
producer = None
if body_callback:
producer = body_callback(method, http_url_bytes, headers_dict)
try: try:
request_deferred = self.agent.request( if json_callback:
json = json_callback()
if json:
data = encode_canonical_json(json)
headers_dict["Content-Type"] = ["application/json"]
self.sign_request(
destination, method, http_url, headers_dict, json
)
else:
data = None
self.sign_request(destination, method, http_url, headers_dict)
request_deferred = treq.request(
method, method,
url_bytes, url,
Headers(headers_dict), headers=Headers(headers_dict),
producer data=data,
agent=self.agent,
) )
add_timeout_to_deferred( add_timeout_to_deferred(
request_deferred, request_deferred,
@ -218,7 +246,7 @@ class MatrixFederationHttpClient(object):
txn_id, txn_id,
destination, destination,
method, method,
url_bytes, url,
_flatten_response_never_received(e), _flatten_response_never_received(e),
) )
@ -252,7 +280,7 @@ class MatrixFederationHttpClient(object):
# :'( # :'(
# Update transactions table? # Update transactions table?
with logcontext.PreserveLoggingContext(): with logcontext.PreserveLoggingContext():
body = yield readBody(response) body = yield treq.content(response)
raise HttpResponseException( raise HttpResponseException(
response.code, response.phrase, body response.code, response.phrase, body
) )
@ -297,11 +325,11 @@ class MatrixFederationHttpClient(object):
auth_headers = [] auth_headers = []
for key, sig in request["signatures"][self.server_name].items(): for key, sig in request["signatures"][self.server_name].items():
auth_headers.append(bytes( auth_headers.append((
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
self.server_name, key, sig, self.server_name, key, sig,
) )).encode('ascii')
)) )
headers_dict[b"Authorization"] = auth_headers headers_dict[b"Authorization"] = auth_headers
@ -347,24 +375,14 @@ class MatrixFederationHttpClient(object):
""" """
if not json_data_callback: if not json_data_callback:
def json_data_callback(): json_data_callback = lambda: data
return data
def body_callback(method, url_bytes, headers_dict):
json_data = json_data_callback()
self.sign_request(
destination, method, url_bytes, headers_dict, json_data
)
producer = _JsonProducer(json_data)
return producer
response = yield self._request( response = yield self._request(
destination, destination,
"PUT", "PUT",
path, path,
body_callback=body_callback, json_callback=json_data_callback,
headers_dict={"Content-Type": ["application/json"]}, query=args,
query_bytes=encode_query_args(args),
long_retries=long_retries, long_retries=long_retries,
timeout=timeout, timeout=timeout,
ignore_backoff=ignore_backoff, ignore_backoff=ignore_backoff,
@ -376,8 +394,8 @@ class MatrixFederationHttpClient(object):
check_content_type_is_json(response.headers) check_content_type_is_json(response.headers)
with logcontext.PreserveLoggingContext(): with logcontext.PreserveLoggingContext():
body = yield readBody(response) body = yield treq.json_content(response)
defer.returnValue(json.loads(body)) defer.returnValue(body)
@defer.inlineCallbacks @defer.inlineCallbacks
def post_json(self, destination, path, data={}, long_retries=False, def post_json(self, destination, path, data={}, long_retries=False,
@ -410,20 +428,12 @@ class MatrixFederationHttpClient(object):
Fails with ``FederationDeniedError`` if this destination Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist is not on our federation whitelist
""" """
def body_callback(method, url_bytes, headers_dict):
self.sign_request(
destination, method, url_bytes, headers_dict, data
)
return _JsonProducer(data)
response = yield self._request( response = yield self._request(
destination, destination,
"POST", "POST",
path, path,
query_bytes=encode_query_args(args), query=args,
body_callback=body_callback, json=data,
headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries, long_retries=long_retries,
timeout=timeout, timeout=timeout,
ignore_backoff=ignore_backoff, ignore_backoff=ignore_backoff,
@ -434,9 +444,9 @@ class MatrixFederationHttpClient(object):
check_content_type_is_json(response.headers) check_content_type_is_json(response.headers)
with logcontext.PreserveLoggingContext(): with logcontext.PreserveLoggingContext():
body = yield readBody(response) body = yield treq.json_content(response)
defer.returnValue(json.loads(body)) defer.returnValue(body)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, destination, path, args=None, retry_on_dns_fail=True, def get_json(self, destination, path, args=None, retry_on_dns_fail=True,
@ -471,16 +481,11 @@ class MatrixFederationHttpClient(object):
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail) logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict):
self.sign_request(destination, method, url_bytes, headers_dict)
return None
response = yield self._request( response = yield self._request(
destination, destination,
"GET", "GET",
path, path,
query_bytes=encode_query_args(args), query=args,
body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail, retry_on_dns_fail=retry_on_dns_fail,
timeout=timeout, timeout=timeout,
ignore_backoff=ignore_backoff, ignore_backoff=ignore_backoff,
@ -491,9 +496,9 @@ class MatrixFederationHttpClient(object):
check_content_type_is_json(response.headers) check_content_type_is_json(response.headers)
with logcontext.PreserveLoggingContext(): with logcontext.PreserveLoggingContext():
body = yield readBody(response) body = yield treq.json_content(response)
defer.returnValue(json.loads(body)) defer.returnValue(body)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_json(self, destination, path, long_retries=False, def delete_json(self, destination, path, long_retries=False,
@ -523,13 +528,11 @@ class MatrixFederationHttpClient(object):
Fails with ``FederationDeniedError`` if this destination Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist is not on our federation whitelist
""" """
response = yield self._request( response = yield self._request(
destination, destination,
"DELETE", "DELETE",
path, path,
query_bytes=encode_query_args(args), query=args,
headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries, long_retries=long_retries,
timeout=timeout, timeout=timeout,
ignore_backoff=ignore_backoff, ignore_backoff=ignore_backoff,
@ -540,9 +543,9 @@ class MatrixFederationHttpClient(object):
check_content_type_is_json(response.headers) check_content_type_is_json(response.headers)
with logcontext.PreserveLoggingContext(): with logcontext.PreserveLoggingContext():
body = yield readBody(response) body = yield treq.json_content(response)
defer.returnValue(json.loads(body)) defer.returnValue(body)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_file(self, destination, path, output_stream, args={}, def get_file(self, destination, path, output_stream, args={},
@ -569,26 +572,11 @@ class MatrixFederationHttpClient(object):
Fails with ``FederationDeniedError`` if this destination Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist is not on our federation whitelist
""" """
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, string_types):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
query_bytes = urllib.urlencode(encoded_args, True)
logger.debug("Query bytes: %s Retry DNS: %s", query_bytes, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict):
self.sign_request(destination, method, url_bytes, headers_dict)
return None
response = yield self._request( response = yield self._request(
destination, destination,
"GET", "GET",
path, path,
query_bytes=query_bytes, query=args,
body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail, retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff, ignore_backoff=ignore_backoff,
) )
@ -639,30 +627,6 @@ def _readBodyToFile(response, stream, max_size):
return d return d
class _JsonProducer(object):
""" Used by the twisted http client to create the HTTP body from json
"""
def __init__(self, jsn):
self.reset(jsn)
def reset(self, jsn):
self.body = encode_canonical_json(jsn)
self.length = len(self.body)
def startProducing(self, consumer):
consumer.write(self.body)
return defer.succeed(None)
def pauseProducing(self):
pass
def stopProducing(self):
pass
def resumeProducing(self):
pass
def _flatten_response_never_received(e): def _flatten_response_never_received(e):
if hasattr(e, "reasons"): if hasattr(e, "reasons"):
reasons = ", ".join( reasons = ", ".join(
@ -693,7 +657,7 @@ def check_content_type_is_json(headers):
"No Content-Type header" "No Content-Type header"
) )
c_type = c_type[0] # only the first header c_type = c_type[0].decode('ascii') # only the first header
val, options = cgi.parse_header(c_type) val, options = cgi.parse_header(c_type)
if val != "application/json": if val != "application/json":
raise RuntimeError( raise RuntimeError(
@ -711,6 +675,6 @@ def encode_query_args(args):
vs = [vs] vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs] encoded_args[k] = [v.encode("UTF-8") for v in vs]
query_bytes = urllib.urlencode(encoded_args, True) query_bytes = urllib.parse.urlencode(encoded_args, True)
return query_bytes return query_bytes.encode('utf8')

View file

@ -204,14 +204,14 @@ class SynapseRequest(Request):
self.start_time = time.time() self.start_time = time.time()
self.request_metrics = RequestMetrics() self.request_metrics = RequestMetrics()
self.request_metrics.start( self.request_metrics.start(
self.start_time, name=servlet_name, method=self.method, self.start_time, name=servlet_name, method=self.method.decode('ascii'),
) )
self.site.access_logger.info( self.site.access_logger.info(
"%s - %s - Received request: %s %s", "%s - %s - Received request: %s %s",
self.getClientIP(), self.getClientIP(),
self.site.site_tag, self.site.site_tag,
self.method, self.method.decode('ascii'),
self.get_redacted_uri() self.get_redacted_uri()
) )

View file

@ -40,9 +40,10 @@ REQUIREMENTS = {
"pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"], "pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=17.1.0": ["twisted>=17.1.0"], "Twisted>=17.1.0": ["twisted>=17.1.0"],
"treq>=15.1": ["treq>=15.1"],
# We use crypto.get_elliptic_curve which is only supported in >=0.15 # Twisted has required pyopenssl 16.0 since about Twisted 16.6.
"pyopenssl>=0.15": ["OpenSSL>=0.15"], "pyopenssl>=16.0.0": ["OpenSSL>=16.0.0"],
"pyyaml": ["yaml"], "pyyaml": ["yaml"],
"pyasn1": ["pyasn1"], "pyasn1": ["pyasn1"],

View file

@ -590,9 +590,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
pending_commands = LaterGauge( pending_commands = LaterGauge(
"synapse_replication_tcp_protocol_pending_commands", "synapse_replication_tcp_protocol_pending_commands",
"", "",
["name", "conn_id"], ["name"],
lambda: { lambda: {
(p.name, p.conn_id): len(p.pending_commands) for p in connected_connections (p.name,): len(p.pending_commands) for p in connected_connections
}, },
) )
@ -607,9 +607,9 @@ def transport_buffer_size(protocol):
transport_send_buffer = LaterGauge( transport_send_buffer = LaterGauge(
"synapse_replication_tcp_protocol_transport_send_buffer", "synapse_replication_tcp_protocol_transport_send_buffer",
"", "",
["name", "conn_id"], ["name"],
lambda: { lambda: {
(p.name, p.conn_id): transport_buffer_size(p) for p in connected_connections (p.name,): transport_buffer_size(p) for p in connected_connections
}, },
) )
@ -632,9 +632,9 @@ def transport_kernel_read_buffer_size(protocol, read=True):
tcp_transport_kernel_send_buffer = LaterGauge( tcp_transport_kernel_send_buffer = LaterGauge(
"synapse_replication_tcp_protocol_transport_kernel_send_buffer", "synapse_replication_tcp_protocol_transport_kernel_send_buffer",
"", "",
["name", "conn_id"], ["name"],
lambda: { lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, False) (p.name,): transport_kernel_read_buffer_size(p, False)
for p in connected_connections for p in connected_connections
}, },
) )
@ -643,9 +643,9 @@ tcp_transport_kernel_send_buffer = LaterGauge(
tcp_transport_kernel_read_buffer = LaterGauge( tcp_transport_kernel_read_buffer = LaterGauge(
"synapse_replication_tcp_protocol_transport_kernel_read_buffer", "synapse_replication_tcp_protocol_transport_kernel_read_buffer",
"", "",
["name", "conn_id"], ["name"],
lambda: { lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, True) (p.name,): transport_kernel_read_buffer_size(p, True)
for p in connected_connections for p in connected_connections
}, },
) )
@ -654,9 +654,9 @@ tcp_transport_kernel_read_buffer = LaterGauge(
tcp_inbound_commands = LaterGauge( tcp_inbound_commands = LaterGauge(
"synapse_replication_tcp_protocol_inbound_commands", "synapse_replication_tcp_protocol_inbound_commands",
"", "",
["command", "name", "conn_id"], ["command", "name"],
lambda: { lambda: {
(k[0], p.name, p.conn_id): count (k[0], p.name,): count
for p in connected_connections for p in connected_connections
for k, count in iteritems(p.inbound_commands_counter) for k, count in iteritems(p.inbound_commands_counter)
}, },
@ -665,9 +665,9 @@ tcp_inbound_commands = LaterGauge(
tcp_outbound_commands = LaterGauge( tcp_outbound_commands = LaterGauge(
"synapse_replication_tcp_protocol_outbound_commands", "synapse_replication_tcp_protocol_outbound_commands",
"", "",
["command", "name", "conn_id"], ["command", "name"],
lambda: { lambda: {
(k[0], p.name, p.conn_id): count (k[0], p.name,): count
for p in connected_connections for p in connected_connections
for k, count in iteritems(p.outbound_commands_counter) for k, count in iteritems(p.outbound_commands_counter)
}, },

View file

@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.events.utils import ( from synapse.events.utils import (
format_event_for_client_v2_without_room_id, format_event_for_client_v2_without_room_id,
format_event_raw,
serialize_event, serialize_event,
) )
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
@ -175,17 +176,28 @@ class SyncRestServlet(RestServlet):
@staticmethod @staticmethod
def encode_response(time_now, sync_result, access_token_id, filter): def encode_response(time_now, sync_result, access_token_id, filter):
if filter.event_format == 'client':
event_formatter = format_event_for_client_v2_without_room_id
elif filter.event_format == 'federation':
event_formatter = format_event_raw
else:
raise Exception("Unknown event format %s" % (filter.event_format, ))
joined = SyncRestServlet.encode_joined( joined = SyncRestServlet.encode_joined(
sync_result.joined, time_now, access_token_id, filter.event_fields sync_result.joined, time_now, access_token_id,
filter.event_fields,
event_formatter,
) )
invited = SyncRestServlet.encode_invited( invited = SyncRestServlet.encode_invited(
sync_result.invited, time_now, access_token_id, sync_result.invited, time_now, access_token_id,
event_formatter,
) )
archived = SyncRestServlet.encode_archived( archived = SyncRestServlet.encode_archived(
sync_result.archived, time_now, access_token_id, sync_result.archived, time_now, access_token_id,
filter.event_fields, filter.event_fields,
event_formatter,
) )
return { return {
@ -228,7 +240,7 @@ class SyncRestServlet(RestServlet):
} }
@staticmethod @staticmethod
def encode_joined(rooms, time_now, token_id, event_fields): def encode_joined(rooms, time_now, token_id, event_fields, event_formatter):
""" """
Encode the joined rooms in a sync result Encode the joined rooms in a sync result
@ -240,7 +252,9 @@ class SyncRestServlet(RestServlet):
token_id(int): ID of the user's auth token - used for namespacing token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs of transaction IDs
event_fields(list<str>): List of event fields to include. If empty, event_fields(list<str>): List of event fields to include. If empty,
all fields will be returned. all fields will be returned.
event_formatter (func[dict]): function to convert from federation format
to client format
Returns: Returns:
dict[str, dict[str, object]]: the joined rooms list, in our dict[str, dict[str, object]]: the joined rooms list, in our
response format response format
@ -248,13 +262,14 @@ class SyncRestServlet(RestServlet):
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = SyncRestServlet.encode_room( joined[room.room_id] = SyncRestServlet.encode_room(
room, time_now, token_id, only_fields=event_fields room, time_now, token_id, joined=True, only_fields=event_fields,
event_formatter=event_formatter,
) )
return joined return joined
@staticmethod @staticmethod
def encode_invited(rooms, time_now, token_id): def encode_invited(rooms, time_now, token_id, event_formatter):
""" """
Encode the invited rooms in a sync result Encode the invited rooms in a sync result
@ -264,7 +279,9 @@ class SyncRestServlet(RestServlet):
time_now(int): current time - used as a baseline for age time_now(int): current time - used as a baseline for age
calculations calculations
token_id(int): ID of the user's auth token - used for namespacing token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs of transaction IDs
event_formatter (func[dict]): function to convert from federation format
to client format
Returns: Returns:
dict[str, dict[str, object]]: the invited rooms list, in our dict[str, dict[str, object]]: the invited rooms list, in our
@ -274,7 +291,7 @@ class SyncRestServlet(RestServlet):
for room in rooms: for room in rooms:
invite = serialize_event( invite = serialize_event(
room.invite, time_now, token_id=token_id, room.invite, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_room_id, event_format=event_formatter,
is_invite=True, is_invite=True,
) )
unsigned = dict(invite.get("unsigned", {})) unsigned = dict(invite.get("unsigned", {}))
@ -288,7 +305,7 @@ class SyncRestServlet(RestServlet):
return invited return invited
@staticmethod @staticmethod
def encode_archived(rooms, time_now, token_id, event_fields): def encode_archived(rooms, time_now, token_id, event_fields, event_formatter):
""" """
Encode the archived rooms in a sync result Encode the archived rooms in a sync result
@ -300,7 +317,9 @@ class SyncRestServlet(RestServlet):
token_id(int): ID of the user's auth token - used for namespacing token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs of transaction IDs
event_fields(list<str>): List of event fields to include. If empty, event_fields(list<str>): List of event fields to include. If empty,
all fields will be returned. all fields will be returned.
event_formatter (func[dict]): function to convert from federation format
to client format
Returns: Returns:
dict[str, dict[str, object]]: The invited rooms list, in our dict[str, dict[str, object]]: The invited rooms list, in our
response format response format
@ -308,13 +327,18 @@ class SyncRestServlet(RestServlet):
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = SyncRestServlet.encode_room( joined[room.room_id] = SyncRestServlet.encode_room(
room, time_now, token_id, joined=False, only_fields=event_fields room, time_now, token_id, joined=False,
only_fields=event_fields,
event_formatter=event_formatter,
) )
return joined return joined
@staticmethod @staticmethod
def encode_room(room, time_now, token_id, joined=True, only_fields=None): def encode_room(
room, time_now, token_id, joined,
only_fields, event_formatter,
):
""" """
Args: Args:
room (JoinedSyncResult|ArchivedSyncResult): sync result for a room (JoinedSyncResult|ArchivedSyncResult): sync result for a
@ -326,14 +350,15 @@ class SyncRestServlet(RestServlet):
joined (bool): True if the user is joined to this room - will mean joined (bool): True if the user is joined to this room - will mean
we handle ephemeral events we handle ephemeral events
only_fields(list<str>): Optional. The list of event fields to include. only_fields(list<str>): Optional. The list of event fields to include.
event_formatter (func[dict]): function to convert from federation format
to client format
Returns: Returns:
dict[str, object]: the room, encoded in our response format dict[str, object]: the room, encoded in our response format
""" """
def serialize(event): def serialize(event):
# TODO(mjark): Respect formatting requirements in the filter.
return serialize_event( return serialize_event(
event, time_now, token_id=token_id, event, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_room_id, event_format=event_formatter,
only_event_fields=only_fields, only_event_fields=only_fields,
) )

View file

@ -199,10 +199,14 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Args: Args:
user_id(str): the user_id to query user_id(str): the user_id to query
""" """
if self.hs.config.limit_usage_by_mau: if self.hs.config.limit_usage_by_mau:
# Trial users and guests should not be included as part of MAU group
is_guest = yield self.is_guest(user_id)
if is_guest:
return
is_trial = yield self.is_trial_user(user_id) is_trial = yield self.is_trial_user(user_id)
if is_trial: if is_trial:
# we don't track trial users in the MAU table.
return return
last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id)

View file

@ -471,6 +471,7 @@ class AuthTestCase(unittest.TestCase):
def test_reserved_threepid(self): def test_reserved_threepid(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 1 self.hs.config.max_mau_value = 1
self.store.get_monthly_active_count = lambda: defer.succeed(2)
threepid = {'medium': 'email', 'address': 'reserved@server.com'} threepid = {'medium': 'email', 'address': 'reserved@server.com'}
unknown_threepid = {'medium': 'email', 'address': 'unreserved@server.com'} unknown_threepid = {'medium': 'email', 'address': 'unreserved@server.com'}
self.hs.config.mau_limits_reserved_threepids = [threepid] self.hs.config.mau_limits_reserved_threepids = [threepid]

View file

@ -47,7 +47,7 @@ class FrontendProxyTests(HomeserverTestCase):
self.assertEqual(len(self.reactor.tcpServers), 1) self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1] site = self.reactor.tcpServers[0][1]
self.resource = ( self.resource = (
site.resource.children["_matrix"].children["client"].children["r0"] site.resource.children[b"_matrix"].children[b"client"].children[b"r0"]
) )
request, channel = self.make_request("PUT", "presence/a/status") request, channel = self.make_request("PUT", "presence/a/status")
@ -77,7 +77,7 @@ class FrontendProxyTests(HomeserverTestCase):
self.assertEqual(len(self.reactor.tcpServers), 1) self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1] site = self.reactor.tcpServers[0][1]
self.resource = ( self.resource = (
site.resource.children["_matrix"].children["client"].children["r0"] site.resource.children[b"_matrix"].children[b"client"].children[b"r0"]
) )
request, channel = self.make_request("PUT", "presence/a/status") request, channel = self.make_request("PUT", "presence/a/status")

View file

@ -33,7 +33,7 @@ from ..utils import (
) )
def _expect_edu(destination, edu_type, content, origin="test"): def _expect_edu_transaction(edu_type, content, origin="test"):
return { return {
"origin": origin, "origin": origin,
"origin_server_ts": 1000000, "origin_server_ts": 1000000,
@ -42,10 +42,8 @@ def _expect_edu(destination, edu_type, content, origin="test"):
} }
def _make_edu_json(origin, edu_type, content): def _make_edu_transaction_json(edu_type, content):
return json.dumps(_expect_edu("test", edu_type, content, origin=origin)).encode( return json.dumps(_expect_edu_transaction(edu_type, content)).encode('utf8')
'utf8'
)
class TypingNotificationsTestCase(unittest.TestCase): class TypingNotificationsTestCase(unittest.TestCase):
@ -190,8 +188,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
call( call(
"farm", "farm",
path="/_matrix/federation/v1/send/1000000/", path="/_matrix/federation/v1/send/1000000/",
data=_expect_edu( data=_expect_edu_transaction(
"farm",
"m.typing", "m.typing",
content={ content={
"room_id": self.room_id, "room_id": self.room_id,
@ -221,11 +218,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.assertEquals(self.event_source.get_current_key(), 0) self.assertEquals(self.event_source.get_current_key(), 0)
yield self.mock_federation_resource.trigger( (code, response) = yield self.mock_federation_resource.trigger(
"PUT", "PUT",
"/_matrix/federation/v1/send/1000000/", "/_matrix/federation/v1/send/1000000/",
_make_edu_json( _make_edu_transaction_json(
"farm",
"m.typing", "m.typing",
content={ content={
"room_id": self.room_id, "room_id": self.room_id,
@ -233,7 +229,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
"typing": True, "typing": True,
}, },
), ),
federation_auth=True, federation_auth_origin=b'farm',
) )
self.on_new_event.assert_has_calls( self.on_new_event.assert_has_calls(
@ -264,8 +260,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
call( call(
"farm", "farm",
path="/_matrix/federation/v1/send/1000000/", path="/_matrix/federation/v1/send/1000000/",
data=_expect_edu( data=_expect_edu_transaction(
"farm",
"m.typing", "m.typing",
content={ content={
"room_id": self.room_id, "room_id": self.room_id,

View file

@ -22,39 +22,24 @@ from six.moves.urllib import parse as urlparse
from twisted.internet import defer from twisted.internet import defer
import synapse.rest.client.v1.room
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.http.server import JsonResource from synapse.rest.client.v1 import room
from synapse.types import UserID
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import (
ThreadedMemoryReactorClock,
make_request,
render,
setup_test_homeserver,
)
from .utils import RestHelper
PATH_PREFIX = b"/_matrix/client/api/v1" PATH_PREFIX = b"/_matrix/client/api/v1"
class RoomBase(unittest.TestCase): class RoomBase(unittest.HomeserverTestCase):
rmcreator_id = None rmcreator_id = None
def setUp(self): servlets = [room.register_servlets, room.register_deprecated_servlets]
self.clock = ThreadedMemoryReactorClock() def make_homeserver(self, reactor, clock):
self.hs_clock = Clock(self.clock)
self.hs = setup_test_homeserver( self.hs = self.setup_test_homeserver(
self.addCleanup,
"red", "red",
http_client=None, http_client=None,
clock=self.hs_clock,
reactor=self.clock,
federation_client=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]), ratelimiter=NonCallableMock(spec_set=["send_message"]),
) )
@ -63,42 +48,21 @@ class RoomBase(unittest.TestCase):
self.hs.get_federation_handler = Mock(return_value=Mock()) self.hs.get_federation_handler = Mock(return_value=Mock())
def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.helper.auth_user_id),
"token_id": 1,
"is_guest": False,
}
def get_user_by_req(request, allow_guest=False, rights="access"):
return synapse.types.create_requester(
UserID.from_string(self.helper.auth_user_id), 1, False, None
)
self.hs.get_auth().get_user_by_req = get_user_by_req
self.hs.get_auth().get_user_by_access_token = get_user_by_access_token
self.hs.get_auth().get_access_token_from_request = Mock(return_value=b"1234")
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
self.hs.get_datastore().insert_client_ip = _insert_client_ip self.hs.get_datastore().insert_client_ip = _insert_client_ip
self.resource = JsonResource(self.hs) return self.hs
synapse.rest.client.v1.room.register_servlets(self.hs, self.resource)
synapse.rest.client.v1.room.register_deprecated_servlets(self.hs, self.resource)
self.helper = RestHelper(self.hs, self.resource, self.user_id)
class RoomPermissionsTestCase(RoomBase): class RoomPermissionsTestCase(RoomBase):
""" Tests room permissions. """ """ Tests room permissions. """
user_id = b"@sid1:red" user_id = "@sid1:red"
rmcreator_id = b"@notme:red" rmcreator_id = "@notme:red"
def setUp(self): def prepare(self, reactor, clock, hs):
super(RoomPermissionsTestCase, self).setUp()
self.helper.auth_user_id = self.rmcreator_id self.helper.auth_user_id = self.rmcreator_id
# create some rooms under the name rmcreator_id # create some rooms under the name rmcreator_id
@ -114,22 +78,20 @@ class RoomPermissionsTestCase(RoomBase):
self.created_rmid_msg_path = ( self.created_rmid_msg_path = (
"rooms/%s/send/m.room.message/a1" % (self.created_rmid) "rooms/%s/send/m.room.message/a1" % (self.created_rmid)
).encode('ascii') ).encode('ascii')
request, channel = make_request( request, channel = self.make_request(
b"PUT", "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}'
self.created_rmid_msg_path,
b'{"msgtype":"m.text","body":"test msg"}',
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(200, channel.code, channel.result)
# set topic for public room # set topic for public room
request, channel = make_request( request, channel = self.make_request(
b"PUT", "PUT",
("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode('ascii'), ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode('ascii'),
b'{"topic":"Public Room Topic"}', b'{"topic":"Public Room Topic"}',
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(200, channel.code, channel.result)
# auth as user_id now # auth as user_id now
self.helper.auth_user_id = self.user_id self.helper.auth_user_id = self.user_id
@ -140,128 +102,128 @@ class RoomPermissionsTestCase(RoomBase):
seq = iter(range(100)) seq = iter(range(100))
def send_msg_path(): def send_msg_path():
return b"/rooms/%s/send/m.room.message/mid%s" % ( return "/rooms/%s/send/m.room.message/mid%s" % (
self.created_rmid, self.created_rmid,
str(next(seq)).encode('ascii'), str(next(seq)),
) )
# send message in uncreated room, expect 403 # send message in uncreated room, expect 403
request, channel = make_request( request, channel = self.make_request(
b"PUT", "PUT",
b"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
msg_content, msg_content,
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room not joined (no state), expect 403 # send message in created room not joined (no state), expect 403
request, channel = make_request(b"PUT", send_msg_path(), msg_content) request, channel = self.make_request("PUT", send_msg_path(), msg_content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room and invited, expect 403 # send message in created room and invited, expect 403
self.helper.invite( self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
) )
request, channel = make_request(b"PUT", send_msg_path(), msg_content) request, channel = self.make_request("PUT", send_msg_path(), msg_content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room and joined, expect 200 # send message in created room and joined, expect 200
self.helper.join(room=self.created_rmid, user=self.user_id) self.helper.join(room=self.created_rmid, user=self.user_id)
request, channel = make_request(b"PUT", send_msg_path(), msg_content) request, channel = self.make_request("PUT", send_msg_path(), msg_content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
# send message in created room and left, expect 403 # send message in created room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id) self.helper.leave(room=self.created_rmid, user=self.user_id)
request, channel = make_request(b"PUT", send_msg_path(), msg_content) request, channel = self.make_request("PUT", send_msg_path(), msg_content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_topic_perms(self): def test_topic_perms(self):
topic_content = b'{"topic":"My Topic Name"}' topic_content = b'{"topic":"My Topic Name"}'
topic_path = b"/rooms/%s/state/m.room.topic" % self.created_rmid topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid
# set/get topic in uncreated room, expect 403 # set/get topic in uncreated room, expect 403
request, channel = make_request( request, channel = self.make_request(
b"PUT", b"/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
request, channel = make_request( request, channel = self.make_request(
b"GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room not joined, expect 403 # set/get topic in created PRIVATE room not joined, expect 403
request, channel = make_request(b"PUT", topic_path, topic_content) request, channel = self.make_request("PUT", topic_path, topic_content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
request, channel = make_request(b"GET", topic_path) request, channel = self.make_request("GET", topic_path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
# set topic in created PRIVATE room and invited, expect 403 # set topic in created PRIVATE room and invited, expect 403
self.helper.invite( self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
) )
request, channel = make_request(b"PUT", topic_path, topic_content) request, channel = self.make_request("PUT", topic_path, topic_content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
# get topic in created PRIVATE room and invited, expect 403 # get topic in created PRIVATE room and invited, expect 403
request, channel = make_request(b"GET", topic_path) request, channel = self.make_request("GET", topic_path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room and joined, expect 200 # set/get topic in created PRIVATE room and joined, expect 200
self.helper.join(room=self.created_rmid, user=self.user_id) self.helper.join(room=self.created_rmid, user=self.user_id)
# Only room ops can set topic by default # Only room ops can set topic by default
self.helper.auth_user_id = self.rmcreator_id self.helper.auth_user_id = self.rmcreator_id
request, channel = make_request(b"PUT", topic_path, topic_content) request, channel = self.make_request("PUT", topic_path, topic_content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
self.helper.auth_user_id = self.user_id self.helper.auth_user_id = self.user_id
request, channel = make_request(b"GET", topic_path) request, channel = self.make_request("GET", topic_path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(topic_content), channel.json_body) self.assert_dict(json.loads(topic_content.decode('utf8')), channel.json_body)
# set/get topic in created PRIVATE room and left, expect 403 # set/get topic in created PRIVATE room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id) self.helper.leave(room=self.created_rmid, user=self.user_id)
request, channel = make_request(b"PUT", topic_path, topic_content) request, channel = self.make_request("PUT", topic_path, topic_content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
request, channel = make_request(b"GET", topic_path) request, channel = self.make_request("GET", topic_path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
# get topic in PUBLIC room, not joined, expect 403 # get topic in PUBLIC room, not joined, expect 403
request, channel = make_request( request, channel = self.make_request(
b"GET", b"/rooms/%s/state/m.room.topic" % self.created_public_rmid "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
# set topic in PUBLIC room, not joined, expect 403 # set topic in PUBLIC room, not joined, expect 403
request, channel = make_request( request, channel = self.make_request(
b"PUT", "PUT",
b"/rooms/%s/state/m.room.topic" % self.created_public_rmid, "/rooms/%s/state/m.room.topic" % self.created_public_rmid,
topic_content, topic_content,
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
def _test_get_membership(self, room=None, members=[], expect_code=None): def _test_get_membership(self, room=None, members=[], expect_code=None):
for member in members: for member in members:
path = b"/rooms/%s/state/m.room.member/%s" % (room, member) path = "/rooms/%s/state/m.room.member/%s" % (room, member)
request, channel = make_request(b"GET", path) request, channel = self.make_request("GET", path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(expect_code, int(channel.result["code"])) self.assertEquals(expect_code, channel.code)
def test_membership_basic_room_perms(self): def test_membership_basic_room_perms(self):
# === room does not exist === # === room does not exist ===
@ -428,217 +390,211 @@ class RoomPermissionsTestCase(RoomBase):
class RoomsMemberListTestCase(RoomBase): class RoomsMemberListTestCase(RoomBase):
""" Tests /rooms/$room_id/members/list REST events.""" """ Tests /rooms/$room_id/members/list REST events."""
user_id = b"@sid1:red" user_id = "@sid1:red"
def test_get_member_list(self): def test_get_member_list(self):
room_id = self.helper.create_room_as(self.user_id) room_id = self.helper.create_room_as(self.user_id)
request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id) request, channel = self.make_request("GET", "/rooms/%s/members" % room_id)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
def test_get_member_list_no_room(self): def test_get_member_list_no_room(self):
request, channel = make_request(b"GET", b"/rooms/roomdoesnotexist/members") request, channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission(self): def test_get_member_list_no_permission(self):
room_id = self.helper.create_room_as(b"@some_other_guy:red") room_id = self.helper.create_room_as("@some_other_guy:red")
request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id) request, channel = self.make_request("GET", "/rooms/%s/members" % room_id)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_get_member_list_mixed_memberships(self): def test_get_member_list_mixed_memberships(self):
room_creator = b"@some_other_guy:red" room_creator = "@some_other_guy:red"
room_id = self.helper.create_room_as(room_creator) room_id = self.helper.create_room_as(room_creator)
room_path = b"/rooms/%s/members" % room_id room_path = "/rooms/%s/members" % room_id
self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) self.helper.invite(room=room_id, src=room_creator, targ=self.user_id)
# can't see list if you're just invited. # can't see list if you're just invited.
request, channel = make_request(b"GET", room_path) request, channel = self.make_request("GET", room_path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
self.helper.join(room=room_id, user=self.user_id) self.helper.join(room=room_id, user=self.user_id)
# can see list now joined # can see list now joined
request, channel = make_request(b"GET", room_path) request, channel = self.make_request("GET", room_path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
self.helper.leave(room=room_id, user=self.user_id) self.helper.leave(room=room_id, user=self.user_id)
# can see old list once left # can see old list once left
request, channel = make_request(b"GET", room_path) request, channel = self.make_request("GET", room_path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
class RoomsCreateTestCase(RoomBase): class RoomsCreateTestCase(RoomBase):
""" Tests /rooms and /rooms/$room_id REST events. """ """ Tests /rooms and /rooms/$room_id REST events. """
user_id = b"@sid1:red" user_id = "@sid1:red"
def test_post_room_no_keys(self): def test_post_room_no_keys(self):
# POST with no config keys, expect new room id # POST with no config keys, expect new room id
request, channel = make_request(b"POST", b"/createRoom", b"{}") request, channel = self.make_request("POST", "/createRoom", "{}")
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), channel.result) self.assertEquals(200, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body) self.assertTrue("room_id" in channel.json_body)
def test_post_room_visibility_key(self): def test_post_room_visibility_key(self):
# POST with visibility config key, expect new room id # POST with visibility config key, expect new room id
request, channel = make_request( request, channel = self.make_request(
b"POST", b"/createRoom", b'{"visibility":"private"}' "POST", "/createRoom", b'{"visibility":"private"}'
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"])) self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body) self.assertTrue("room_id" in channel.json_body)
def test_post_room_custom_key(self): def test_post_room_custom_key(self):
# POST with custom config keys, expect new room id # POST with custom config keys, expect new room id
request, channel = make_request(b"POST", b"/createRoom", b'{"custom":"stuff"}') request, channel = self.make_request(
render(request, self.resource, self.clock) "POST", "/createRoom", b'{"custom":"stuff"}'
self.assertEquals(200, int(channel.result["code"])) )
self.render(request)
self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body) self.assertTrue("room_id" in channel.json_body)
def test_post_room_known_and_unknown_keys(self): def test_post_room_known_and_unknown_keys(self):
# POST with custom + known config keys, expect new room id # POST with custom + known config keys, expect new room id
request, channel = make_request( request, channel = self.make_request(
b"POST", b"/createRoom", b'{"visibility":"private","custom":"things"}' "POST", "/createRoom", b'{"visibility":"private","custom":"things"}'
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"])) self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body) self.assertTrue("room_id" in channel.json_body)
def test_post_room_invalid_content(self): def test_post_room_invalid_content(self):
# POST with invalid content / paths, expect 400 # POST with invalid content / paths, expect 400
request, channel = make_request(b"POST", b"/createRoom", b'{"visibili') request, channel = self.make_request("POST", "/createRoom", b'{"visibili')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"])) self.assertEquals(400, channel.code)
request, channel = make_request(b"POST", b"/createRoom", b'["hello"]') request, channel = self.make_request("POST", "/createRoom", b'["hello"]')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"])) self.assertEquals(400, channel.code)
class RoomTopicTestCase(RoomBase): class RoomTopicTestCase(RoomBase):
""" Tests /rooms/$room_id/topic REST events. """ """ Tests /rooms/$room_id/topic REST events. """
user_id = b"@sid1:red" user_id = "@sid1:red"
def setUp(self):
super(RoomTopicTestCase, self).setUp()
def prepare(self, reactor, clock, hs):
# create the room # create the room
self.room_id = self.helper.create_room_as(self.user_id) self.room_id = self.helper.create_room_as(self.user_id)
self.path = b"/rooms/%s/state/m.room.topic" % (self.room_id,) self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,)
def test_invalid_puts(self): def test_invalid_puts(self):
# missing keys or invalid json # missing keys or invalid json
request, channel = make_request(b"PUT", self.path, '{}') request, channel = self.make_request("PUT", self.path, '{}')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", self.path, '{"_name":"bob"}') request, channel = self.make_request("PUT", self.path, '{"_name":"bo"}')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", self.path, '{"nao') request, channel = self.make_request("PUT", self.path, '{"nao')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request( request, channel = self.make_request(
b"PUT", self.path, '[{"_name":"bob"},{"_name":"jill"}]' "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]'
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", self.path, 'text only') request, channel = self.make_request("PUT", self.path, 'text only')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", self.path, '') request, channel = self.make_request("PUT", self.path, '')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
# valid key, wrong type # valid key, wrong type
content = '{"topic":["Topic name"]}' content = '{"topic":["Topic name"]}'
request, channel = make_request(b"PUT", self.path, content) request, channel = self.make_request("PUT", self.path, content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_topic(self): def test_rooms_topic(self):
# nothing should be there # nothing should be there
request, channel = make_request(b"GET", self.path) request, channel = self.make_request("GET", self.path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(404, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(404, channel.code, msg=channel.result["body"])
# valid put # valid put
content = '{"topic":"Topic name"}' content = '{"topic":"Topic name"}'
request, channel = make_request(b"PUT", self.path, content) request, channel = self.make_request("PUT", self.path, content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
# valid get # valid get
request, channel = make_request(b"GET", self.path) request, channel = self.make_request("GET", self.path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body) self.assert_dict(json.loads(content), channel.json_body)
def test_rooms_topic_with_extra_keys(self): def test_rooms_topic_with_extra_keys(self):
# valid put with extra keys # valid put with extra keys
content = '{"topic":"Seasons","subtopic":"Summer"}' content = '{"topic":"Seasons","subtopic":"Summer"}'
request, channel = make_request(b"PUT", self.path, content) request, channel = self.make_request("PUT", self.path, content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
# valid get # valid get
request, channel = make_request(b"GET", self.path) request, channel = self.make_request("GET", self.path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body) self.assert_dict(json.loads(content), channel.json_body)
class RoomMemberStateTestCase(RoomBase): class RoomMemberStateTestCase(RoomBase):
""" Tests /rooms/$room_id/members/$user_id/state REST events. """ """ Tests /rooms/$room_id/members/$user_id/state REST events. """
user_id = b"@sid1:red" user_id = "@sid1:red"
def setUp(self): def prepare(self, reactor, clock, hs):
super(RoomMemberStateTestCase, self).setUp()
self.room_id = self.helper.create_room_as(self.user_id) self.room_id = self.helper.create_room_as(self.user_id)
def tearDown(self):
pass
def test_invalid_puts(self): def test_invalid_puts(self):
path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
# missing keys or invalid json # missing keys or invalid json
request, channel = make_request(b"PUT", path, '{}') request, channel = self.make_request("PUT", path, '{}')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, '{"_name":"bob"}') request, channel = self.make_request("PUT", path, '{"_name":"bo"}')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, '{"nao') request, channel = self.make_request("PUT", path, '{"nao')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request( request, channel = self.make_request(
b"PUT", path, b'[{"_name":"bob"},{"_name":"jill"}]' "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]'
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, 'text only') request, channel = self.make_request("PUT", path, 'text only')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, '') request, channel = self.make_request("PUT", path, '')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
# valid keys, wrong types # valid keys, wrong types
content = '{"membership":["%s","%s","%s"]}' % ( content = '{"membership":["%s","%s","%s"]}' % (
@ -646,9 +602,9 @@ class RoomMemberStateTestCase(RoomBase):
Membership.JOIN, Membership.JOIN,
Membership.LEAVE, Membership.LEAVE,
) )
request, channel = make_request(b"PUT", path, content.encode('ascii')) request, channel = self.make_request("PUT", path, content.encode('ascii'))
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_members_self(self): def test_rooms_members_self(self):
path = "/rooms/%s/state/m.room.member/%s" % ( path = "/rooms/%s/state/m.room.member/%s" % (
@ -658,13 +614,13 @@ class RoomMemberStateTestCase(RoomBase):
# valid join message (NOOP since we made the room) # valid join message (NOOP since we made the room)
content = '{"membership":"%s"}' % Membership.JOIN content = '{"membership":"%s"}' % Membership.JOIN
request, channel = make_request(b"PUT", path, content.encode('ascii')) request, channel = self.make_request("PUT", path, content.encode('ascii'))
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
request, channel = make_request(b"GET", path, None) request, channel = self.make_request("GET", path, None)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
expected_response = {"membership": Membership.JOIN} expected_response = {"membership": Membership.JOIN}
self.assertEquals(expected_response, channel.json_body) self.assertEquals(expected_response, channel.json_body)
@ -678,13 +634,13 @@ class RoomMemberStateTestCase(RoomBase):
# valid invite message # valid invite message
content = '{"membership":"%s"}' % Membership.INVITE content = '{"membership":"%s"}' % Membership.INVITE
request, channel = make_request(b"PUT", path, content) request, channel = self.make_request("PUT", path, content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
request, channel = make_request(b"GET", path, None) request, channel = self.make_request("GET", path, None)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assertEquals(json.loads(content), channel.json_body) self.assertEquals(json.loads(content), channel.json_body)
def test_rooms_members_other_custom_keys(self): def test_rooms_members_other_custom_keys(self):
@ -699,13 +655,13 @@ class RoomMemberStateTestCase(RoomBase):
Membership.INVITE, Membership.INVITE,
"Join us!", "Join us!",
) )
request, channel = make_request(b"PUT", path, content) request, channel = self.make_request("PUT", path, content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
request, channel = make_request(b"GET", path, None) request, channel = self.make_request("GET", path, None)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assertEquals(json.loads(content), channel.json_body) self.assertEquals(json.loads(content), channel.json_body)
@ -714,60 +670,58 @@ class RoomMessagesTestCase(RoomBase):
user_id = "@sid1:red" user_id = "@sid1:red"
def setUp(self): def prepare(self, reactor, clock, hs):
super(RoomMessagesTestCase, self).setUp()
self.room_id = self.helper.create_room_as(self.user_id) self.room_id = self.helper.create_room_as(self.user_id)
def test_invalid_puts(self): def test_invalid_puts(self):
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
# missing keys or invalid json # missing keys or invalid json
request, channel = make_request(b"PUT", path, '{}') request, channel = self.make_request("PUT", path, b'{}')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, '{"_name":"bob"}') request, channel = self.make_request("PUT", path, b'{"_name":"bo"}')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, '{"nao') request, channel = self.make_request("PUT", path, b'{"nao')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request( request, channel = self.make_request(
b"PUT", path, '[{"_name":"bob"},{"_name":"jill"}]' "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]'
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, 'text only') request, channel = self.make_request("PUT", path, b'text only')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, '') request, channel = self.make_request("PUT", path, b'')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_messages_sent(self): def test_rooms_messages_sent(self):
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
content = '{"body":"test","msgtype":{"type":"a"}}' content = b'{"body":"test","msgtype":{"type":"a"}}'
request, channel = make_request(b"PUT", path, content) request, channel = self.make_request("PUT", path, content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
# custom message types # custom message types
content = '{"body":"test","msgtype":"test.custom.text"}' content = b'{"body":"test","msgtype":"test.custom.text"}'
request, channel = make_request(b"PUT", path, content) request, channel = self.make_request("PUT", path, content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
# m.text message type # m.text message type
path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id))
content = '{"body":"test2","msgtype":"m.text"}' content = b'{"body":"test2","msgtype":"m.text"}'
request, channel = make_request(b"PUT", path, content) request, channel = self.make_request("PUT", path, content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
class RoomInitialSyncTestCase(RoomBase): class RoomInitialSyncTestCase(RoomBase):
@ -775,16 +729,16 @@ class RoomInitialSyncTestCase(RoomBase):
user_id = "@sid1:red" user_id = "@sid1:red"
def setUp(self): def prepare(self, reactor, clock, hs):
super(RoomInitialSyncTestCase, self).setUp()
# create the room # create the room
self.room_id = self.helper.create_room_as(self.user_id) self.room_id = self.helper.create_room_as(self.user_id)
def test_initial_sync(self): def test_initial_sync(self):
request, channel = make_request(b"GET", "/rooms/%s/initialSync" % self.room_id) request, channel = self.make_request(
render(request, self.resource, self.clock) "GET", "/rooms/%s/initialSync" % self.room_id
self.assertEquals(200, int(channel.result["code"])) )
self.render(request)
self.assertEquals(200, channel.code)
self.assertEquals(self.room_id, channel.json_body["room_id"]) self.assertEquals(self.room_id, channel.json_body["room_id"])
self.assertEquals("join", channel.json_body["membership"]) self.assertEquals("join", channel.json_body["membership"])
@ -819,17 +773,16 @@ class RoomMessageListTestCase(RoomBase):
user_id = "@sid1:red" user_id = "@sid1:red"
def setUp(self): def prepare(self, reactor, clock, hs):
super(RoomMessageListTestCase, self).setUp()
self.room_id = self.helper.create_room_as(self.user_id) self.room_id = self.helper.create_room_as(self.user_id)
def test_topo_token_is_accepted(self): def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0_0_0_0_0" token = "t1-0_0_0_0_0_0_0_0_0"
request, channel = make_request( request, channel = self.make_request(
b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"])) self.assertEquals(200, channel.code)
self.assertTrue("start" in channel.json_body) self.assertTrue("start" in channel.json_body)
self.assertEquals(token, channel.json_body['start']) self.assertEquals(token, channel.json_body['start'])
self.assertTrue("chunk" in channel.json_body) self.assertTrue("chunk" in channel.json_body)
@ -837,11 +790,11 @@ class RoomMessageListTestCase(RoomBase):
def test_stream_token_is_accepted_for_fwd_pagianation(self): def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0_0_0_0_0" token = "s0_0_0_0_0_0_0_0_0"
request, channel = make_request( request, channel = self.make_request(
b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"])) self.assertEquals(200, channel.code)
self.assertTrue("start" in channel.json_body) self.assertTrue("start" in channel.json_body)
self.assertEquals(token, channel.json_body['start']) self.assertEquals(token, channel.json_body['start'])
self.assertTrue("chunk" in channel.json_body) self.assertTrue("chunk" in channel.json_body)

View file

@ -62,12 +62,6 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self.assertTrue( self.assertTrue(
set( set(
[ ["next_batch", "rooms", "account_data", "to_device", "device_lists"]
"next_batch",
"rooms",
"account_data",
"to_device",
"device_lists",
]
).issubset(set(channel.json_body.keys())) ).issubset(set(channel.json_body.keys()))
) )

View file

@ -65,7 +65,7 @@ class FakeChannel(object):
def getPeer(self): def getPeer(self):
# We give an address so that getClientIP returns a non null entry, # We give an address so that getClientIP returns a non null entry,
# causing us to record the MAU # causing us to record the MAU
return address.IPv4Address(b"TCP", "127.0.0.1", 3423) return address.IPv4Address("TCP", "127.0.0.1", 3423)
def getHost(self): def getHost(self):
return None return None

View file

@ -80,12 +80,11 @@ class TestResourceLimitsServerNotices(unittest.TestCase):
self._rlsn._auth.check_auth_blocking = Mock() self._rlsn._auth.check_auth_blocking = Mock()
mock_event = Mock( mock_event = Mock(
type=EventTypes.Message, type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
content={"msgtype": ServerNoticeMsgType}, )
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
) )
self._rlsn._store.get_events = Mock(return_value=defer.succeed(
{"123": mock_event}
))
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
# Would be better to check the content, but once == remove blocking event # Would be better to check the content, but once == remove blocking event
@ -99,12 +98,11 @@ class TestResourceLimitsServerNotices(unittest.TestCase):
) )
mock_event = Mock( mock_event = Mock(
type=EventTypes.Message, type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
content={"msgtype": ServerNoticeMsgType}, )
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
) )
self._rlsn._store.get_events = Mock(return_value=defer.succeed(
{"123": mock_event}
))
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
self._send_notice.assert_not_called() self._send_notice.assert_not_called()
@ -177,13 +175,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_server_notice_only_sent_once(self): def test_server_notice_only_sent_once(self):
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(return_value=1000)
return_value=1000,
)
self.store.user_last_seen_monthly_active = Mock( self.store.user_last_seen_monthly_active = Mock(return_value=1000)
return_value=1000,
)
# Call the function multiple times to ensure we only send the notice once # Call the function multiple times to ensure we only send the notice once
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
@ -193,12 +187,12 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase):
# Now lets get the last load of messages in the service notice room and # Now lets get the last load of messages in the service notice room and
# check that there is only one server notice # check that there is only one server notice
room_id = yield self.server_notices_manager.get_notice_room_for_user( room_id = yield self.server_notices_manager.get_notice_room_for_user(
self.user_id, self.user_id
) )
token = yield self.event_source.get_current_token() token = yield self.event_source.get_current_token()
events, _ = yield self.store.get_recent_events_for_room( events, _ = yield self.store.get_recent_events_for_room(
room_id, limit=100, end_token=token.room_key, room_id, limit=100, end_token=token.room_key
) )
count = 0 count = 0

View file

@ -101,13 +101,11 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50 self.hs.config.max_mau_value = 50
user_id = "@user:server" user_id = "@user:server"
yield self.store.register(user_id=user_id, token="123", password_hash=None)
active = yield self.store.user_last_seen_monthly_active(user_id) active = yield self.store.user_last_seen_monthly_active(user_id)
self.assertFalse(active) self.assertFalse(active)
yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
)
yield self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id" user_id, "access_token", "ip", "user_agent", "device_id"
) )

View file

@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from mock import Mock
from twisted.internet import defer
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -23,7 +26,8 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
hs = self.setup_test_homeserver() hs = self.setup_test_homeserver()
self.store = hs.get_datastore() self.store = hs.get_datastore()
hs.config.limit_usage_by_mau = True
hs.config.max_mau_value = 50
# Advance the clock a bit # Advance the clock a bit
reactor.advance(FORTY_DAYS) reactor.advance(FORTY_DAYS)
@ -73,7 +77,7 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
active_count = self.store.get_monthly_active_count() active_count = self.store.get_monthly_active_count()
self.assertEquals(self.get_success(active_count), user_num) self.assertEquals(self.get_success(active_count), user_num)
# Test that regalar users are removed from the db # Test that regular users are removed from the db
ru_count = 2 ru_count = 2
self.store.upsert_monthly_active_user("@ru1:server") self.store.upsert_monthly_active_user("@ru1:server")
self.store.upsert_monthly_active_user("@ru2:server") self.store.upsert_monthly_active_user("@ru2:server")
@ -139,3 +143,43 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
count = self.store.get_monthly_active_count() count = self.store.get_monthly_active_count()
self.assertEquals(self.get_success(count), 0) self.assertEquals(self.get_success(count), 0)
def test_populate_monthly_users_is_guest(self):
# Test that guest users are not added to mau list
user_id = "user_id"
self.store.register(
user_id=user_id, token="123", password_hash=None, make_guest=True
)
self.store.upsert_monthly_active_user = Mock()
self.store.populate_monthly_active_users(user_id)
self.pump()
self.store.upsert_monthly_active_user.assert_not_called()
def test_populate_monthly_users_should_update(self):
self.store.upsert_monthly_active_user = Mock()
self.store.is_trial_user = Mock(
return_value=defer.succeed(False)
)
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(None)
)
self.store.populate_monthly_active_users('user_id')
self.pump()
self.store.upsert_monthly_active_user.assert_called_once()
def test_populate_monthly_users_should_not_update(self):
self.store.upsert_monthly_active_user = Mock()
self.store.is_trial_user = Mock(
return_value=defer.succeed(False)
)
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(
self.hs.get_clock().time_msec()
)
)
self.store.populate_monthly_active_users('user_id')
self.pump()
self.store.upsert_monthly_active_user.assert_not_called()

View file

@ -185,8 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters out members with types=[] # test _get_some_state_from_cache correctly filters out members with types=[]
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_cache, self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member]
group, [], filtered_types=[EventTypes.Member]
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
@ -200,19 +199,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, [], filtered_types=[EventTypes.Member] group,
[],
filtered_types=[EventTypes.Member],
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual( self.assertDictEqual({}, state_dict)
{},
state_dict,
)
# test _get_some_state_from_cache correctly filters in members with wildcard types # test _get_some_state_from_cache correctly filters in members with wildcard types
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_cache, self.store._state_group_cache,
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] group,
[(EventTypes.Member, None)],
filtered_types=[EventTypes.Member],
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
@ -226,7 +226,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] group,
[(EventTypes.Member, None)],
filtered_types=[EventTypes.Member],
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
@ -264,18 +266,15 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual( self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
{
(e5.type, e5.state_key): e5.event_id,
},
state_dict,
)
# test _get_some_state_from_cache correctly filters in members with specific types # test _get_some_state_from_cache correctly filters in members with specific types
# and no filtered_types # and no filtered_types
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, [(EventTypes.Member, e5.state_key)], filtered_types=None group,
[(EventTypes.Member, e5.state_key)],
filtered_types=None,
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
@ -305,9 +304,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
key=group, key=group,
value=state_dict_ids, value=state_dict_ids,
# list fetched keys so it knows it's partial # list fetched keys so it knows it's partial
fetched_keys=( fetched_keys=((e1.type, e1.state_key),),
(e1.type, e1.state_key),
),
) )
(is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
@ -315,20 +312,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertEqual( self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
known_absent, self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
set(
[
(e1.type, e1.state_key),
]
),
)
self.assertDictEqual(
state_dict_ids,
{
(e1.type, e1.state_key): e1.event_id,
},
)
############################################ ############################################
# test that things work with a partial cache # test that things work with a partial cache
@ -336,8 +321,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters out members with types=[] # test _get_some_state_from_cache correctly filters out members with types=[]
room_id = self.room.to_string() room_id = self.room.to_string()
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_cache, self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member]
group, [], filtered_types=[EventTypes.Member]
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
@ -346,7 +330,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
room_id = self.room.to_string() room_id = self.room.to_string()
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, [], filtered_types=[EventTypes.Member] group,
[],
filtered_types=[EventTypes.Member],
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
@ -355,20 +341,19 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters in members wildcard types # test _get_some_state_from_cache correctly filters in members wildcard types
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_cache, self.store._state_group_cache,
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] group,
[(EventTypes.Member, None)],
filtered_types=[EventTypes.Member],
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual( self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
{
(e1.type, e1.state_key): e1.event_id,
},
state_dict,
)
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] group,
[(EventTypes.Member, None)],
filtered_types=[EventTypes.Member],
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
@ -389,12 +374,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual( self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
{
(e1.type, e1.state_key): e1.event_id,
},
state_dict,
)
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
@ -404,18 +384,15 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual( self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
{
(e5.type, e5.state_key): e5.event_id,
},
state_dict,
)
# test _get_some_state_from_cache correctly filters in members with specific types # test _get_some_state_from_cache correctly filters in members with specific types
# and no filtered_types # and no filtered_types
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_cache, self.store._state_group_cache,
group, [(EventTypes.Member, e5.state_key)], filtered_types=None group,
[(EventTypes.Member, e5.state_key)],
filtered_types=None,
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
@ -423,13 +400,10 @@ class StateStoreTestCase(tests.unittest.TestCase):
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, [(EventTypes.Member, e5.state_key)], filtered_types=None group,
[(EventTypes.Member, e5.state_key)],
filtered_types=None,
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual( self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
{
(e5.type, e5.state_key): e5.event_id,
},
state_dict,
)

View file

@ -185,20 +185,20 @@ class TestMauLimit(unittest.TestCase):
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def create_user(self, localpart): def create_user(self, localpart):
request_data = json.dumps({ request_data = json.dumps(
"username": localpart, {
"password": "monkey", "username": localpart,
"auth": {"type": LoginType.DUMMY}, "password": "monkey",
}) "auth": {"type": LoginType.DUMMY},
}
)
request, channel = make_request(b"POST", b"/register", request_data) request, channel = make_request("POST", "/register", request_data)
render(request, self.resource, self.reactor) render(request, self.resource, self.reactor)
if channel.result["code"] != b"200": if channel.code != 200:
raise HttpResponseException( raise HttpResponseException(
int(channel.result["code"]), channel.code, channel.result["reason"], channel.result["body"]
channel.result["reason"],
channel.result["body"],
).to_synapse_error() ).to_synapse_error()
access_token = channel.json_body["access_token"] access_token = channel.json_body["access_token"]
@ -206,12 +206,12 @@ class TestMauLimit(unittest.TestCase):
return access_token return access_token
def do_sync_for_user(self, token): def do_sync_for_user(self, token):
request, channel = make_request(b"GET", b"/sync", access_token=token) request, channel = make_request(
"GET", "/sync", access_token=token.encode('ascii')
)
render(request, self.resource, self.reactor) render(request, self.resource, self.reactor)
if channel.result["code"] != b"200": if channel.code != 200:
raise HttpResponseException( raise HttpResponseException(
int(channel.result["code"]), channel.code, channel.result["reason"], channel.result["body"]
channel.result["reason"],
channel.result["body"],
).to_synapse_error() ).to_synapse_error()

View file

@ -180,7 +180,7 @@ class StateTestCase(unittest.TestCase):
graph = Graph( graph = Graph(
nodes={ nodes={
"START": DictObj( "START": DictObj(
type=EventTypes.Create, state_key="", content={}, depth=1, type=EventTypes.Create, state_key="", content={}, depth=1
), ),
"A": DictObj(type=EventTypes.Message, depth=2), "A": DictObj(type=EventTypes.Message, depth=2),
"B": DictObj(type=EventTypes.Message, depth=3), "B": DictObj(type=EventTypes.Message, depth=3),

View file

@ -100,8 +100,13 @@ class TestHomeServer(HomeServer):
@defer.inlineCallbacks @defer.inlineCallbacks
def setup_test_homeserver( def setup_test_homeserver(
cleanup_func, name="test", datastore=None, config=None, reactor=None, cleanup_func,
homeserverToUse=TestHomeServer, **kargs name="test",
datastore=None,
config=None,
reactor=None,
homeserverToUse=TestHomeServer,
**kargs
): ):
""" """
Setup a homeserver suitable for running tests against. Keyword arguments Setup a homeserver suitable for running tests against. Keyword arguments
@ -147,6 +152,7 @@ def setup_test_homeserver(
config.hs_disabled_message = "" config.hs_disabled_message = ""
config.hs_disabled_limit_type = "" config.hs_disabled_limit_type = ""
config.max_mau_value = 50 config.max_mau_value = 50
config.mau_trial_days = 0
config.mau_limits_reserved_threepids = [] config.mau_limits_reserved_threepids = []
config.admin_contact = None config.admin_contact = None
config.rc_messages_per_second = 10000 config.rc_messages_per_second = 10000
@ -321,7 +327,9 @@ class MockHttpResource(HttpServer):
@patch('twisted.web.http.Request') @patch('twisted.web.http.Request')
@defer.inlineCallbacks @defer.inlineCallbacks
def trigger(self, http_method, path, content, mock_request, federation_auth=False): def trigger(
self, http_method, path, content, mock_request, federation_auth_origin=None
):
""" Fire an HTTP event. """ Fire an HTTP event.
Args: Args:
@ -330,6 +338,7 @@ class MockHttpResource(HttpServer):
content : The HTTP body content : The HTTP body
mock_request : Mocked request to pass to the event so it can get mock_request : Mocked request to pass to the event so it can get
content. content.
federation_auth_origin (bytes|None): domain to authenticate as, for federation
Returns: Returns:
A tuple of (code, response) A tuple of (code, response)
Raises: Raises:
@ -350,8 +359,10 @@ class MockHttpResource(HttpServer):
mock_request.getClientIP.return_value = "-" mock_request.getClientIP.return_value = "-"
headers = {} headers = {}
if federation_auth: if federation_auth_origin is not None:
headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="] headers[b"Authorization"] = [
b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
]
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers) mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
# return the right path if the event requires it # return the right path if the event requires it
@ -570,16 +581,16 @@ def create_room(hs, room_id, creator_id):
event_builder_factory = hs.get_event_builder_factory() event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler() event_creation_handler = hs.get_event_creation_handler()
builder = event_builder_factory.new({ builder = event_builder_factory.new(
"type": EventTypes.Create, {
"state_key": "", "type": EventTypes.Create,
"sender": creator_id, "state_key": "",
"room_id": room_id, "sender": creator_id,
"content": {}, "room_id": room_id,
}) "content": {},
}
event, context = yield event_creation_handler.create_new_client_event(
builder
) )
event, context = yield event_creation_handler.create_new_client_event(builder)
yield store.persist_event(event, context) yield store.persist_event(event, context)