1
0
Fork 0
mirror of https://github.com/element-hq/synapse.git synced 2024-12-22 20:50:23 +03:00

flake8 and make API more compact & docced

This commit is contained in:
Matthew Hodgson 2018-09-07 22:59:38 +01:00
commit 967fdfef10
78 changed files with 908 additions and 757 deletions

View file

@ -3,6 +3,5 @@ Dockerfile
.gitignore
demo/etc
tox.ini
synctl
.git/*
.tox/*

1
.gitignore vendored
View file

@ -44,6 +44,7 @@ media_store/
build/
venv/
venv*/
*venv/
localhost-800*/
static/client/register/register_config.js

View file

@ -8,9 +8,6 @@ before_script:
- git remote set-branches --add origin develop
- git fetch origin develop
services:
- postgresql
matrix:
fast_finish: true
include:
@ -25,6 +22,8 @@ matrix:
- python: 2.7
env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4"
services:
- postgresql
- python: 3.6
env: TOX_ENV=py36

View file

@ -1,3 +1,77 @@
Synapse 0.33.4 (2018-09-07)
===========================
Internal Changes
----------------
- Unignore synctl in .dockerignore to fix docker builds ([\#3802](https://github.com/matrix-org/synapse/issues/3802))
Synapse 0.33.4rc2 (2018-09-06)
==============================
Pull in security fixes from v0.33.3.1
Synapse 0.33.3.1 (2018-09-06)
=============================
SECURITY FIXES
--------------
- Fix an issue where event signatures were not always correctly validated ([\#3796](https://github.com/matrix-org/synapse/issues/3796))
- Fix an issue where server_acls could be circumvented for incoming events ([\#3796](https://github.com/matrix-org/synapse/issues/3796))
Internal Changes
----------------
- Unignore synctl in .dockerignore to fix docker builds ([\#3802](https://github.com/matrix-org/synapse/issues/3802))
Synapse 0.33.4rc1 (2018-09-04)
==============================
Features
--------
- Support profile API endpoints on workers ([\#3659](https://github.com/matrix-org/synapse/issues/3659))
- Server notices for resource limit blocking ([\#3680](https://github.com/matrix-org/synapse/issues/3680))
- Allow guests to use /rooms/:roomId/event/:eventId ([\#3724](https://github.com/matrix-org/synapse/issues/3724))
- Add mau_trial_days config param, so that users only get counted as MAU after N days. ([\#3749](https://github.com/matrix-org/synapse/issues/3749))
- Require twisted 17.1 or later (fixes [#3741](https://github.com/matrix-org/synapse/issues/3741)). ([\#3751](https://github.com/matrix-org/synapse/issues/3751))
Bugfixes
--------
- Fix error collecting prometheus metrics when run on dedicated thread due to threading concurrency issues ([\#3722](https://github.com/matrix-org/synapse/issues/3722))
- Fix bug where we resent "limit exceeded" server notices repeatedly ([\#3747](https://github.com/matrix-org/synapse/issues/3747))
- Fix bug where we broke sync when using limit_usage_by_mau but hadn't configured server notices ([\#3753](https://github.com/matrix-org/synapse/issues/3753))
- Fix 'federation_domain_whitelist' such that an empty list correctly blocks all outbound federation traffic ([\#3754](https://github.com/matrix-org/synapse/issues/3754))
- Fix tagging of server notice rooms ([\#3755](https://github.com/matrix-org/synapse/issues/3755), [\#3756](https://github.com/matrix-org/synapse/issues/3756))
- Fix 'admin_uri' config variable and error parameter to be 'admin_contact' to match the spec. ([\#3758](https://github.com/matrix-org/synapse/issues/3758))
- Don't return non-LL-member state in incremental sync state blocks ([\#3760](https://github.com/matrix-org/synapse/issues/3760))
- Fix bug in sending presence over federation ([\#3768](https://github.com/matrix-org/synapse/issues/3768))
- Fix bug where preserved threepid user comes to sign up and server is mau blocked ([\#3777](https://github.com/matrix-org/synapse/issues/3777))
Internal Changes
----------------
- Removed the link to the unmaintained matrix-synapse-auto-deploy project from the readme. ([\#3378](https://github.com/matrix-org/synapse/issues/3378))
- Refactor state module to support multiple room versions ([\#3673](https://github.com/matrix-org/synapse/issues/3673))
- The synapse.storage module has been ported to Python 3. ([\#3725](https://github.com/matrix-org/synapse/issues/3725))
- Split the state_group_cache into member and non-member state events (and so speed up LL /sync) ([\#3726](https://github.com/matrix-org/synapse/issues/3726))
- Log failure to authenticate remote servers as warnings (without stack traces) ([\#3727](https://github.com/matrix-org/synapse/issues/3727))
- The CONTRIBUTING guidelines have been updated to mention our use of Markdown and that .misc files have content. ([\#3730](https://github.com/matrix-org/synapse/issues/3730))
- Reference the need for an HTTP replication port when using the federation_reader worker ([\#3734](https://github.com/matrix-org/synapse/issues/3734))
- Fix minor spelling error in federation client documentation. ([\#3735](https://github.com/matrix-org/synapse/issues/3735))
- Remove redundant state resolution function ([\#3737](https://github.com/matrix-org/synapse/issues/3737))
- The test suite now passes on PostgreSQL. ([\#3740](https://github.com/matrix-org/synapse/issues/3740))
- Fix MAU cache invalidation due to missing yield ([\#3746](https://github.com/matrix-org/synapse/issues/3746))
- Make sure that we close db connections opened during init ([\#3764](https://github.com/matrix-org/synapse/issues/3764))
Synapse 0.33.3 (2018-08-22)
===========================

View file

@ -742,6 +742,18 @@ so an example nginx configuration might look like::
}
}
and an example apache configuration may look like::
<VirtualHost *:443>
SSLEngine on
ServerName matrix.example.com;
<Location /_matrix>
ProxyPass http://127.0.0.1:8008/_matrix nocanon
ProxyPassReverse http://127.0.0.1:8008/_matrix
</Location>
</VirtualHost>
You will also want to set ``bind_addresses: ['127.0.0.1']`` and ``x_forwarded: true``
for port 8008 in ``homeserver.yaml`` to ensure that client IP addresses are
recorded correctly.

View file

@ -1 +0,0 @@
Removed the link to the unmaintained matrix-synapse-auto-deploy project from the readme.

View file

@ -1 +0,0 @@
Support profile API endpoints on workers

View file

@ -1 +0,0 @@
Refactor state module to support multiple room versions

View file

@ -1 +0,0 @@
Server notices for resource limit blocking

View file

@ -1 +0,0 @@
Fix error collecting prometheus metrics when run on dedicated thread due to threading concurrency issues

View file

@ -1 +0,0 @@
Allow guests to use /rooms/:roomId/event/:eventId

View file

@ -1 +0,0 @@
The synapse.storage module has been ported to Python 3.

View file

@ -1 +0,0 @@
Split the state_group_cache into member and non-member state events (and so speed up LL /sync)

View file

@ -1 +0,0 @@
Log failure to authenticate remote servers as warnings (without stack traces)

View file

@ -1 +0,0 @@
The CONTRIBUTING guidelines have been updated to mention our use of Markdown and that .misc files have content.

View file

@ -1 +0,0 @@
Reference the need for an HTTP replication port when using the federation_reader worker

View file

@ -1 +0,0 @@
Fix minor spelling error in federation client documentation.

View file

@ -1 +0,0 @@
Remove redundant state resolution function

View file

@ -1 +0,0 @@
The test suite now passes on PostgreSQL.

View file

@ -1 +0,0 @@
Fix MAU cache invalidation due to missing yield

View file

@ -1 +0,0 @@
Fix bug where we resent "limit exceeded" server notices repeatedly

View file

@ -1 +0,0 @@
Add mau_trial_days config param, so that users only get counted as MAU after N days.

View file

@ -1 +0,0 @@
Require twisted 17.1 or later (fixes [#3741](https://github.com/matrix-org/synapse/issues/3741)).

View file

@ -1 +0,0 @@
Fix bug where we broke sync when using limit_usage_by_mau but hadn't configured server notices

View file

@ -1 +0,0 @@
Fix 'federation_domain_whitelist' such that an empty list correctly blocks all outbound federation traffic

View file

@ -1 +0,0 @@
Fix tagging of server notice rooms

View file

@ -1 +0,0 @@
Fix tagging of server notice rooms

View file

@ -1 +0,0 @@
Fix 'admin_uri' config variable and error parameter to be 'admin_contact' to match the spec.

View file

@ -1 +0,0 @@
Don't return non-LL-member state in incremental sync state blocks

View file

@ -1 +0,0 @@
Make sure that we close db connections opened during init

View file

@ -1 +0,0 @@
Fix bug in sending presence over federation

1
changelog.d/3771.misc Normal file
View file

@ -0,0 +1 @@
http/ is now ported to Python 3.

View file

@ -1 +0,0 @@
Fix bug where preserved threepid user comes to sign up and server is mau blocked

1
changelog.d/3788.bugfix Normal file
View file

@ -0,0 +1 @@
Remove connection ID for replication prometheus metrics, as it creates a large number of new series.

1
changelog.d/3790.feature Normal file
View file

@ -0,0 +1 @@
Implement `event_format` filter param in `/sync`

1
changelog.d/3795.misc Normal file
View file

@ -0,0 +1 @@
Make /sync slightly faster by avoiding needless copies

1
changelog.d/3800.bugfix Normal file
View file

@ -0,0 +1 @@
guest users should not be part of mau total

1
changelog.d/3803.misc Normal file
View file

@ -0,0 +1 @@
handlers/ is now ported to Python 3.

1
changelog.d/3804.bugfix Normal file
View file

@ -0,0 +1 @@
Bump dependency on pyopenssl 16.x, to avoid incompatibility with recent Twisted.

1
changelog.d/3805.misc Normal file
View file

@ -0,0 +1 @@
Limit the number of PDUs/EDUs per federation transaction

1
changelog.d/3806.misc Normal file
View file

@ -0,0 +1 @@
Only start postgres instance for postgres tests on Travis CI

1
changelog.d/3808.misc Normal file
View file

@ -0,0 +1 @@
tests/ is now ported to Python 3.

1
changelog.d/3810.bugfix Normal file
View file

@ -0,0 +1 @@
Fix existing room tags not coming down sync when joining a room

View file

@ -17,13 +17,14 @@ ignore =
[pep8]
max-line-length = 90
# W503 requires that binary operators be at the end, not start, of lines. Erik
# doesn't like it. E203 is contrary to PEP8.
ignore = W503,E203
# doesn't like it. E203 is contrary to PEP8. E731 is silly.
ignore = W503,E203,E731
[flake8]
# note that flake8 inherits the "ignore" settings from "pep8" (because it uses
# pep8 to do those checks), but not the "max-line-length" setting
max-line-length = 90
ignore=W503,E203,E731
[isort]
line_length = 89

View file

@ -17,4 +17,4 @@
""" This is a reference implementation of a Matrix home server.
"""
__version__ = "0.33.3"
__version__ = "0.33.4"

View file

@ -251,6 +251,7 @@ class FilterCollection(object):
"include_leave", False
)
self.event_fields = filter_json.get("event_fields", [])
self.event_format = filter_json.get("event_format", "client")
def __repr__(self):
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)

View file

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib
from six.moves import urllib
from prometheus_client import Counter
@ -98,7 +99,7 @@ class ApplicationServiceApi(SimpleHttpClient):
def query_user(self, service, user_id):
if service.url is None:
defer.returnValue(False)
uri = service.url + ("/users/%s" % urllib.quote(user_id))
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
response = None
try:
response = yield self.get_json(uri, {
@ -119,7 +120,7 @@ class ApplicationServiceApi(SimpleHttpClient):
def query_alias(self, service, alias):
if service.url is None:
defer.returnValue(False)
uri = service.url + ("/rooms/%s" % urllib.quote(alias))
uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
response = None
try:
response = yield self.get_json(uri, {
@ -153,7 +154,7 @@ class ApplicationServiceApi(SimpleHttpClient):
service.url,
APP_SERVICE_PREFIX,
kind,
urllib.quote(protocol)
urllib.parse.quote(protocol)
)
try:
response = yield self.get_json(uri, fields)
@ -188,7 +189,7 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = "%s%s/thirdparty/protocol/%s" % (
service.url,
APP_SERVICE_PREFIX,
urllib.quote(protocol)
urllib.parse.quote(protocol)
)
try:
info = yield self.get_json(uri, {})
@ -228,7 +229,7 @@ class ApplicationServiceApi(SimpleHttpClient):
txn_id = str(txn_id)
uri = service.url + ("/transactions/%s" %
urllib.quote(txn_id))
urllib.parse.quote(txn_id))
try:
yield self.put_json(
uri=uri,

View file

@ -13,17 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections import namedtuple
import six
from twisted.internet import defer
from twisted.internet.defer import DeferredList
from synapse.api.constants import MAX_DEPTH
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_dict
from synapse.types import get_domain_from_id
from synapse.util import logcontext, unwrapFirstError
logger = logging.getLogger(__name__)
@ -133,34 +136,25 @@ class FederationBase(object):
* throws a SynapseError if the signature check failed.
The deferreds run their callbacks in the sentinel logcontext.
"""
redacted_pdus = [
prune_event(pdu)
for pdu in pdus
]
deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json())
for p in redacted_pdus
])
deferreds = _check_sigs_on_pdus(self.keyring, pdus)
ctx = logcontext.LoggingContext.current_context()
def callback(_, pdu, redacted):
def callback(_, pdu):
with logcontext.PreserveLoggingContext(ctx):
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json()
)
return redacted
return prune_event(pdu)
if self.spam_checker.check_event_for_spam(pdu):
logger.warn(
"Event contains spam, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json()
)
return redacted
return prune_event(pdu)
return pdu
@ -173,16 +167,116 @@ class FederationBase(object):
)
return failure
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
for deferred, pdu in zip(deferreds, pdus):
deferred.addCallbacks(
callback, errback,
callbackArgs=[pdu, redacted],
callbackArgs=[pdu],
errbackArgs=[pdu],
)
return deferreds
class PduToCheckSig(namedtuple("PduToCheckSig", [
"pdu", "redacted_pdu_json", "event_id_domain", "sender_domain", "deferreds",
])):
pass
def _check_sigs_on_pdus(keyring, pdus):
"""Check that the given events are correctly signed
Args:
keyring (synapse.crypto.Keyring): keyring object to do the checks
pdus (Collection[EventBase]): the events to be checked
Returns:
List[Deferred]: a Deferred for each event in pdus, which will either succeed if
the signatures are valid, or fail (with a SynapseError) if not.
"""
# (currently this is written assuming the v1 room structure; we'll probably want a
# separate function for checking v2 rooms)
# we want to check that the event is signed by:
#
# (a) the server which created the event_id
#
# (b) the sender's server.
#
# - except in the case of invites created from a 3pid invite, which are exempt
# from this check, because the sender has to match that of the original 3pid
# invite, but the event may come from a different HS, for reasons that I don't
# entirely grok (why do the senders have to match? and if they do, why doesn't the
# joining server ask the inviting server to do the switcheroo with
# exchange_third_party_invite?).
#
# That's pretty awful, since redacting such an invite will render it invalid
# (because it will then look like a regular invite without a valid signature),
# and signatures are *supposed* to be valid whether or not an event has been
# redacted. But this isn't the worst of the ways that 3pid invites are broken.
#
# let's start by getting the domain for each pdu, and flattening the event back
# to JSON.
pdus_to_check = [
PduToCheckSig(
pdu=p,
redacted_pdu_json=prune_event(p).get_pdu_json(),
event_id_domain=get_domain_from_id(p.event_id),
sender_domain=get_domain_from_id(p.sender),
deferreds=[],
)
for p in pdus
]
# first make sure that the event is signed by the event_id's domain
deferreds = keyring.verify_json_objects_for_server([
(p.event_id_domain, p.redacted_pdu_json)
for p in pdus_to_check
])
for p, d in zip(pdus_to_check, deferreds):
p.deferreds.append(d)
# now let's look for events where the sender's domain is different to the
# event id's domain (normally only the case for joins/leaves), and add additional
# checks.
pdus_to_check_sender = [
p for p in pdus_to_check
if p.sender_domain != p.event_id_domain and not _is_invite_via_3pid(p.pdu)
]
more_deferreds = keyring.verify_json_objects_for_server([
(p.sender_domain, p.redacted_pdu_json)
for p in pdus_to_check_sender
])
for p, d in zip(pdus_to_check_sender, more_deferreds):
p.deferreds.append(d)
# replace lists of deferreds with single Deferreds
return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
def _flatten_deferred_list(deferreds):
"""Given a list of one or more deferreds, either return the single deferred, or
combine into a DeferredList.
"""
if len(deferreds) > 1:
return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True)
else:
assert len(deferreds) == 1
return deferreds[0]
def _is_invite_via_3pid(event):
return (
event.type == EventTypes.Member
and event.membership == Membership.INVITE
and "third_party_invite" in event.content
)
def event_from_pdu_json(pdu_json, outlier=False):
"""Construct a FrozenEvent from an event json received over federation

View file

@ -99,7 +99,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
def on_incoming_transaction(self, transaction_data):
def on_incoming_transaction(self, origin, transaction_data):
# keep this as early as possible to make the calculated origin ts as
# accurate as possible.
request_time = self._clock.time_msec()
@ -108,34 +108,33 @@ class FederationServer(FederationBase):
if not transaction.transaction_id:
raise Exception("Transaction missing transaction_id")
if not transaction.origin:
raise Exception("Transaction missing origin")
logger.debug("[%s] Got transaction", transaction.transaction_id)
# use a linearizer to ensure that we don't process the same transaction
# multiple times in parallel.
with (yield self._transaction_linearizer.queue(
(transaction.origin, transaction.transaction_id),
(origin, transaction.transaction_id),
)):
result = yield self._handle_incoming_transaction(
transaction, request_time,
origin, transaction, request_time,
)
defer.returnValue(result)
@defer.inlineCallbacks
def _handle_incoming_transaction(self, transaction, request_time):
def _handle_incoming_transaction(self, origin, transaction, request_time):
""" Process an incoming transaction and return the HTTP response
Args:
origin (unicode): the server making the request
transaction (Transaction): incoming transaction
request_time (int): timestamp that the HTTP request arrived at
Returns:
Deferred[(int, object)]: http response code and body
"""
response = yield self.transaction_actions.have_responded(transaction)
response = yield self.transaction_actions.have_responded(origin, transaction)
if response:
logger.debug(
@ -149,7 +148,7 @@ class FederationServer(FederationBase):
received_pdus_counter.inc(len(transaction.pdus))
origin_host, _ = parse_server_name(transaction.origin)
origin_host, _ = parse_server_name(origin)
pdus_by_room = {}
@ -190,7 +189,7 @@ class FederationServer(FederationBase):
event_id = pdu.event_id
try:
yield self._handle_received_pdu(
transaction.origin, pdu
origin, pdu
)
pdu_results[event_id] = {}
except FederationError as e:
@ -212,7 +211,7 @@ class FederationServer(FederationBase):
if hasattr(transaction, "edus"):
for edu in (Edu(**x) for x in transaction.edus):
yield self.received_edu(
transaction.origin,
origin,
edu.edu_type,
edu.content
)
@ -224,6 +223,7 @@ class FederationServer(FederationBase):
logger.debug("Returning: %s", str(response))
yield self.transaction_actions.set_response(
origin,
transaction,
200, response
)
@ -838,9 +838,9 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
)
return self._send_edu(
edu_type=edu_type,
origin=origin,
content=content,
edu_type=edu_type,
origin=origin,
content=content,
)
def on_query(self, query_type, args):
@ -851,6 +851,6 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
return handler(args)
return self._get_query_client(
query_type=query_type,
args=args,
query_type=query_type,
args=args,
)

View file

@ -36,7 +36,7 @@ class TransactionActions(object):
self.store = datastore
@log_function
def have_responded(self, transaction):
def have_responded(self, origin, transaction):
""" Have we already responded to a transaction with the same id and
origin?
@ -50,11 +50,11 @@ class TransactionActions(object):
"transaction_id")
return self.store.get_received_txn_response(
transaction.transaction_id, transaction.origin
transaction.transaction_id, origin
)
@log_function
def set_response(self, transaction, code, response):
def set_response(self, origin, transaction, code, response):
""" Persist how we responded to a transaction.
Returns:
@ -66,7 +66,7 @@ class TransactionActions(object):
return self.store.set_received_txn_response(
transaction.transaction_id,
transaction.origin,
origin,
code,
response,
)

View file

@ -463,7 +463,19 @@ class TransactionQueue(object):
# pending_transactions flag.
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
# We can only include at most 50 PDUs per transactions
pending_pdus, leftover_pdus = pending_pdus[:50], pending_pdus[50:]
if leftover_pdus:
self.pending_pdus_by_dest[destination] = leftover_pdus
pending_edus = self.pending_edus_by_dest.pop(destination, [])
# We can only include at most 100 EDUs per transactions
pending_edus, leftover_edus = pending_edus[:100], pending_edus[100:]
if leftover_edus:
self.pending_edus_by_dest[destination] = leftover_edus
pending_presence = self.pending_presence_by_dest.pop(destination, {})
pending_edus.extend(

View file

@ -353,7 +353,7 @@ class FederationSendServlet(BaseFederationServlet):
try:
code, response = yield self.handler.on_incoming_transaction(
transaction_data
origin, transaction_data,
)
except Exception:
logger.exception("on_incoming_transaction failed")

View file

@ -895,22 +895,24 @@ class AuthHandler(BaseHandler):
Args:
password (unicode): Password to hash.
stored_hash (unicode): Expected hash value.
stored_hash (bytes): Expected hash value.
Returns:
Deferred(bool): Whether self.hash(password) == stored_hash.
"""
def _do_validate_hash():
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
return bcrypt.checkpw(
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
stored_hash.encode('utf8')
stored_hash
)
if stored_hash:
if not isinstance(stored_hash, bytes):
stored_hash = stored_hash.encode('ascii')
return make_deferred_yieldable(
threads.deferToThreadPool(
self.hs.get_reactor(),

View file

@ -330,7 +330,8 @@ class E2eKeysHandler(object):
(algorithm, key_id, ex_json, key)
)
else:
new_keys.append((algorithm, key_id, encode_canonical_json(key)))
new_keys.append((
algorithm, key_id, encode_canonical_json(key).decode('ascii')))
yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, new_keys
@ -358,7 +359,7 @@ def _exception_to_failure(e):
# Note that some Exceptions (notably twisted's ResponseFailed etc) don't
# give a string for e.message, which json then fails to serialize.
return {
"status": 503, "message": str(e.message),
"status": 503, "message": str(e),
}

View file

@ -594,7 +594,7 @@ class FederationHandler(BaseHandler):
required_auth = set(
a_id
for event in events + state_events.values() + auth_events.values()
for event in events + list(state_events.values()) + list(auth_events.values())
for a_id, _ in event.auth_events
)
auth_events.update({
@ -802,7 +802,7 @@ class FederationHandler(BaseHandler):
)
continue
except NotRetryingDestination as e:
logger.info(e.message)
logger.info(str(e))
continue
except FederationDeniedError as e:
logger.info(e)
@ -1358,7 +1358,7 @@ class FederationHandler(BaseHandler):
)
if state_groups:
_, state = state_groups.items().pop()
_, state = list(state_groups.items()).pop()
results = state
if event.is_state():

View file

@ -162,7 +162,7 @@ class RoomListHandler(BaseHandler):
# Filter out rooms that we don't want to return
rooms_to_scan = [
r for r in sorted_rooms
if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0
if r not in newly_unpublished and rooms_to_num_joined[r] > 0
]
total_room_count = len(rooms_to_scan)

View file

@ -54,7 +54,7 @@ class SearchHandler(BaseHandler):
batch_token = None
if batch:
try:
b = decode_base64(batch)
b = decode_base64(batch).decode('ascii')
batch_group, batch_group_key, batch_token = b.split("\n")
assert batch_group is not None
@ -258,18 +258,18 @@ class SearchHandler(BaseHandler):
# it returns more from the same group (if applicable) rather
# than reverting to searching all results again.
if batch_group and batch_group_key:
global_next_batch = encode_base64("%s\n%s\n%s" % (
global_next_batch = encode_base64(("%s\n%s\n%s" % (
batch_group, batch_group_key, pagination_token
))
)).encode('ascii'))
else:
global_next_batch = encode_base64("%s\n%s\n%s" % (
global_next_batch = encode_base64(("%s\n%s\n%s" % (
"all", "", pagination_token
))
)).encode('ascii'))
for room_id, group in room_groups.items():
group["next_batch"] = encode_base64("%s\n%s\n%s" % (
group["next_batch"] = encode_base64(("%s\n%s\n%s" % (
"room_id", room_id, pagination_token
))
)).encode('ascii'))
allowed_events.extend(room_events)

View file

@ -553,22 +553,36 @@ class SyncHandler(object):
summary = {}
# TODO: only send these when they change.
summary["m.joined_member_count"] = details.get(Membership.JOIN, {}).get('count', 0)
summary["m.invited_member_count"] = details.get(Membership.INVITE, {}).get('count', 0)
summary["m.joined_member_count"] = (
details.get(Membership.JOIN, ([], 0))[1]
)
summary["m.invited_member_count"] = (
details.get(Membership.INVITE, ([], 0))[1]
)
if name_id or canonical_alias_id:
defer.returnValue(summary)
joined_user_ids = [r[0] for r in details.get(Membership.JOIN, {}).get('users', [])]
invited_user_ids = [r[0] for r in details.get(Membership.INVITE, {}).get('users', [])]
joined_user_ids = [
r[0] for r in details.get(Membership.JOIN, ([], 0))[0]
]
invited_user_ids = [
r[0] for r in details.get(Membership.INVITE, ([], 0))[0]
]
gone_user_ids = (
[r[0] for r in details.get(Membership.LEAVE, {}).get('users', [])] +
[r[0] for r in details.get(Membership.BAN, {}).get('users', [])]
[r[0] for r in details.get(Membership.LEAVE, ([], 0))[0]] +
[r[0] for r in details.get(Membership.BAN, ([], 0))[0]]
)
# FIXME: only build up a member_ids list for our heroes
member_ids = {}
for m in (Membership.JOIN, Membership.INVITE, Membership.LEAVE, Membership.BAN):
for r in details.get(m, {}).get('users', []):
for m in (
Membership.JOIN,
Membership.INVITE,
Membership.LEAVE,
Membership.BAN
):
for r in details.get(m, ([], 0))[0]:
member_ids[r[0]] = r[1]
# FIXME: order by stream ordering rather than as returned by SQL
@ -775,7 +789,7 @@ class SyncHandler(object):
logger.debug("filtering state from %r...", state_ids)
state_ids = {
t: event_id
for t, event_id in state_ids.iteritems()
for t, event_id in iteritems(state_ids)
if cache.get(t[1]) != event_id
}
logger.debug("...to %r", state_ids)
@ -1576,6 +1590,19 @@ class SyncHandler(object):
newly_joined_room=newly_joined,
)
# When we join the room (or the client requests full_state), we should
# send down any existing tags. Usually the user won't have tags in a
# newly joined room, unless either a) they've joined before or b) the
# tag was added by synapse e.g. for server notice rooms.
if full_state:
user_id = sync_result_builder.sync_config.user.to_string()
tags = yield self.store.get_tags_for_room(user_id, room_id)
# If there aren't any tags, don't send the empty tags list down
# sync
if not tags:
tags = None
account_data_events = []
if tags is not None:
account_data_events.append({
@ -1608,7 +1635,12 @@ class SyncHandler(object):
sync_config.filter_collection.lazy_load_members() and
(
any(ev.type == EventTypes.Member for ev in batch.events) or
#(batch.limited and any(ev.type == EventTypes.Member for ev in state)) or
(
# XXX: this may include false positives in the form of LL
# members which have snuck into state
batch.limited and
any(t == EventTypes.Member for (t, k) in state.keys())
) or
since_token is None
)
):
@ -1731,17 +1763,17 @@ def _calculate_state(
event_id_to_key = {
e: key
for key, e in itertools.chain(
timeline_contains.items(),
previous.items(),
timeline_start.items(),
current.items(),
iteritems(timeline_contains),
iteritems(previous),
iteritems(timeline_start),
iteritems(current),
)
}
c_ids = set(e for e in current.values())
ts_ids = set(e for e in timeline_start.values())
p_ids = set(e for e in previous.values())
tc_ids = set(e for e in timeline_contains.values())
c_ids = set(e for e in itervalues(current))
ts_ids = set(e for e in itervalues(timeline_start))
p_ids = set(e for e in itervalues(previous))
tc_ids = set(e for e in itervalues(timeline_contains))
# If we are lazyloading room members, we explicitly add the membership events
# for the senders in the timeline into the state block returned by /sync,
@ -1755,7 +1787,7 @@ def _calculate_state(
if lazy_load_members:
p_ids.difference_update(
e for t, e in timeline_start.iteritems()
e for t, e in iteritems(timeline_start)
if t[0] == EventTypes.Member
)

View file

@ -13,24 +13,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib
from six import StringIO
from six import text_type
from six.moves import urllib
import treq
from canonicaljson import encode_canonical_json, json
from prometheus_client import Counter
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, protocol, reactor, ssl, task
from twisted.internet import defer, protocol, reactor, ssl
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web._newclient import ResponseDone
from twisted.web.client import (
Agent,
BrowserLikeRedirectAgent,
ContentDecoderAgent,
FileBodyProducer as TwistedFileBodyProducer,
GzipDecoder,
HTTPConnectionPool,
PartialDownloadError,
@ -83,18 +84,20 @@ class SimpleHttpClient(object):
if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,)
self.user_agent = self.user_agent.encode('ascii')
@defer.inlineCallbacks
def request(self, method, uri, *args, **kwargs):
def request(self, method, uri, data=b'', headers=None):
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
outgoing_requests_counter.labels(method).inc()
# log request but strip `access_token` (AS requests for example include this)
logger.info("Sending request %s %s", method, redact_uri(uri))
logger.info("Sending request %s %s", method, redact_uri(uri.encode('ascii')))
try:
request_deferred = self.agent.request(
method, uri, *args, **kwargs
request_deferred = treq.request(
method, uri, agent=self.agent, data=data, headers=headers
)
add_timeout_to_deferred(
request_deferred, 60, self.hs.get_reactor(),
@ -105,14 +108,14 @@ class SimpleHttpClient(object):
incoming_responses_counter.labels(method, response.code).inc()
logger.info(
"Received response to %s %s: %s",
method, redact_uri(uri), response.code
method, redact_uri(uri.encode('ascii')), response.code
)
defer.returnValue(response)
except Exception as e:
incoming_responses_counter.labels(method, "ERR").inc()
logger.info(
"Error sending request to %s %s: %s %s",
method, redact_uri(uri), type(e).__name__, e.message
method, redact_uri(uri.encode('ascii')), type(e).__name__, e.args[0]
)
raise
@ -137,7 +140,8 @@ class SimpleHttpClient(object):
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
query_bytes = urllib.parse.urlencode(
encode_urlencode_args(args), True).encode("utf8")
actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
@ -148,15 +152,14 @@ class SimpleHttpClient(object):
response = yield self.request(
"POST",
uri.encode("ascii"),
uri,
headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(query_bytes))
data=query_bytes
)
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
body = yield make_deferred_yieldable(treq.json_content(response))
defer.returnValue(body)
else:
raise HttpResponseException(response.code, response.phrase, body)
@ -191,9 +194,9 @@ class SimpleHttpClient(object):
response = yield self.request(
"POST",
uri.encode("ascii"),
uri,
headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str))
data=json_str
)
body = yield make_deferred_yieldable(readBody(response))
@ -248,7 +251,7 @@ class SimpleHttpClient(object):
ValueError: if the response was not JSON
"""
if len(args):
query_bytes = urllib.urlencode(args, True)
query_bytes = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
json_str = encode_canonical_json(json_body)
@ -262,9 +265,9 @@ class SimpleHttpClient(object):
response = yield self.request(
"PUT",
uri.encode("ascii"),
uri,
headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str))
data=json_str
)
body = yield make_deferred_yieldable(readBody(response))
@ -293,7 +296,7 @@ class SimpleHttpClient(object):
HttpResponseException on a non-2xx HTTP response.
"""
if len(args):
query_bytes = urllib.urlencode(args, True)
query_bytes = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
actual_headers = {
@ -304,7 +307,7 @@ class SimpleHttpClient(object):
response = yield self.request(
"GET",
uri.encode("ascii"),
uri,
headers=Headers(actual_headers),
)
@ -339,7 +342,7 @@ class SimpleHttpClient(object):
response = yield self.request(
"GET",
url.encode("ascii"),
url,
headers=Headers(actual_headers),
)
@ -434,12 +437,12 @@ class CaptchaServerHttpClient(SimpleHttpClient):
@defer.inlineCallbacks
def post_urlencoded_get_raw(self, url, args={}):
query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True)
response = yield self.request(
"POST",
url.encode("ascii"),
bodyProducer=FileBodyProducer(StringIO(query_bytes)),
url,
data=query_bytes,
headers=Headers({
b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.user_agent],
@ -510,7 +513,7 @@ def encode_urlencode_args(args):
def encode_urlencode_arg(arg):
if isinstance(arg, unicode):
if isinstance(arg, text_type):
return arg.encode('utf-8')
elif isinstance(arg, list):
return [encode_urlencode_arg(i) for i in arg]
@ -542,26 +545,3 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
def creatorForNetloc(self, hostname, port):
return self
class FileBodyProducer(TwistedFileBodyProducer):
"""Workaround for https://twistedmatrix.com/trac/ticket/8473
We override the pauseProducing and resumeProducing methods in twisted's
FileBodyProducer so that they do not raise exceptions if the task has
already completed.
"""
def pauseProducing(self):
try:
super(FileBodyProducer, self).pauseProducing()
except task.TaskDone:
# task has already completed
pass
def resumeProducing(self):
try:
super(FileBodyProducer, self).resumeProducing()
except task.NotPaused:
# task was not paused (probably because it had already completed)
pass

View file

@ -17,19 +17,19 @@ import cgi
import logging
import random
import sys
import urllib
from six import string_types
from six.moves.urllib import parse as urlparse
from six import PY3, string_types
from six.moves import urllib
from canonicaljson import encode_canonical_json, json
import treq
from canonicaljson import encode_canonical_json
from prometheus_client import Counter
from signedjson.sign import sign_json
from twisted.internet import defer, protocol, reactor
from twisted.internet.error import DNSLookupError
from twisted.web._newclient import ResponseDone
from twisted.web.client import Agent, HTTPConnectionPool, readBody
from twisted.web.client import Agent, HTTPConnectionPool
from twisted.web.http_headers import Headers
import synapse.metrics
@ -58,13 +58,18 @@ incoming_responses_counter = Counter("synapse_http_matrixfederationclient_respon
MAX_LONG_RETRIES = 10
MAX_SHORT_RETRIES = 3
if PY3:
MAXINT = sys.maxsize
else:
MAXINT = sys.maxint
class MatrixFederationEndpointFactory(object):
def __init__(self, hs):
self.tls_client_options_factory = hs.tls_client_options_factory
def endpointForURI(self, uri):
destination = uri.netloc
destination = uri.netloc.decode('ascii')
return matrix_federation_endpoint(
reactor, destination, timeout=10,
@ -93,26 +98,32 @@ class MatrixFederationHttpClient(object):
)
self.clock = hs.get_clock()
self._store = hs.get_datastore()
self.version_string = hs.version_string
self.version_string = hs.version_string.encode('ascii')
self._next_id = 1
def _create_url(self, destination, path_bytes, param_bytes, query_bytes):
return urlparse.urlunparse(
("matrix", destination, path_bytes, param_bytes, query_bytes, "")
return urllib.parse.urlunparse(
(b"matrix", destination, path_bytes, param_bytes, query_bytes, b"")
)
@defer.inlineCallbacks
def _request(self, destination, method, path,
body_callback, headers_dict={}, param_bytes=b"",
query_bytes=b"", retry_on_dns_fail=True,
json=None, json_callback=None,
param_bytes=b"",
query=None, retry_on_dns_fail=True,
timeout=None, long_retries=False,
ignore_backoff=False,
backoff_on_404=False):
""" Creates and sends a request to the given server
"""
Creates and sends a request to the given server.
Args:
destination (str): The remote server to send the HTTP request to.
method (str): HTTP method
path (str): The HTTP path
json (dict or None): JSON to send in the body.
json_callback (func or None): A callback to generate the JSON.
query (dict or None): Query arguments.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
backoff_on_404 (bool): Back off if we get a 404
@ -146,22 +157,29 @@ class MatrixFederationHttpClient(object):
ignore_backoff=ignore_backoff,
)
destination = destination.encode("ascii")
headers_dict = {}
path_bytes = path.encode("ascii")
with limiter:
headers_dict[b"User-Agent"] = [self.version_string]
headers_dict[b"Host"] = [destination]
if query:
query_bytes = encode_query_args(query)
else:
query_bytes = b""
url_bytes = self._create_url(
destination, path_bytes, param_bytes, query_bytes
)
headers_dict = {
"User-Agent": [self.version_string],
"Host": [destination],
}
with limiter:
url = self._create_url(
destination.encode("ascii"), path_bytes, param_bytes, query_bytes
).decode('ascii')
txn_id = "%s-O-%s" % (method, self._next_id)
self._next_id = (self._next_id + 1) % (sys.maxint - 1)
self._next_id = (self._next_id + 1) % (MAXINT - 1)
outbound_logger.info(
"{%s} [%s] Sending request: %s %s",
txn_id, destination, method, url_bytes
txn_id, destination, method, url
)
# XXX: Would be much nicer to retry only at the transaction-layer
@ -171,23 +189,33 @@ class MatrixFederationHttpClient(object):
else:
retries_left = MAX_SHORT_RETRIES
http_url_bytes = urlparse.urlunparse(
("", "", path_bytes, param_bytes, query_bytes, "")
)
http_url = urllib.parse.urlunparse(
(b"", b"", path_bytes, param_bytes, query_bytes, b"")
).decode('ascii')
log_result = None
try:
while True:
producer = None
if body_callback:
producer = body_callback(method, http_url_bytes, headers_dict)
try:
request_deferred = self.agent.request(
if json_callback:
json = json_callback()
if json:
data = encode_canonical_json(json)
headers_dict["Content-Type"] = ["application/json"]
self.sign_request(
destination, method, http_url, headers_dict, json
)
else:
data = None
self.sign_request(destination, method, http_url, headers_dict)
request_deferred = treq.request(
method,
url_bytes,
Headers(headers_dict),
producer
url,
headers=Headers(headers_dict),
data=data,
agent=self.agent,
)
add_timeout_to_deferred(
request_deferred,
@ -218,7 +246,7 @@ class MatrixFederationHttpClient(object):
txn_id,
destination,
method,
url_bytes,
url,
_flatten_response_never_received(e),
)
@ -252,7 +280,7 @@ class MatrixFederationHttpClient(object):
# :'(
# Update transactions table?
with logcontext.PreserveLoggingContext():
body = yield readBody(response)
body = yield treq.content(response)
raise HttpResponseException(
response.code, response.phrase, body
)
@ -297,11 +325,11 @@ class MatrixFederationHttpClient(object):
auth_headers = []
for key, sig in request["signatures"][self.server_name].items():
auth_headers.append(bytes(
auth_headers.append((
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
self.server_name, key, sig,
)
))
)).encode('ascii')
)
headers_dict[b"Authorization"] = auth_headers
@ -347,24 +375,14 @@ class MatrixFederationHttpClient(object):
"""
if not json_data_callback:
def json_data_callback():
return data
def body_callback(method, url_bytes, headers_dict):
json_data = json_data_callback()
self.sign_request(
destination, method, url_bytes, headers_dict, json_data
)
producer = _JsonProducer(json_data)
return producer
json_data_callback = lambda: data
response = yield self._request(
destination,
"PUT",
path,
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
query_bytes=encode_query_args(args),
json_callback=json_data_callback,
query=args,
long_retries=long_retries,
timeout=timeout,
ignore_backoff=ignore_backoff,
@ -376,8 +394,8 @@ class MatrixFederationHttpClient(object):
check_content_type_is_json(response.headers)
with logcontext.PreserveLoggingContext():
body = yield readBody(response)
defer.returnValue(json.loads(body))
body = yield treq.json_content(response)
defer.returnValue(body)
@defer.inlineCallbacks
def post_json(self, destination, path, data={}, long_retries=False,
@ -410,20 +428,12 @@ class MatrixFederationHttpClient(object):
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
"""
def body_callback(method, url_bytes, headers_dict):
self.sign_request(
destination, method, url_bytes, headers_dict, data
)
return _JsonProducer(data)
response = yield self._request(
destination,
"POST",
path,
query_bytes=encode_query_args(args),
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
query=args,
json=data,
long_retries=long_retries,
timeout=timeout,
ignore_backoff=ignore_backoff,
@ -434,9 +444,9 @@ class MatrixFederationHttpClient(object):
check_content_type_is_json(response.headers)
with logcontext.PreserveLoggingContext():
body = yield readBody(response)
body = yield treq.json_content(response)
defer.returnValue(json.loads(body))
defer.returnValue(body)
@defer.inlineCallbacks
def get_json(self, destination, path, args=None, retry_on_dns_fail=True,
@ -471,16 +481,11 @@ class MatrixFederationHttpClient(object):
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict):
self.sign_request(destination, method, url_bytes, headers_dict)
return None
response = yield self._request(
destination,
"GET",
path,
query_bytes=encode_query_args(args),
body_callback=body_callback,
query=args,
retry_on_dns_fail=retry_on_dns_fail,
timeout=timeout,
ignore_backoff=ignore_backoff,
@ -491,9 +496,9 @@ class MatrixFederationHttpClient(object):
check_content_type_is_json(response.headers)
with logcontext.PreserveLoggingContext():
body = yield readBody(response)
body = yield treq.json_content(response)
defer.returnValue(json.loads(body))
defer.returnValue(body)
@defer.inlineCallbacks
def delete_json(self, destination, path, long_retries=False,
@ -523,13 +528,11 @@ class MatrixFederationHttpClient(object):
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
"""
response = yield self._request(
destination,
"DELETE",
path,
query_bytes=encode_query_args(args),
headers_dict={"Content-Type": ["application/json"]},
query=args,
long_retries=long_retries,
timeout=timeout,
ignore_backoff=ignore_backoff,
@ -540,9 +543,9 @@ class MatrixFederationHttpClient(object):
check_content_type_is_json(response.headers)
with logcontext.PreserveLoggingContext():
body = yield readBody(response)
body = yield treq.json_content(response)
defer.returnValue(json.loads(body))
defer.returnValue(body)
@defer.inlineCallbacks
def get_file(self, destination, path, output_stream, args={},
@ -569,26 +572,11 @@ class MatrixFederationHttpClient(object):
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
"""
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, string_types):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
query_bytes = urllib.urlencode(encoded_args, True)
logger.debug("Query bytes: %s Retry DNS: %s", query_bytes, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict):
self.sign_request(destination, method, url_bytes, headers_dict)
return None
response = yield self._request(
destination,
"GET",
path,
query_bytes=query_bytes,
body_callback=body_callback,
query=args,
retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff,
)
@ -639,30 +627,6 @@ def _readBodyToFile(response, stream, max_size):
return d
class _JsonProducer(object):
""" Used by the twisted http client to create the HTTP body from json
"""
def __init__(self, jsn):
self.reset(jsn)
def reset(self, jsn):
self.body = encode_canonical_json(jsn)
self.length = len(self.body)
def startProducing(self, consumer):
consumer.write(self.body)
return defer.succeed(None)
def pauseProducing(self):
pass
def stopProducing(self):
pass
def resumeProducing(self):
pass
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
reasons = ", ".join(
@ -693,7 +657,7 @@ def check_content_type_is_json(headers):
"No Content-Type header"
)
c_type = c_type[0] # only the first header
c_type = c_type[0].decode('ascii') # only the first header
val, options = cgi.parse_header(c_type)
if val != "application/json":
raise RuntimeError(
@ -711,6 +675,6 @@ def encode_query_args(args):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
query_bytes = urllib.urlencode(encoded_args, True)
query_bytes = urllib.parse.urlencode(encoded_args, True)
return query_bytes
return query_bytes.encode('utf8')

View file

@ -204,14 +204,14 @@ class SynapseRequest(Request):
self.start_time = time.time()
self.request_metrics = RequestMetrics()
self.request_metrics.start(
self.start_time, name=servlet_name, method=self.method,
self.start_time, name=servlet_name, method=self.method.decode('ascii'),
)
self.site.access_logger.info(
"%s - %s - Received request: %s %s",
self.getClientIP(),
self.site.site_tag,
self.method,
self.method.decode('ascii'),
self.get_redacted_uri()
)

View file

@ -40,9 +40,10 @@ REQUIREMENTS = {
"pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=17.1.0": ["twisted>=17.1.0"],
"treq>=15.1": ["treq>=15.1"],
# We use crypto.get_elliptic_curve which is only supported in >=0.15
"pyopenssl>=0.15": ["OpenSSL>=0.15"],
# Twisted has required pyopenssl 16.0 since about Twisted 16.6.
"pyopenssl>=16.0.0": ["OpenSSL>=16.0.0"],
"pyyaml": ["yaml"],
"pyasn1": ["pyasn1"],

View file

@ -590,9 +590,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
pending_commands = LaterGauge(
"synapse_replication_tcp_protocol_pending_commands",
"",
["name", "conn_id"],
["name"],
lambda: {
(p.name, p.conn_id): len(p.pending_commands) for p in connected_connections
(p.name,): len(p.pending_commands) for p in connected_connections
},
)
@ -607,9 +607,9 @@ def transport_buffer_size(protocol):
transport_send_buffer = LaterGauge(
"synapse_replication_tcp_protocol_transport_send_buffer",
"",
["name", "conn_id"],
["name"],
lambda: {
(p.name, p.conn_id): transport_buffer_size(p) for p in connected_connections
(p.name,): transport_buffer_size(p) for p in connected_connections
},
)
@ -632,9 +632,9 @@ def transport_kernel_read_buffer_size(protocol, read=True):
tcp_transport_kernel_send_buffer = LaterGauge(
"synapse_replication_tcp_protocol_transport_kernel_send_buffer",
"",
["name", "conn_id"],
["name"],
lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, False)
(p.name,): transport_kernel_read_buffer_size(p, False)
for p in connected_connections
},
)
@ -643,9 +643,9 @@ tcp_transport_kernel_send_buffer = LaterGauge(
tcp_transport_kernel_read_buffer = LaterGauge(
"synapse_replication_tcp_protocol_transport_kernel_read_buffer",
"",
["name", "conn_id"],
["name"],
lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, True)
(p.name,): transport_kernel_read_buffer_size(p, True)
for p in connected_connections
},
)
@ -654,9 +654,9 @@ tcp_transport_kernel_read_buffer = LaterGauge(
tcp_inbound_commands = LaterGauge(
"synapse_replication_tcp_protocol_inbound_commands",
"",
["command", "name", "conn_id"],
["command", "name"],
lambda: {
(k[0], p.name, p.conn_id): count
(k[0], p.name,): count
for p in connected_connections
for k, count in iteritems(p.inbound_commands_counter)
},
@ -665,9 +665,9 @@ tcp_inbound_commands = LaterGauge(
tcp_outbound_commands = LaterGauge(
"synapse_replication_tcp_protocol_outbound_commands",
"",
["command", "name", "conn_id"],
["command", "name"],
lambda: {
(k[0], p.name, p.conn_id): count
(k[0], p.name,): count
for p in connected_connections
for k, count in iteritems(p.outbound_commands_counter)
},

View file

@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.events.utils import (
format_event_for_client_v2_without_room_id,
format_event_raw,
serialize_event,
)
from synapse.handlers.presence import format_user_presence_state
@ -175,17 +176,28 @@ class SyncRestServlet(RestServlet):
@staticmethod
def encode_response(time_now, sync_result, access_token_id, filter):
if filter.event_format == 'client':
event_formatter = format_event_for_client_v2_without_room_id
elif filter.event_format == 'federation':
event_formatter = format_event_raw
else:
raise Exception("Unknown event format %s" % (filter.event_format, ))
joined = SyncRestServlet.encode_joined(
sync_result.joined, time_now, access_token_id, filter.event_fields
sync_result.joined, time_now, access_token_id,
filter.event_fields,
event_formatter,
)
invited = SyncRestServlet.encode_invited(
sync_result.invited, time_now, access_token_id,
event_formatter,
)
archived = SyncRestServlet.encode_archived(
sync_result.archived, time_now, access_token_id,
filter.event_fields,
event_formatter,
)
return {
@ -228,7 +240,7 @@ class SyncRestServlet(RestServlet):
}
@staticmethod
def encode_joined(rooms, time_now, token_id, event_fields):
def encode_joined(rooms, time_now, token_id, event_fields, event_formatter):
"""
Encode the joined rooms in a sync result
@ -240,7 +252,9 @@ class SyncRestServlet(RestServlet):
token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs
event_fields(list<str>): List of event fields to include. If empty,
all fields will be returned.
all fields will be returned.
event_formatter (func[dict]): function to convert from federation format
to client format
Returns:
dict[str, dict[str, object]]: the joined rooms list, in our
response format
@ -248,13 +262,14 @@ class SyncRestServlet(RestServlet):
joined = {}
for room in rooms:
joined[room.room_id] = SyncRestServlet.encode_room(
room, time_now, token_id, only_fields=event_fields
room, time_now, token_id, joined=True, only_fields=event_fields,
event_formatter=event_formatter,
)
return joined
@staticmethod
def encode_invited(rooms, time_now, token_id):
def encode_invited(rooms, time_now, token_id, event_formatter):
"""
Encode the invited rooms in a sync result
@ -264,7 +279,9 @@ class SyncRestServlet(RestServlet):
time_now(int): current time - used as a baseline for age
calculations
token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs
of transaction IDs
event_formatter (func[dict]): function to convert from federation format
to client format
Returns:
dict[str, dict[str, object]]: the invited rooms list, in our
@ -274,7 +291,7 @@ class SyncRestServlet(RestServlet):
for room in rooms:
invite = serialize_event(
room.invite, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_room_id,
event_format=event_formatter,
is_invite=True,
)
unsigned = dict(invite.get("unsigned", {}))
@ -288,7 +305,7 @@ class SyncRestServlet(RestServlet):
return invited
@staticmethod
def encode_archived(rooms, time_now, token_id, event_fields):
def encode_archived(rooms, time_now, token_id, event_fields, event_formatter):
"""
Encode the archived rooms in a sync result
@ -300,7 +317,9 @@ class SyncRestServlet(RestServlet):
token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs
event_fields(list<str>): List of event fields to include. If empty,
all fields will be returned.
all fields will be returned.
event_formatter (func[dict]): function to convert from federation format
to client format
Returns:
dict[str, dict[str, object]]: The invited rooms list, in our
response format
@ -308,13 +327,18 @@ class SyncRestServlet(RestServlet):
joined = {}
for room in rooms:
joined[room.room_id] = SyncRestServlet.encode_room(
room, time_now, token_id, joined=False, only_fields=event_fields
room, time_now, token_id, joined=False,
only_fields=event_fields,
event_formatter=event_formatter,
)
return joined
@staticmethod
def encode_room(room, time_now, token_id, joined=True, only_fields=None):
def encode_room(
room, time_now, token_id, joined,
only_fields, event_formatter,
):
"""
Args:
room (JoinedSyncResult|ArchivedSyncResult): sync result for a
@ -326,14 +350,15 @@ class SyncRestServlet(RestServlet):
joined (bool): True if the user is joined to this room - will mean
we handle ephemeral events
only_fields(list<str>): Optional. The list of event fields to include.
event_formatter (func[dict]): function to convert from federation format
to client format
Returns:
dict[str, object]: the room, encoded in our response format
"""
def serialize(event):
# TODO(mjark): Respect formatting requirements in the filter.
return serialize_event(
event, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_room_id,
event_format=event_formatter,
only_event_fields=only_fields,
)

View file

@ -199,10 +199,14 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Args:
user_id(str): the user_id to query
"""
if self.hs.config.limit_usage_by_mau:
# Trial users and guests should not be included as part of MAU group
is_guest = yield self.is_guest(user_id)
if is_guest:
return
is_trial = yield self.is_trial_user(user_id)
if is_trial:
# we don't track trial users in the MAU table.
return
last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id)

View file

@ -471,6 +471,7 @@ class AuthTestCase(unittest.TestCase):
def test_reserved_threepid(self):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 1
self.store.get_monthly_active_count = lambda: defer.succeed(2)
threepid = {'medium': 'email', 'address': 'reserved@server.com'}
unknown_threepid = {'medium': 'email', 'address': 'unreserved@server.com'}
self.hs.config.mau_limits_reserved_threepids = [threepid]

View file

@ -47,7 +47,7 @@ class FrontendProxyTests(HomeserverTestCase):
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
self.resource = (
site.resource.children["_matrix"].children["client"].children["r0"]
site.resource.children[b"_matrix"].children[b"client"].children[b"r0"]
)
request, channel = self.make_request("PUT", "presence/a/status")
@ -77,7 +77,7 @@ class FrontendProxyTests(HomeserverTestCase):
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
self.resource = (
site.resource.children["_matrix"].children["client"].children["r0"]
site.resource.children[b"_matrix"].children[b"client"].children[b"r0"]
)
request, channel = self.make_request("PUT", "presence/a/status")

View file

@ -33,7 +33,7 @@ from ..utils import (
)
def _expect_edu(destination, edu_type, content, origin="test"):
def _expect_edu_transaction(edu_type, content, origin="test"):
return {
"origin": origin,
"origin_server_ts": 1000000,
@ -42,10 +42,8 @@ def _expect_edu(destination, edu_type, content, origin="test"):
}
def _make_edu_json(origin, edu_type, content):
return json.dumps(_expect_edu("test", edu_type, content, origin=origin)).encode(
'utf8'
)
def _make_edu_transaction_json(edu_type, content):
return json.dumps(_expect_edu_transaction(edu_type, content)).encode('utf8')
class TypingNotificationsTestCase(unittest.TestCase):
@ -190,8 +188,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
call(
"farm",
path="/_matrix/federation/v1/send/1000000/",
data=_expect_edu(
"farm",
data=_expect_edu_transaction(
"m.typing",
content={
"room_id": self.room_id,
@ -221,11 +218,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.assertEquals(self.event_source.get_current_key(), 0)
yield self.mock_federation_resource.trigger(
(code, response) = yield self.mock_federation_resource.trigger(
"PUT",
"/_matrix/federation/v1/send/1000000/",
_make_edu_json(
"farm",
_make_edu_transaction_json(
"m.typing",
content={
"room_id": self.room_id,
@ -233,7 +229,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
"typing": True,
},
),
federation_auth=True,
federation_auth_origin=b'farm',
)
self.on_new_event.assert_has_calls(
@ -264,8 +260,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
call(
"farm",
path="/_matrix/federation/v1/send/1000000/",
data=_expect_edu(
"farm",
data=_expect_edu_transaction(
"m.typing",
content={
"room_id": self.room_id,

View file

@ -22,39 +22,24 @@ from six.moves.urllib import parse as urlparse
from twisted.internet import defer
import synapse.rest.client.v1.room
from synapse.api.constants import Membership
from synapse.http.server import JsonResource
from synapse.types import UserID
from synapse.util import Clock
from synapse.rest.client.v1 import room
from tests import unittest
from tests.server import (
ThreadedMemoryReactorClock,
make_request,
render,
setup_test_homeserver,
)
from .utils import RestHelper
PATH_PREFIX = b"/_matrix/client/api/v1"
class RoomBase(unittest.TestCase):
class RoomBase(unittest.HomeserverTestCase):
rmcreator_id = None
def setUp(self):
servlets = [room.register_servlets, room.register_deprecated_servlets]
self.clock = ThreadedMemoryReactorClock()
self.hs_clock = Clock(self.clock)
def make_homeserver(self, reactor, clock):
self.hs = setup_test_homeserver(
self.addCleanup,
self.hs = self.setup_test_homeserver(
"red",
http_client=None,
clock=self.hs_clock,
reactor=self.clock,
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
)
@ -63,42 +48,21 @@ class RoomBase(unittest.TestCase):
self.hs.get_federation_handler = Mock(return_value=Mock())
def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.helper.auth_user_id),
"token_id": 1,
"is_guest": False,
}
def get_user_by_req(request, allow_guest=False, rights="access"):
return synapse.types.create_requester(
UserID.from_string(self.helper.auth_user_id), 1, False, None
)
self.hs.get_auth().get_user_by_req = get_user_by_req
self.hs.get_auth().get_user_by_access_token = get_user_by_access_token
self.hs.get_auth().get_access_token_from_request = Mock(return_value=b"1234")
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
self.hs.get_datastore().insert_client_ip = _insert_client_ip
self.resource = JsonResource(self.hs)
synapse.rest.client.v1.room.register_servlets(self.hs, self.resource)
synapse.rest.client.v1.room.register_deprecated_servlets(self.hs, self.resource)
self.helper = RestHelper(self.hs, self.resource, self.user_id)
return self.hs
class RoomPermissionsTestCase(RoomBase):
""" Tests room permissions. """
user_id = b"@sid1:red"
rmcreator_id = b"@notme:red"
user_id = "@sid1:red"
rmcreator_id = "@notme:red"
def setUp(self):
super(RoomPermissionsTestCase, self).setUp()
def prepare(self, reactor, clock, hs):
self.helper.auth_user_id = self.rmcreator_id
# create some rooms under the name rmcreator_id
@ -114,22 +78,20 @@ class RoomPermissionsTestCase(RoomBase):
self.created_rmid_msg_path = (
"rooms/%s/send/m.room.message/a1" % (self.created_rmid)
).encode('ascii')
request, channel = make_request(
b"PUT",
self.created_rmid_msg_path,
b'{"msgtype":"m.text","body":"test msg"}',
request, channel = self.make_request(
"PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}'
)
render(request, self.resource, self.clock)
self.assertEquals(channel.result["code"], b"200", channel.result)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
# set topic for public room
request, channel = make_request(
b"PUT",
request, channel = self.make_request(
"PUT",
("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode('ascii'),
b'{"topic":"Public Room Topic"}',
)
render(request, self.resource, self.clock)
self.assertEquals(channel.result["code"], b"200", channel.result)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
# auth as user_id now
self.helper.auth_user_id = self.user_id
@ -140,128 +102,128 @@ class RoomPermissionsTestCase(RoomBase):
seq = iter(range(100))
def send_msg_path():
return b"/rooms/%s/send/m.room.message/mid%s" % (
return "/rooms/%s/send/m.room.message/mid%s" % (
self.created_rmid,
str(next(seq)).encode('ascii'),
str(next(seq)),
)
# send message in uncreated room, expect 403
request, channel = make_request(
b"PUT",
b"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
request, channel = self.make_request(
"PUT",
"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
msg_content,
)
render(request, self.resource, self.clock)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room not joined (no state), expect 403
request, channel = make_request(b"PUT", send_msg_path(), msg_content)
render(request, self.resource, self.clock)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", send_msg_path(), msg_content)
self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room and invited, expect 403
self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
request, channel = make_request(b"PUT", send_msg_path(), msg_content)
render(request, self.resource, self.clock)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", send_msg_path(), msg_content)
self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room and joined, expect 200
self.helper.join(room=self.created_rmid, user=self.user_id)
request, channel = make_request(b"PUT", send_msg_path(), msg_content)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", send_msg_path(), msg_content)
self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# send message in created room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id)
request, channel = make_request(b"PUT", send_msg_path(), msg_content)
render(request, self.resource, self.clock)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", send_msg_path(), msg_content)
self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_topic_perms(self):
topic_content = b'{"topic":"My Topic Name"}'
topic_path = b"/rooms/%s/state/m.room.topic" % self.created_rmid
topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid
# set/get topic in uncreated room, expect 403
request, channel = make_request(
b"PUT", b"/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content
request, channel = self.make_request(
"PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content
)
render(request, self.resource, self.clock)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
request, channel = make_request(
b"GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid
self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
request, channel = self.make_request(
"GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid
)
render(request, self.resource, self.clock)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room not joined, expect 403
request, channel = make_request(b"PUT", topic_path, topic_content)
render(request, self.resource, self.clock)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
request, channel = make_request(b"GET", topic_path)
render(request, self.resource, self.clock)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", topic_path, topic_content)
self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
request, channel = self.make_request("GET", topic_path)
self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# set topic in created PRIVATE room and invited, expect 403
self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
request, channel = make_request(b"PUT", topic_path, topic_content)
render(request, self.resource, self.clock)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", topic_path, topic_content)
self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# get topic in created PRIVATE room and invited, expect 403
request, channel = make_request(b"GET", topic_path)
render(request, self.resource, self.clock)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("GET", topic_path)
self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room and joined, expect 200
self.helper.join(room=self.created_rmid, user=self.user_id)
# Only room ops can set topic by default
self.helper.auth_user_id = self.rmcreator_id
request, channel = make_request(b"PUT", topic_path, topic_content)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", topic_path, topic_content)
self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.helper.auth_user_id = self.user_id
request, channel = make_request(b"GET", topic_path)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
self.assert_dict(json.loads(topic_content), channel.json_body)
request, channel = self.make_request("GET", topic_path)
self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(topic_content.decode('utf8')), channel.json_body)
# set/get topic in created PRIVATE room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id)
request, channel = make_request(b"PUT", topic_path, topic_content)
render(request, self.resource, self.clock)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
request, channel = make_request(b"GET", topic_path)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", topic_path, topic_content)
self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
request, channel = self.make_request("GET", topic_path)
self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# get topic in PUBLIC room, not joined, expect 403
request, channel = make_request(
b"GET", b"/rooms/%s/state/m.room.topic" % self.created_public_rmid
request, channel = self.make_request(
"GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid
)
render(request, self.resource, self.clock)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# set topic in PUBLIC room, not joined, expect 403
request, channel = make_request(
b"PUT",
b"/rooms/%s/state/m.room.topic" % self.created_public_rmid,
request, channel = self.make_request(
"PUT",
"/rooms/%s/state/m.room.topic" % self.created_public_rmid,
topic_content,
)
render(request, self.resource, self.clock)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
def _test_get_membership(self, room=None, members=[], expect_code=None):
for member in members:
path = b"/rooms/%s/state/m.room.member/%s" % (room, member)
request, channel = make_request(b"GET", path)
render(request, self.resource, self.clock)
self.assertEquals(expect_code, int(channel.result["code"]))
path = "/rooms/%s/state/m.room.member/%s" % (room, member)
request, channel = self.make_request("GET", path)
self.render(request)
self.assertEquals(expect_code, channel.code)
def test_membership_basic_room_perms(self):
# === room does not exist ===
@ -428,217 +390,211 @@ class RoomPermissionsTestCase(RoomBase):
class RoomsMemberListTestCase(RoomBase):
""" Tests /rooms/$room_id/members/list REST events."""
user_id = b"@sid1:red"
user_id = "@sid1:red"
def test_get_member_list(self):
room_id = self.helper.create_room_as(self.user_id)
request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("GET", "/rooms/%s/members" % room_id)
self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
def test_get_member_list_no_room(self):
request, channel = make_request(b"GET", b"/rooms/roomdoesnotexist/members")
render(request, self.resource, self.clock)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission(self):
room_id = self.helper.create_room_as(b"@some_other_guy:red")
request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id)
render(request, self.resource, self.clock)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
room_id = self.helper.create_room_as("@some_other_guy:red")
request, channel = self.make_request("GET", "/rooms/%s/members" % room_id)
self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_get_member_list_mixed_memberships(self):
room_creator = b"@some_other_guy:red"
room_creator = "@some_other_guy:red"
room_id = self.helper.create_room_as(room_creator)
room_path = b"/rooms/%s/members" % room_id
room_path = "/rooms/%s/members" % room_id
self.helper.invite(room=room_id, src=room_creator, targ=self.user_id)
# can't see list if you're just invited.
request, channel = make_request(b"GET", room_path)
render(request, self.resource, self.clock)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("GET", room_path)
self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
self.helper.join(room=room_id, user=self.user_id)
# can see list now joined
request, channel = make_request(b"GET", room_path)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("GET", room_path)
self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.helper.leave(room=room_id, user=self.user_id)
# can see old list once left
request, channel = make_request(b"GET", room_path)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("GET", room_path)
self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
class RoomsCreateTestCase(RoomBase):
""" Tests /rooms and /rooms/$room_id REST events. """
user_id = b"@sid1:red"
user_id = "@sid1:red"
def test_post_room_no_keys(self):
# POST with no config keys, expect new room id
request, channel = make_request(b"POST", b"/createRoom", b"{}")
request, channel = self.make_request("POST", "/createRoom", "{}")
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]), channel.result)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_visibility_key(self):
# POST with visibility config key, expect new room id
request, channel = make_request(
b"POST", b"/createRoom", b'{"visibility":"private"}'
request, channel = self.make_request(
"POST", "/createRoom", b'{"visibility":"private"}'
)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]))
self.render(request)
self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_custom_key(self):
# POST with custom config keys, expect new room id
request, channel = make_request(b"POST", b"/createRoom", b'{"custom":"stuff"}')
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]))
request, channel = self.make_request(
"POST", "/createRoom", b'{"custom":"stuff"}'
)
self.render(request)
self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_known_and_unknown_keys(self):
# POST with custom + known config keys, expect new room id
request, channel = make_request(
b"POST", b"/createRoom", b'{"visibility":"private","custom":"things"}'
request, channel = self.make_request(
"POST", "/createRoom", b'{"visibility":"private","custom":"things"}'
)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]))
self.render(request)
self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_invalid_content(self):
# POST with invalid content / paths, expect 400
request, channel = make_request(b"POST", b"/createRoom", b'{"visibili')
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]))
request, channel = self.make_request("POST", "/createRoom", b'{"visibili')
self.render(request)
self.assertEquals(400, channel.code)
request, channel = make_request(b"POST", b"/createRoom", b'["hello"]')
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]))
request, channel = self.make_request("POST", "/createRoom", b'["hello"]')
self.render(request)
self.assertEquals(400, channel.code)
class RoomTopicTestCase(RoomBase):
""" Tests /rooms/$room_id/topic REST events. """
user_id = b"@sid1:red"
def setUp(self):
super(RoomTopicTestCase, self).setUp()
user_id = "@sid1:red"
def prepare(self, reactor, clock, hs):
# create the room
self.room_id = self.helper.create_room_as(self.user_id)
self.path = b"/rooms/%s/state/m.room.topic" % (self.room_id,)
self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,)
def test_invalid_puts(self):
# missing keys or invalid json
request, channel = make_request(b"PUT", self.path, '{}')
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", self.path, '{}')
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", self.path, '{"_name":"bob"}')
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", self.path, '{"_name":"bo"}')
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", self.path, '{"nao')
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", self.path, '{"nao')
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(
b"PUT", self.path, '[{"_name":"bob"},{"_name":"jill"}]'
request, channel = self.make_request(
"PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]'
)
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", self.path, 'text only')
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", self.path, 'text only')
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", self.path, '')
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", self.path, '')
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
# valid key, wrong type
content = '{"topic":["Topic name"]}'
request, channel = make_request(b"PUT", self.path, content)
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", self.path, content)
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_topic(self):
# nothing should be there
request, channel = make_request(b"GET", self.path)
render(request, self.resource, self.clock)
self.assertEquals(404, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("GET", self.path)
self.render(request)
self.assertEquals(404, channel.code, msg=channel.result["body"])
# valid put
content = '{"topic":"Topic name"}'
request, channel = make_request(b"PUT", self.path, content)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", self.path, content)
self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# valid get
request, channel = make_request(b"GET", self.path)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("GET", self.path)
self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
def test_rooms_topic_with_extra_keys(self):
# valid put with extra keys
content = '{"topic":"Seasons","subtopic":"Summer"}'
request, channel = make_request(b"PUT", self.path, content)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", self.path, content)
self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# valid get
request, channel = make_request(b"GET", self.path)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("GET", self.path)
self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
class RoomMemberStateTestCase(RoomBase):
""" Tests /rooms/$room_id/members/$user_id/state REST events. """
user_id = b"@sid1:red"
user_id = "@sid1:red"
def setUp(self):
super(RoomMemberStateTestCase, self).setUp()
def prepare(self, reactor, clock, hs):
self.room_id = self.helper.create_room_as(self.user_id)
def tearDown(self):
pass
def test_invalid_puts(self):
path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
# missing keys or invalid json
request, channel = make_request(b"PUT", path, '{}')
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", path, '{}')
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, '{"_name":"bob"}')
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", path, '{"_name":"bo"}')
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, '{"nao')
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", path, '{"nao')
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(
b"PUT", path, b'[{"_name":"bob"},{"_name":"jill"}]'
request, channel = self.make_request(
"PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]'
)
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, 'text only')
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", path, 'text only')
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, '')
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", path, '')
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
# valid keys, wrong types
content = '{"membership":["%s","%s","%s"]}' % (
@ -646,9 +602,9 @@ class RoomMemberStateTestCase(RoomBase):
Membership.JOIN,
Membership.LEAVE,
)
request, channel = make_request(b"PUT", path, content.encode('ascii'))
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", path, content.encode('ascii'))
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_members_self(self):
path = "/rooms/%s/state/m.room.member/%s" % (
@ -658,13 +614,13 @@ class RoomMemberStateTestCase(RoomBase):
# valid join message (NOOP since we made the room)
content = '{"membership":"%s"}' % Membership.JOIN
request, channel = make_request(b"PUT", path, content.encode('ascii'))
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", path, content.encode('ascii'))
self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
request, channel = make_request(b"GET", path, None)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("GET", path, None)
self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
expected_response = {"membership": Membership.JOIN}
self.assertEquals(expected_response, channel.json_body)
@ -678,13 +634,13 @@ class RoomMemberStateTestCase(RoomBase):
# valid invite message
content = '{"membership":"%s"}' % Membership.INVITE
request, channel = make_request(b"PUT", path, content)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", path, content)
self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
request, channel = make_request(b"GET", path, None)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("GET", path, None)
self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assertEquals(json.loads(content), channel.json_body)
def test_rooms_members_other_custom_keys(self):
@ -699,13 +655,13 @@ class RoomMemberStateTestCase(RoomBase):
Membership.INVITE,
"Join us!",
)
request, channel = make_request(b"PUT", path, content)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", path, content)
self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
request, channel = make_request(b"GET", path, None)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("GET", path, None)
self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assertEquals(json.loads(content), channel.json_body)
@ -714,60 +670,58 @@ class RoomMessagesTestCase(RoomBase):
user_id = "@sid1:red"
def setUp(self):
super(RoomMessagesTestCase, self).setUp()
def prepare(self, reactor, clock, hs):
self.room_id = self.helper.create_room_as(self.user_id)
def test_invalid_puts(self):
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
# missing keys or invalid json
request, channel = make_request(b"PUT", path, '{}')
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", path, b'{}')
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, '{"_name":"bob"}')
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", path, b'{"_name":"bo"}')
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, '{"nao')
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", path, b'{"nao')
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(
b"PUT", path, '[{"_name":"bob"},{"_name":"jill"}]'
request, channel = self.make_request(
"PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]'
)
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, 'text only')
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", path, b'text only')
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, '')
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
request, channel = self.make_request("PUT", path, b'')
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_messages_sent(self):
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
content = '{"body":"test","msgtype":{"type":"a"}}'
request, channel = make_request(b"PUT", path, content)
render(request, self.resource, self.clock)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
content = b'{"body":"test","msgtype":{"type":"a"}}'
request, channel = self.make_request("PUT", path, content)
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
# custom message types
content = '{"body":"test","msgtype":"test.custom.text"}'
request, channel = make_request(b"PUT", path, content)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
content = b'{"body":"test","msgtype":"test.custom.text"}'
request, channel = self.make_request("PUT", path, content)
self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# m.text message type
path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id))
content = '{"body":"test2","msgtype":"m.text"}'
request, channel = make_request(b"PUT", path, content)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
content = b'{"body":"test2","msgtype":"m.text"}'
request, channel = self.make_request("PUT", path, content)
self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
class RoomInitialSyncTestCase(RoomBase):
@ -775,16 +729,16 @@ class RoomInitialSyncTestCase(RoomBase):
user_id = "@sid1:red"
def setUp(self):
super(RoomInitialSyncTestCase, self).setUp()
def prepare(self, reactor, clock, hs):
# create the room
self.room_id = self.helper.create_room_as(self.user_id)
def test_initial_sync(self):
request, channel = make_request(b"GET", "/rooms/%s/initialSync" % self.room_id)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]))
request, channel = self.make_request(
"GET", "/rooms/%s/initialSync" % self.room_id
)
self.render(request)
self.assertEquals(200, channel.code)
self.assertEquals(self.room_id, channel.json_body["room_id"])
self.assertEquals("join", channel.json_body["membership"])
@ -819,17 +773,16 @@ class RoomMessageListTestCase(RoomBase):
user_id = "@sid1:red"
def setUp(self):
super(RoomMessageListTestCase, self).setUp()
def prepare(self, reactor, clock, hs):
self.room_id = self.helper.create_room_as(self.user_id)
def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0_0_0_0_0"
request, channel = make_request(
b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
request, channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]))
self.render(request)
self.assertEquals(200, channel.code)
self.assertTrue("start" in channel.json_body)
self.assertEquals(token, channel.json_body['start'])
self.assertTrue("chunk" in channel.json_body)
@ -837,11 +790,11 @@ class RoomMessageListTestCase(RoomBase):
def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0_0_0_0_0"
request, channel = make_request(
b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
request, channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
render(request, self.resource, self.clock)
self.assertEquals(200, int(channel.result["code"]))
self.render(request)
self.assertEquals(200, channel.code)
self.assertTrue("start" in channel.json_body)
self.assertEquals(token, channel.json_body['start'])
self.assertTrue("chunk" in channel.json_body)

View file

@ -62,12 +62,6 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertTrue(
set(
[
"next_batch",
"rooms",
"account_data",
"to_device",
"device_lists",
]
["next_batch", "rooms", "account_data", "to_device", "device_lists"]
).issubset(set(channel.json_body.keys()))
)

View file

@ -65,7 +65,7 @@ class FakeChannel(object):
def getPeer(self):
# We give an address so that getClientIP returns a non null entry,
# causing us to record the MAU
return address.IPv4Address(b"TCP", "127.0.0.1", 3423)
return address.IPv4Address("TCP", "127.0.0.1", 3423)
def getHost(self):
return None

View file

@ -80,12 +80,11 @@ class TestResourceLimitsServerNotices(unittest.TestCase):
self._rlsn._auth.check_auth_blocking = Mock()
mock_event = Mock(
type=EventTypes.Message,
content={"msgtype": ServerNoticeMsgType},
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
)
self._rlsn._store.get_events = Mock(return_value=defer.succeed(
{"123": mock_event}
))
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
# Would be better to check the content, but once == remove blocking event
@ -99,12 +98,11 @@ class TestResourceLimitsServerNotices(unittest.TestCase):
)
mock_event = Mock(
type=EventTypes.Message,
content={"msgtype": ServerNoticeMsgType},
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
)
self._rlsn._store.get_events = Mock(return_value=defer.succeed(
{"123": mock_event}
))
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
self._send_notice.assert_not_called()
@ -177,13 +175,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase):
@defer.inlineCallbacks
def test_server_notice_only_sent_once(self):
self.store.get_monthly_active_count = Mock(
return_value=1000,
)
self.store.get_monthly_active_count = Mock(return_value=1000)
self.store.user_last_seen_monthly_active = Mock(
return_value=1000,
)
self.store.user_last_seen_monthly_active = Mock(return_value=1000)
# Call the function multiple times to ensure we only send the notice once
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
@ -193,12 +187,12 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase):
# Now lets get the last load of messages in the service notice room and
# check that there is only one server notice
room_id = yield self.server_notices_manager.get_notice_room_for_user(
self.user_id,
self.user_id
)
token = yield self.event_source.get_current_token()
events, _ = yield self.store.get_recent_events_for_room(
room_id, limit=100, end_token=token.room_key,
room_id, limit=100, end_token=token.room_key
)
count = 0

View file

@ -101,13 +101,11 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50
user_id = "@user:server"
yield self.store.register(user_id=user_id, token="123", password_hash=None)
active = yield self.store.user_last_seen_monthly_active(user_id)
self.assertFalse(active)
yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
)
yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
)

View file

@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
from twisted.internet import defer
from tests.unittest import HomeserverTestCase
@ -23,7 +26,8 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
hs = self.setup_test_homeserver()
self.store = hs.get_datastore()
hs.config.limit_usage_by_mau = True
hs.config.max_mau_value = 50
# Advance the clock a bit
reactor.advance(FORTY_DAYS)
@ -73,7 +77,7 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
active_count = self.store.get_monthly_active_count()
self.assertEquals(self.get_success(active_count), user_num)
# Test that regalar users are removed from the db
# Test that regular users are removed from the db
ru_count = 2
self.store.upsert_monthly_active_user("@ru1:server")
self.store.upsert_monthly_active_user("@ru2:server")
@ -139,3 +143,43 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
count = self.store.get_monthly_active_count()
self.assertEquals(self.get_success(count), 0)
def test_populate_monthly_users_is_guest(self):
# Test that guest users are not added to mau list
user_id = "user_id"
self.store.register(
user_id=user_id, token="123", password_hash=None, make_guest=True
)
self.store.upsert_monthly_active_user = Mock()
self.store.populate_monthly_active_users(user_id)
self.pump()
self.store.upsert_monthly_active_user.assert_not_called()
def test_populate_monthly_users_should_update(self):
self.store.upsert_monthly_active_user = Mock()
self.store.is_trial_user = Mock(
return_value=defer.succeed(False)
)
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(None)
)
self.store.populate_monthly_active_users('user_id')
self.pump()
self.store.upsert_monthly_active_user.assert_called_once()
def test_populate_monthly_users_should_not_update(self):
self.store.upsert_monthly_active_user = Mock()
self.store.is_trial_user = Mock(
return_value=defer.succeed(False)
)
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(
self.hs.get_clock().time_msec()
)
)
self.store.populate_monthly_active_users('user_id')
self.pump()
self.store.upsert_monthly_active_user.assert_not_called()

View file

@ -185,8 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters out members with types=[]
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_cache,
group, [], filtered_types=[EventTypes.Member]
self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member]
)
self.assertEqual(is_all, True)
@ -200,19 +199,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache,
group, [], filtered_types=[EventTypes.Member]
group,
[],
filtered_types=[EventTypes.Member],
)
self.assertEqual(is_all, True)
self.assertDictEqual(
{},
state_dict,
)
self.assertDictEqual({}, state_dict)
# test _get_some_state_from_cache correctly filters in members with wildcard types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_cache,
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
group,
[(EventTypes.Member, None)],
filtered_types=[EventTypes.Member],
)
self.assertEqual(is_all, True)
@ -226,7 +226,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache,
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
group,
[(EventTypes.Member, None)],
filtered_types=[EventTypes.Member],
)
self.assertEqual(is_all, True)
@ -264,18 +266,15 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
self.assertEqual(is_all, True)
self.assertDictEqual(
{
(e5.type, e5.state_key): e5.event_id,
},
state_dict,
)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
# test _get_some_state_from_cache correctly filters in members with specific types
# and no filtered_types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache,
group, [(EventTypes.Member, e5.state_key)], filtered_types=None
group,
[(EventTypes.Member, e5.state_key)],
filtered_types=None,
)
self.assertEqual(is_all, True)
@ -305,9 +304,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
key=group,
value=state_dict_ids,
# list fetched keys so it knows it's partial
fetched_keys=(
(e1.type, e1.state_key),
),
fetched_keys=((e1.type, e1.state_key),),
)
(is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
@ -315,20 +312,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
self.assertEqual(is_all, False)
self.assertEqual(
known_absent,
set(
[
(e1.type, e1.state_key),
]
),
)
self.assertDictEqual(
state_dict_ids,
{
(e1.type, e1.state_key): e1.event_id,
},
)
self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
############################################
# test that things work with a partial cache
@ -336,8 +321,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters out members with types=[]
room_id = self.room.to_string()
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_cache,
group, [], filtered_types=[EventTypes.Member]
self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member]
)
self.assertEqual(is_all, False)
@ -346,7 +330,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
room_id = self.room.to_string()
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache,
group, [], filtered_types=[EventTypes.Member]
group,
[],
filtered_types=[EventTypes.Member],
)
self.assertEqual(is_all, True)
@ -355,20 +341,19 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters in members wildcard types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_cache,
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
group,
[(EventTypes.Member, None)],
filtered_types=[EventTypes.Member],
)
self.assertEqual(is_all, False)
self.assertDictEqual(
{
(e1.type, e1.state_key): e1.event_id,
},
state_dict,
)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache,
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
group,
[(EventTypes.Member, None)],
filtered_types=[EventTypes.Member],
)
self.assertEqual(is_all, True)
@ -389,12 +374,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
self.assertEqual(is_all, False)
self.assertDictEqual(
{
(e1.type, e1.state_key): e1.event_id,
},
state_dict,
)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache,
@ -404,18 +384,15 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
self.assertEqual(is_all, True)
self.assertDictEqual(
{
(e5.type, e5.state_key): e5.event_id,
},
state_dict,
)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
# test _get_some_state_from_cache correctly filters in members with specific types
# and no filtered_types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_cache,
group, [(EventTypes.Member, e5.state_key)], filtered_types=None
group,
[(EventTypes.Member, e5.state_key)],
filtered_types=None,
)
self.assertEqual(is_all, False)
@ -423,13 +400,10 @@ class StateStoreTestCase(tests.unittest.TestCase):
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache,
group, [(EventTypes.Member, e5.state_key)], filtered_types=None
group,
[(EventTypes.Member, e5.state_key)],
filtered_types=None,
)
self.assertEqual(is_all, True)
self.assertDictEqual(
{
(e5.type, e5.state_key): e5.event_id,
},
state_dict,
)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)

View file

@ -185,20 +185,20 @@ class TestMauLimit(unittest.TestCase):
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def create_user(self, localpart):
request_data = json.dumps({
"username": localpart,
"password": "monkey",
"auth": {"type": LoginType.DUMMY},
})
request_data = json.dumps(
{
"username": localpart,
"password": "monkey",
"auth": {"type": LoginType.DUMMY},
}
)
request, channel = make_request(b"POST", b"/register", request_data)
request, channel = make_request("POST", "/register", request_data)
render(request, self.resource, self.reactor)
if channel.result["code"] != b"200":
if channel.code != 200:
raise HttpResponseException(
int(channel.result["code"]),
channel.result["reason"],
channel.result["body"],
channel.code, channel.result["reason"], channel.result["body"]
).to_synapse_error()
access_token = channel.json_body["access_token"]
@ -206,12 +206,12 @@ class TestMauLimit(unittest.TestCase):
return access_token
def do_sync_for_user(self, token):
request, channel = make_request(b"GET", b"/sync", access_token=token)
request, channel = make_request(
"GET", "/sync", access_token=token.encode('ascii')
)
render(request, self.resource, self.reactor)
if channel.result["code"] != b"200":
if channel.code != 200:
raise HttpResponseException(
int(channel.result["code"]),
channel.result["reason"],
channel.result["body"],
channel.code, channel.result["reason"], channel.result["body"]
).to_synapse_error()

View file

@ -180,7 +180,7 @@ class StateTestCase(unittest.TestCase):
graph = Graph(
nodes={
"START": DictObj(
type=EventTypes.Create, state_key="", content={}, depth=1,
type=EventTypes.Create, state_key="", content={}, depth=1
),
"A": DictObj(type=EventTypes.Message, depth=2),
"B": DictObj(type=EventTypes.Message, depth=3),

View file

@ -100,8 +100,13 @@ class TestHomeServer(HomeServer):
@defer.inlineCallbacks
def setup_test_homeserver(
cleanup_func, name="test", datastore=None, config=None, reactor=None,
homeserverToUse=TestHomeServer, **kargs
cleanup_func,
name="test",
datastore=None,
config=None,
reactor=None,
homeserverToUse=TestHomeServer,
**kargs
):
"""
Setup a homeserver suitable for running tests against. Keyword arguments
@ -147,6 +152,7 @@ def setup_test_homeserver(
config.hs_disabled_message = ""
config.hs_disabled_limit_type = ""
config.max_mau_value = 50
config.mau_trial_days = 0
config.mau_limits_reserved_threepids = []
config.admin_contact = None
config.rc_messages_per_second = 10000
@ -321,7 +327,9 @@ class MockHttpResource(HttpServer):
@patch('twisted.web.http.Request')
@defer.inlineCallbacks
def trigger(self, http_method, path, content, mock_request, federation_auth=False):
def trigger(
self, http_method, path, content, mock_request, federation_auth_origin=None
):
""" Fire an HTTP event.
Args:
@ -330,6 +338,7 @@ class MockHttpResource(HttpServer):
content : The HTTP body
mock_request : Mocked request to pass to the event so it can get
content.
federation_auth_origin (bytes|None): domain to authenticate as, for federation
Returns:
A tuple of (code, response)
Raises:
@ -350,8 +359,10 @@ class MockHttpResource(HttpServer):
mock_request.getClientIP.return_value = "-"
headers = {}
if federation_auth:
headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="]
if federation_auth_origin is not None:
headers[b"Authorization"] = [
b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
]
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
# return the right path if the event requires it
@ -570,16 +581,16 @@ def create_room(hs, room_id, creator_id):
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()
builder = event_builder_factory.new({
"type": EventTypes.Create,
"state_key": "",
"sender": creator_id,
"room_id": room_id,
"content": {},
})
event, context = yield event_creation_handler.create_new_client_event(
builder
builder = event_builder_factory.new(
{
"type": EventTypes.Create,
"state_key": "",
"sender": creator_id,
"room_id": room_id,
"content": {},
}
)
event, context = yield event_creation_handler.create_new_client_event(builder)
yield store.persist_event(event, context)