Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2018-08-20 11:12:18 +01:00
commit f5abc10724
58 changed files with 707 additions and 356 deletions

1
changelog.d/3694.feature Normal file
View file

@ -0,0 +1 @@
Synapse's presence functionality can now be disabled with the "use_presence" configuration option.

1
changelog.d/3700.bugfix Normal file
View file

@ -0,0 +1 @@
Improve HTTP request logging to include all requests

1
changelog.d/3701.bugfix Normal file
View file

@ -0,0 +1 @@
Avoid timing out requests while we are streaming back the response

1
changelog.d/3703.removal Normal file
View file

@ -0,0 +1 @@
The Shared-Secret registration method of the legacy v1/register REST endpoint has been removed. For a replacement, please see [the admin/register API documentation](https://github.com/matrix-org/synapse/blob/master/docs/admin_api/register_api.rst).

1
changelog.d/3707.misc Normal file
View file

@ -0,0 +1 @@
add new error type ResourceLimit

1
changelog.d/3708.feature Normal file
View file

@ -0,0 +1 @@
For resource limit blocked users, prevent writing into rooms

1
changelog.d/3709.misc Normal file
View file

@ -0,0 +1 @@
Logcontexts for replication command handlers

1
changelog.d/3710.bugfix Normal file
View file

@ -0,0 +1 @@
Fix "Starting db txn 'get_all_updated_receipts' from sentinel context" warning

1
changelog.d/3712.misc Normal file
View file

@ -0,0 +1 @@
Update admin register API documentation to reference a real user ID.

1
changelog.d/3713.bugfix Normal file
View file

@ -0,0 +1 @@
Support more federation endpoints on workers

View file

@ -33,7 +33,7 @@ As an example::
< { < {
"access_token": "token_here", "access_token": "token_here",
"user_id": "@pepper_roni@test", "user_id": "@pepper_roni:localhost",
"home_server": "test", "home_server": "test",
"device_id": "device_id_here" "device_id": "device_id_here"
} }

View file

@ -241,6 +241,14 @@ regular expressions::
^/_matrix/client/(api/v1|r0|unstable)/keys/upload ^/_matrix/client/(api/v1|r0|unstable)/keys/upload
If ``use_presence`` is False in the homeserver config, it can also handle REST
endpoints matching the following regular expressions::
^/_matrix/client/(api/v1|r0|unstable)/presence/[^/]+/status
This "stub" presence handler will pass through ``GET`` request but make the
``PUT`` effectively a no-op.
It will proxy any requests it cannot handle to the main synapse instance. It It will proxy any requests it cannot handle to the main synapse instance. It
must therefore be configured with the location of the main instance, via must therefore be configured with the location of the main instance, via
the ``worker_main_http_uri`` setting in the frontend_proxy worker configuration the ``worker_main_http_uri`` setting in the frontend_proxy worker configuration

View file

@ -25,7 +25,7 @@ from twisted.internet import defer
import synapse.types import synapse.types
from synapse import event_auth from synapse import event_auth
from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, Codes from synapse.api.errors import AuthError, Codes, ResourceLimitError
from synapse.types import UserID from synapse.types import UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
@ -784,10 +784,11 @@ class Auth(object):
MAU cohort MAU cohort
""" """
if self.hs.config.hs_disabled: if self.hs.config.hs_disabled:
raise AuthError( raise ResourceLimitError(
403, self.hs.config.hs_disabled_message, 403, self.hs.config.hs_disabled_message,
errcode=Codes.RESOURCE_LIMIT_EXCEED, errcode=Codes.RESOURCE_LIMIT_EXCEED,
admin_uri=self.hs.config.admin_uri, admin_uri=self.hs.config.admin_uri,
limit_type=self.hs.config.hs_disabled_limit_type
) )
if self.hs.config.limit_usage_by_mau is True: if self.hs.config.limit_usage_by_mau is True:
# If the user is already part of the MAU cohort # If the user is already part of the MAU cohort
@ -798,8 +799,10 @@ class Auth(object):
# Else if there is no room in the MAU bucket, bail # Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count() current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value: if current_mau >= self.hs.config.max_mau_value:
raise AuthError( raise ResourceLimitError(
403, "Monthly Active User Limits AU Limit Exceeded", 403, "Monthly Active User Limit Exceeded",
admin_uri=self.hs.config.admin_uri, admin_uri=self.hs.config.admin_uri,
errcode=Codes.RESOURCE_LIMIT_EXCEED errcode=Codes.RESOURCE_LIMIT_EXCEED,
limit_type="monthly_active_user"
) )

View file

@ -224,15 +224,34 @@ class NotFoundError(SynapseError):
class AuthError(SynapseError): class AuthError(SynapseError):
"""An error raised when there was a problem authorising an event.""" """An error raised when there was a problem authorising an event."""
def __init__(self, code, msg, errcode=Codes.FORBIDDEN, admin_uri=None):
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.FORBIDDEN
super(AuthError, self).__init__(*args, **kwargs)
class ResourceLimitError(SynapseError):
"""
Any error raised when there is a problem with resource usage.
For instance, the monthly active user limit for the server has been exceeded
"""
def __init__(
self, code, msg,
errcode=Codes.RESOURCE_LIMIT_EXCEED,
admin_uri=None,
limit_type=None,
):
self.admin_uri = admin_uri self.admin_uri = admin_uri
super(AuthError, self).__init__(code, msg, errcode=errcode) self.limit_type = limit_type
super(ResourceLimitError, self).__init__(code, msg, errcode=errcode)
def error_dict(self): def error_dict(self):
return cs_error( return cs_error(
self.msg, self.msg,
self.errcode, self.errcode,
admin_uri=self.admin_uri, admin_uri=self.admin_uri,
limit_type=self.limit_type
) )

View file

@ -140,7 +140,7 @@ def listen_metrics(bind_addresses, port):
logger.info("Metrics now reporting on %s:%d", host, port) logger.info("Metrics now reporting on %s:%d", host, port)
def listen_tcp(bind_addresses, port, factory, backlog=50): def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
""" """
Create a TCP socket for a port and several addresses Create a TCP socket for a port and several addresses
""" """
@ -156,7 +156,9 @@ def listen_tcp(bind_addresses, port, factory, backlog=50):
check_bind_error(e, address, bind_addresses) check_bind_error(e, address, bind_addresses)
def listen_ssl(bind_addresses, port, factory, context_factory, backlog=50): def listen_ssl(
bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50
):
""" """
Create an SSL socket for a port and several addresses Create an SSL socket for a port and several addresses
""" """

View file

@ -117,8 +117,9 @@ class ASReplicationHandler(ReplicationClientHandler):
super(ASReplicationHandler, self).__init__(hs.get_datastore()) super(ASReplicationHandler, self).__init__(hs.get_datastore())
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows): def on_rdata(self, stream_name, token, rows):
super(ASReplicationHandler, self).on_rdata(stream_name, token, rows) yield super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
if stream_name == "events": if stream_name == "events":
max_stream_id = self.store.get_room_max_stream_ordering() max_stream_id = self.store.get_room_max_stream_ordering()

View file

@ -144,8 +144,9 @@ class FederationSenderReplicationHandler(ReplicationClientHandler):
super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore()) super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
self.send_handler = FederationSenderHandler(hs, self) self.send_handler = FederationSenderHandler(hs, self)
@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows): def on_rdata(self, stream_name, token, rows):
super(FederationSenderReplicationHandler, self).on_rdata( yield super(FederationSenderReplicationHandler, self).on_rdata(
stream_name, token, rows stream_name, token, rows
) )
self.send_handler.process_replication_rows(stream_name, token, rows) self.send_handler.process_replication_rows(stream_name, token, rows)

View file

@ -165,7 +165,12 @@ class FrontendProxyServer(HomeServer):
elif name == "client": elif name == "client":
resource = JsonResource(self, canonical_json=False) resource = JsonResource(self, canonical_json=False)
KeyUploadServlet(self).register(resource) KeyUploadServlet(self).register(resource)
PresenceStatusStubServlet(self).register(resource)
# If presence is disabled, use the stub servlet that does
# not allow sending presence
if not self.config.use_presence:
PresenceStatusStubServlet(self).register(resource)
resources.update({ resources.update({
"/_matrix/client/r0": resource, "/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource, "/_matrix/client/unstable": resource,
@ -184,7 +189,8 @@ class FrontendProxyServer(HomeServer):
listener_config, listener_config,
root_resource, root_resource,
self.version_string, self.version_string,
) ),
reactor=self.get_reactor()
) )
logger.info("Synapse client reader now listening on port %d", port) logger.info("Synapse client reader now listening on port %d", port)

View file

@ -525,6 +525,7 @@ def run(hs):
clock.looping_call( clock.looping_call(
hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60 hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60
) )
hs.get_datastore().reap_monthly_active_users()
@defer.inlineCallbacks @defer.inlineCallbacks
def generate_monthly_active_users(): def generate_monthly_active_users():

View file

@ -148,8 +148,9 @@ class PusherReplicationHandler(ReplicationClientHandler):
self.pusher_pool = hs.get_pusherpool() self.pusher_pool = hs.get_pusherpool()
@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows): def on_rdata(self, stream_name, token, rows):
super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) yield super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
run_in_background(self.poke_pushers, stream_name, token, rows) run_in_background(self.poke_pushers, stream_name, token, rows)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -162,11 +163,11 @@ class PusherReplicationHandler(ReplicationClientHandler):
else: else:
yield self.start_pusher(row.user_id, row.app_id, row.pushkey) yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
elif stream_name == "events": elif stream_name == "events":
yield self.pusher_pool.on_new_notifications( self.pusher_pool.on_new_notifications(
token, token, token, token,
) )
elif stream_name == "receipts": elif stream_name == "receipts":
yield self.pusher_pool.on_new_receipts( self.pusher_pool.on_new_receipts(
token, token, set(row.room_id for row in rows) token, token, set(row.room_id for row in rows)
) )
except Exception: except Exception:

View file

@ -114,8 +114,10 @@ class SynchrotronPresence(object):
logger.info("Presence process_id is %r", self.process_id) logger.info("Presence process_id is %r", self.process_id)
def send_user_sync(self, user_id, is_syncing, last_sync_ms): def send_user_sync(self, user_id, is_syncing, last_sync_ms):
return if self.hs.config.use_presence:
self.hs.get_tcp_replication().send_user_sync(user_id, is_syncing, last_sync_ms) self.hs.get_tcp_replication().send_user_sync(
user_id, is_syncing, last_sync_ms
)
def mark_as_coming_online(self, user_id): def mark_as_coming_online(self, user_id):
"""A user has started syncing. Send a UserSync to the master, unless they """A user has started syncing. Send a UserSync to the master, unless they
@ -212,12 +214,13 @@ class SynchrotronPresence(object):
yield self.notify_from_replication(states, stream_id) yield self.notify_from_replication(states, stream_id)
def get_currently_syncing_users(self): def get_currently_syncing_users(self):
# presence is disabled on matrix.org, so we return the empty set if self.hs.config.use_presence:
return set() return [
return [ user_id for user_id, count in iteritems(self.user_to_num_current_syncs)
user_id for user_id, count in iteritems(self.user_to_num_current_syncs) if count > 0
if count > 0 ]
] else:
return set()
class SynchrotronTyping(object): class SynchrotronTyping(object):
@ -335,8 +338,9 @@ class SyncReplicationHandler(ReplicationClientHandler):
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows): def on_rdata(self, stream_name, token, rows):
super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) yield super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
run_in_background(self.process_and_notify, stream_name, token, rows) run_in_background(self.process_and_notify, stream_name, token, rows)
def get_streams_to_replicate(self): def get_streams_to_replicate(self):

View file

@ -169,8 +169,9 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler):
super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore()) super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore())
self.user_directory = hs.get_user_directory_handler() self.user_directory = hs.get_user_directory_handler()
@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows): def on_rdata(self, stream_name, token, rows):
super(UserDirectoryReplicationHandler, self).on_rdata( yield super(UserDirectoryReplicationHandler, self).on_rdata(
stream_name, token, rows stream_name, token, rows
) )
if stream_name == "current_state_deltas": if stream_name == "current_state_deltas":

View file

@ -49,6 +49,9 @@ class ServerConfig(Config):
# "disable" federation # "disable" federation
self.send_federation = config.get("send_federation", True) self.send_federation = config.get("send_federation", True)
# Whether to enable user presence.
self.use_presence = config.get("use_presence", True)
# Whether to update the user directory or not. This should be set to # Whether to update the user directory or not. This should be set to
# false only if we are updating the user directory in a worker # false only if we are updating the user directory in a worker
self.update_user_directory = config.get("update_user_directory", True) self.update_user_directory = config.get("update_user_directory", True)
@ -81,6 +84,7 @@ class ServerConfig(Config):
# Options to disable HS # Options to disable HS
self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled = config.get("hs_disabled", False)
self.hs_disabled_message = config.get("hs_disabled_message", "") self.hs_disabled_message = config.get("hs_disabled_message", "")
self.hs_disabled_limit_type = config.get("hs_disabled_limit_type", "")
# Admin uri to direct users at should their instance become blocked # Admin uri to direct users at should their instance become blocked
# due to resource constraints # due to resource constraints
@ -249,6 +253,9 @@ class ServerConfig(Config):
# hard limit. # hard limit.
soft_file_limit: 0 soft_file_limit: 0
# Set to false to disable presence tracking on this homeserver.
use_presence: true
# The GC threshold parameters to pass to `gc.set_threshold`, if defined # The GC threshold parameters to pass to `gc.set_threshold`, if defined
# gc_thresholds: [700, 10, 10] # gc_thresholds: [700, 10, 10]
@ -340,6 +347,32 @@ class ServerConfig(Config):
# - port: 9000 # - port: 9000
# bind_addresses: ['::1', '127.0.0.1'] # bind_addresses: ['::1', '127.0.0.1']
# type: manhole # type: manhole
# Homeserver blocking
#
# How to reach the server admin, used in ResourceLimitError
# admin_uri: 'mailto:admin@server.com'
#
# Global block config
#
# hs_disabled: False
# hs_disabled_message: 'Human readable reason for why the HS is blocked'
# hs_disabled_limit_type: 'error code(str), to help clients decode reason'
#
# Monthly Active User Blocking
#
# Enables monthly active user checking
# limit_usage_by_mau: False
# max_mau_value: 50
#
# Sometimes the server admin will want to ensure certain accounts are
# never blocked by mau checking. These accounts are specified here.
#
# mau_limit_reserved_threepids:
# - medium: 'email'
# address: 'reserved_user@example.com'
""" % locals() """ % locals()
def read_arguments(self, args): def read_arguments(self, args):

View file

@ -58,6 +58,7 @@ class TransactionQueue(object):
""" """
def __init__(self, hs): def __init__(self, hs):
self.hs = hs
self.server_name = hs.hostname self.server_name = hs.hostname
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -308,7 +309,9 @@ class TransactionQueue(object):
Args: Args:
states (list(UserPresenceState)) states (list(UserPresenceState))
""" """
return if not self.hs.config.use_presence:
# No-op if presence is disabled.
return
# First we queue up the new presence by user ID, so multiple presence # First we queue up the new presence by user ID, so multiple presence
# updates in quick successtion are correctly handled # updates in quick successtion are correctly handled

View file

@ -2386,8 +2386,7 @@ class FederationHandler(BaseHandler):
extra_users=extra_users extra_users=extra_users
) )
logcontext.run_in_background( self.pusher_pool.on_new_notifications(
self.pusher_pool.on_new_notifications,
event_stream_id, max_stream_id, event_stream_id, max_stream_id,
) )

View file

@ -372,7 +372,10 @@ class InitialSyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_presence(): def get_presence():
defer.returnValue([]) # If presence is disabled, return an empty list
if not self.hs.config.use_presence:
defer.returnValue([])
states = yield presence_handler.get_states( states = yield presence_handler.get_states(
[m.user_id for m in room_members], [m.user_id for m in room_members],
as_event=True, as_event=True,

View file

@ -276,10 +276,14 @@ class EventCreationHandler(object):
where *hashes* is a map from algorithm to hash. where *hashes* is a map from algorithm to hash.
If None, they will be requested from the database. If None, they will be requested from the database.
Raises:
ResourceLimitError if server is blocked to some resource being
exceeded
Returns: Returns:
Tuple of created event (FrozenEvent), Context Tuple of created event (FrozenEvent), Context
""" """
yield self.auth.check_auth_blocking(requester.user.to_string())
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
self.validator.validate_new(builder) self.validator.validate_new(builder)
@ -774,11 +778,8 @@ class EventCreationHandler(object):
event, context=context event, context=context
) )
# this intentionally does not yield: we don't care about the result self.pusher_pool.on_new_notifications(
# and don't need to wait for it. event_stream_id, max_stream_id,
run_in_background(
self.pusher_pool.on_new_notifications,
event_stream_id, max_stream_id
) )
def _notify(): def _notify():

View file

@ -395,7 +395,10 @@ class PresenceHandler(object):
"""We've seen the user do something that indicates they're interacting """We've seen the user do something that indicates they're interacting
with the app. with the app.
""" """
return # If presence is disabled, no-op
if not self.hs.config.use_presence:
return
user_id = user.to_string() user_id = user.to_string()
bump_active_time_counter.inc() bump_active_time_counter.inc()
@ -425,7 +428,11 @@ class PresenceHandler(object):
Useful for streams that are not associated with an actual Useful for streams that are not associated with an actual
client that is being used by a user. client that is being used by a user.
""" """
affect_presence = False # Override if it should affect the user's presence, if presence is
# disabled.
if not self.hs.config.use_presence:
affect_presence = False
if affect_presence: if affect_presence:
curr_sync = self.user_to_num_current_syncs.get(user_id, 0) curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
self.user_to_num_current_syncs[user_id] = curr_sync + 1 self.user_to_num_current_syncs[user_id] = curr_sync + 1
@ -471,15 +478,16 @@ class PresenceHandler(object):
Returns: Returns:
set(str): A set of user_id strings. set(str): A set of user_id strings.
""" """
# presence is disabled on matrix.org, so we return the empty set if self.hs.config.use_presence:
return set() syncing_user_ids = {
syncing_user_ids = { user_id for user_id, count in self.user_to_num_current_syncs.items()
user_id for user_id, count in self.user_to_num_current_syncs.items() if count
if count }
} for user_ids in self.external_process_to_current_syncs.values():
for user_ids in self.external_process_to_current_syncs.values(): syncing_user_ids.update(user_ids)
syncing_user_ids.update(user_ids) return syncing_user_ids
return syncing_user_ids else:
return set()
@defer.inlineCallbacks @defer.inlineCallbacks
def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec): def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec):

View file

@ -18,7 +18,6 @@ from twisted.internet import defer
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util import logcontext from synapse.util import logcontext
from synapse.util.logcontext import PreserveLoggingContext
from ._base import BaseHandler from ._base import BaseHandler
@ -116,16 +115,15 @@ class ReceiptsHandler(BaseHandler):
affected_room_ids = list(set([r["room_id"] for r in receipts])) affected_room_ids = list(set([r["room_id"] for r in receipts]))
with PreserveLoggingContext(): self.notifier.on_new_event(
self.notifier.on_new_event( "receipt_key", max_batch_id, rooms=affected_room_ids
"receipt_key", max_batch_id, rooms=affected_room_ids )
) # Note that the min here shouldn't be relied upon to be accurate.
# Note that the min here shouldn't be relied upon to be accurate. self.hs.get_pusherpool().on_new_receipts(
self.hs.get_pusherpool().on_new_receipts( min_batch_id, max_batch_id, affected_room_ids,
min_batch_id, max_batch_id, affected_room_ids )
)
defer.returnValue(True) defer.returnValue(True)
@logcontext.preserve_fn # caller should not yield on this @logcontext.preserve_fn # caller should not yield on this
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -98,9 +98,13 @@ class RoomCreationHandler(BaseHandler):
Raises: Raises:
SynapseError if the room ID couldn't be stored, or something went SynapseError if the room ID couldn't be stored, or something went
horribly wrong. horribly wrong.
ResourceLimitError if server is blocked to some resource being
exceeded
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
self.auth.check_auth_blocking(user_id)
if not self.spam_checker.user_may_create_room(user_id): if not self.spam_checker.user_may_create_room(user_id):
raise SynapseError(403, "You are not permitted to create rooms") raise SynapseError(403, "You are not permitted to create rooms")

View file

@ -187,6 +187,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
class SyncHandler(object): class SyncHandler(object):
def __init__(self, hs): def __init__(self, hs):
self.hs_config = hs.config
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
@ -864,7 +865,7 @@ class SyncHandler(object):
since_token is None and since_token is None and
sync_config.filter_collection.blocks_all_presence() sync_config.filter_collection.blocks_all_presence()
) )
if False and not block_all_presence_data: if self.hs_config.use_presence and not block_all_presence_data:
yield self._generate_sync_entry_for_presence( yield self._generate_sync_entry_for_presence(
sync_result_builder, newly_joined_rooms, newly_joined_users sync_result_builder, newly_joined_rooms, newly_joined_users
) )

View file

@ -25,8 +25,9 @@ from canonicaljson import encode_canonical_json, encode_pretty_printed_json, jso
from twisted.internet import defer from twisted.internet import defer
from twisted.python import failure from twisted.python import failure
from twisted.web import resource, server from twisted.web import resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from twisted.web.static import NoRangeStaticProducer
from twisted.web.util import redirectTo from twisted.web.util import redirectTo
import synapse.events import synapse.events
@ -37,10 +38,13 @@ from synapse.api.errors import (
SynapseError, SynapseError,
UnrecognizedRequestError, UnrecognizedRequestError,
) )
from synapse.http.request_metrics import requests_counter
from synapse.util.caches import intern_dict from synapse.util.caches import intern_dict
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import preserve_fn
from synapse.util.metrics import Measure
if PY3:
from io import BytesIO
else:
from cStringIO import StringIO as BytesIO
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -60,11 +64,10 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
def wrap_json_request_handler(h): def wrap_json_request_handler(h):
"""Wraps a request handler method with exception handling. """Wraps a request handler method with exception handling.
Also adds logging as per wrap_request_handler_with_logging. Also does the wrapping with request.processing as per wrap_async_request_handler.
The handler method must have a signature of "handle_foo(self, request)", The handler method must have a signature of "handle_foo(self, request)",
where "self" must have a "clock" attribute (and "request" must be a where "request" must be a SynapseRequest.
SynapseRequest).
The handler must return a deferred. If the deferred succeeds we assume that The handler must return a deferred. If the deferred succeeds we assume that
a response has been sent. If the deferred fails with a SynapseError we use a response has been sent. If the deferred fails with a SynapseError we use
@ -108,24 +111,23 @@ def wrap_json_request_handler(h):
pretty_print=_request_user_agent_is_curl(request), pretty_print=_request_user_agent_is_curl(request),
) )
return wrap_request_handler_with_logging(wrapped_request_handler) return wrap_async_request_handler(wrapped_request_handler)
def wrap_html_request_handler(h): def wrap_html_request_handler(h):
"""Wraps a request handler method with exception handling. """Wraps a request handler method with exception handling.
Also adds logging as per wrap_request_handler_with_logging. Also does the wrapping with request.processing as per wrap_async_request_handler.
The handler method must have a signature of "handle_foo(self, request)", The handler method must have a signature of "handle_foo(self, request)",
where "self" must have a "clock" attribute (and "request" must be a where "request" must be a SynapseRequest.
SynapseRequest).
""" """
def wrapped_request_handler(self, request): def wrapped_request_handler(self, request):
d = defer.maybeDeferred(h, self, request) d = defer.maybeDeferred(h, self, request)
d.addErrback(_return_html_error, request) d.addErrback(_return_html_error, request)
return d return d
return wrap_request_handler_with_logging(wrapped_request_handler) return wrap_async_request_handler(wrapped_request_handler)
def _return_html_error(f, request): def _return_html_error(f, request):
@ -170,46 +172,26 @@ def _return_html_error(f, request):
finish_request(request) finish_request(request)
def wrap_request_handler_with_logging(h): def wrap_async_request_handler(h):
"""Wraps a request handler to provide logging and metrics """Wraps an async request handler so that it calls request.processing.
This helps ensure that work done by the request handler after the request is completed
is correctly recorded against the request metrics/logs.
The handler method must have a signature of "handle_foo(self, request)", The handler method must have a signature of "handle_foo(self, request)",
where "self" must have a "clock" attribute (and "request" must be a where "request" must be a SynapseRequest.
SynapseRequest).
As well as calling `request.processing` (which will log the response and The handler may return a deferred, in which case the completion of the request isn't
duration for this request), the wrapped request handler will insert the logged until the deferred completes.
request id into the logging context.
""" """
@defer.inlineCallbacks @defer.inlineCallbacks
def wrapped_request_handler(self, request): def wrapped_async_request_handler(self, request):
""" with request.processing():
Args: yield h(self, request)
self:
request (synapse.http.site.SynapseRequest):
"""
request_id = request.get_request_id() # we need to preserve_fn here, because the synchronous render method won't yield for
with LoggingContext(request_id) as request_context: # us (obviously)
request_context.request = request_id return preserve_fn(wrapped_async_request_handler)
with Measure(self.clock, "wrapped_request_handler"):
# we start the request metrics timer here with an initial stab
# at the servlet name. For most requests that name will be
# JsonResource (or a subclass), and JsonResource._async_render
# will update it once it picks a servlet.
servlet_name = self.__class__.__name__
with request.processing(servlet_name):
with PreserveLoggingContext(request_context):
d = defer.maybeDeferred(h, self, request)
# record the arrival of the request *after*
# dispatching to the handler, so that the handler
# can update the servlet name in the request
# metrics
requests_counter.labels(request.method,
request.request_metrics.name).inc()
yield d
return wrapped_request_handler
class HttpServer(object): class HttpServer(object):
@ -272,7 +254,7 @@ class JsonResource(HttpServer, resource.Resource):
""" This gets called by twisted every time someone sends us a request. """ This gets called by twisted every time someone sends us a request.
""" """
self._async_render(request) self._async_render(request)
return server.NOT_DONE_YET return NOT_DONE_YET
@wrap_json_request_handler @wrap_json_request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
@ -413,8 +395,7 @@ def respond_with_json(request, code, json_object, send_cors=False,
return return
if pretty_print: if pretty_print:
json_bytes = (encode_pretty_printed_json(json_object) + "\n" json_bytes = encode_pretty_printed_json(json_object) + b"\n"
).encode("utf-8")
else: else:
if canonical_json or synapse.events.USE_FROZEN_DICTS: if canonical_json or synapse.events.USE_FROZEN_DICTS:
# canonicaljson already encodes to bytes # canonicaljson already encodes to bytes
@ -450,8 +431,12 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
if send_cors: if send_cors:
set_cors_headers(request) set_cors_headers(request)
request.write(json_bytes) # todo: we can almost certainly avoid this copy and encode the json straight into
finish_request(request) # the bytesIO, but it would involve faffing around with string->bytes wrappers.
bytes_io = BytesIO(json_bytes)
producer = NoRangeStaticProducer(request, bytes_io)
producer.start()
return NOT_DONE_YET return NOT_DONE_YET

View file

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib import contextlib
import logging import logging
import time import time
@ -19,8 +18,8 @@ import time
from twisted.web.server import Request, Site from twisted.web.server import Request, Site
from synapse.http import redact_uri from synapse.http import redact_uri
from synapse.http.request_metrics import RequestMetrics from synapse.http.request_metrics import RequestMetrics, requests_counter
from synapse.util.logcontext import ContextResourceUsage, LoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,25 +33,43 @@ class SynapseRequest(Request):
It extends twisted's twisted.web.server.Request, and adds: It extends twisted's twisted.web.server.Request, and adds:
* Unique request ID * Unique request ID
* A log context associated with the request
* Redaction of access_token query-params in __repr__ * Redaction of access_token query-params in __repr__
* Logging at start and end * Logging at start and end
* Metrics to record CPU, wallclock and DB time by endpoint. * Metrics to record CPU, wallclock and DB time by endpoint.
It provides a method `processing` which should be called by the Resource It also provides a method `processing`, which returns a context manager. If this
which is handling the request, and returns a context manager. method is called, the request won't be logged until the context manager is closed;
this is useful for asynchronous request handlers which may go on processing the
request even after the client has disconnected.
Attributes:
logcontext(LoggingContext) : the log context for this request
""" """
def __init__(self, site, channel, *args, **kw): def __init__(self, site, channel, *args, **kw):
Request.__init__(self, channel, *args, **kw) Request.__init__(self, channel, *args, **kw)
self.site = site self.site = site
self._channel = channel self._channel = channel # this is used by the tests
self.authenticated_entity = None self.authenticated_entity = None
self.start_time = 0 self.start_time = 0
# we can't yet create the logcontext, as we don't know the method.
self.logcontext = None
global _next_request_seq global _next_request_seq
self.request_seq = _next_request_seq self.request_seq = _next_request_seq
_next_request_seq += 1 _next_request_seq += 1
# whether an asynchronous request handler has called processing()
self._is_processing = False
# the time when the asynchronous request handler completed its processing
self._processing_finished_time = None
# what time we finished sending the response to the client (or the connection
# dropped)
self.finish_time = None
def __repr__(self): def __repr__(self):
# We overwrite this so that we don't log ``access_token`` # We overwrite this so that we don't log ``access_token``
return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % ( return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % (
@ -74,11 +91,116 @@ class SynapseRequest(Request):
return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
def render(self, resrc): def render(self, resrc):
# this is called once a Resource has been found to serve the request; in our
# case the Resource in question will normally be a JsonResource.
# create a LogContext for this request
request_id = self.get_request_id()
logcontext = self.logcontext = LoggingContext(request_id)
logcontext.request = request_id
# override the Server header which is set by twisted # override the Server header which is set by twisted
self.setHeader("Server", self.site.server_version_string) self.setHeader("Server", self.site.server_version_string)
return Request.render(self, resrc)
with PreserveLoggingContext(self.logcontext):
# we start the request metrics timer here with an initial stab
# at the servlet name. For most requests that name will be
# JsonResource (or a subclass), and JsonResource._async_render
# will update it once it picks a servlet.
servlet_name = resrc.__class__.__name__
self._started_processing(servlet_name)
Request.render(self, resrc)
# record the arrival of the request *after*
# dispatching to the handler, so that the handler
# can update the servlet name in the request
# metrics
requests_counter.labels(self.method,
self.request_metrics.name).inc()
@contextlib.contextmanager
def processing(self):
"""Record the fact that we are processing this request.
Returns a context manager; the correct way to use this is:
@defer.inlineCallbacks
def handle_request(request):
with request.processing("FooServlet"):
yield really_handle_the_request()
Once the context manager is closed, the completion of the request will be logged,
and the various metrics will be updated.
"""
if self._is_processing:
raise RuntimeError("Request is already processing")
self._is_processing = True
try:
yield
except Exception:
# this should already have been caught, and sent back to the client as a 500.
logger.exception("Asynchronous messge handler raised an uncaught exception")
finally:
# the request handler has finished its work and either sent the whole response
# back, or handed over responsibility to a Producer.
self._processing_finished_time = time.time()
self._is_processing = False
# if we've already sent the response, log it now; otherwise, we wait for the
# response to be sent.
if self.finish_time is not None:
self._finished_processing()
def finish(self):
"""Called when all response data has been written to this Request.
Overrides twisted.web.server.Request.finish to record the finish time and do
logging.
"""
self.finish_time = time.time()
Request.finish(self)
if not self._is_processing:
with PreserveLoggingContext(self.logcontext):
self._finished_processing()
def connectionLost(self, reason):
"""Called when the client connection is closed before the response is written.
Overrides twisted.web.server.Request.connectionLost to record the finish time and
do logging.
"""
self.finish_time = time.time()
Request.connectionLost(self, reason)
# we only get here if the connection to the client drops before we send
# the response.
#
# It's useful to log it here so that we can get an idea of when
# the client disconnects.
with PreserveLoggingContext(self.logcontext):
logger.warn(
"Error processing request: %s %s", reason.type, reason.value,
)
if not self._is_processing:
self._finished_processing()
def _started_processing(self, servlet_name): def _started_processing(self, servlet_name):
"""Record the fact that we are processing this request.
This will log the request's arrival. Once the request completes,
be sure to call finished_processing.
Args:
servlet_name (str): the name of the servlet which will be
processing this request. This is used in the metrics.
It is possible to update this afterwards by updating
self.request_metrics.name.
"""
self.start_time = time.time() self.start_time = time.time()
self.request_metrics = RequestMetrics() self.request_metrics = RequestMetrics()
self.request_metrics.start( self.request_metrics.start(
@ -94,13 +216,21 @@ class SynapseRequest(Request):
) )
def _finished_processing(self): def _finished_processing(self):
try: """Log the completion of this request and update the metrics
context = LoggingContext.current_context() """
usage = context.get_resource_usage()
except Exception:
usage = ContextResourceUsage()
end_time = time.time() usage = self.logcontext.get_resource_usage()
if self._processing_finished_time is None:
# we completed the request without anything calling processing()
self._processing_finished_time = time.time()
# the time between receiving the request and the request handler finishing
processing_time = self._processing_finished_time - self.start_time
# the time between the request handler finishing and the response being sent
# to the client (nb may be negative)
response_send_time = self.finish_time - self._processing_finished_time
# need to decode as it could be raw utf-8 bytes # need to decode as it could be raw utf-8 bytes
# from a IDN servname in an auth header # from a IDN servname in an auth header
@ -116,22 +246,31 @@ class SynapseRequest(Request):
user_agent = self.get_user_agent() user_agent = self.get_user_agent()
if user_agent is not None: if user_agent is not None:
user_agent = user_agent.decode("utf-8", "replace") user_agent = user_agent.decode("utf-8", "replace")
else:
user_agent = "-"
code = str(self.code)
if not self.finished:
# we didn't send the full response before we gave up (presumably because
# the connection dropped)
code += "!"
self.site.access_logger.info( self.site.access_logger.info(
"%s - %s - {%s}" "%s - %s - {%s}"
" Processed request: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
" %sB %s \"%s %s %s\" \"%s\" [%d dbevts]", " %sB %s \"%s %s %s\" \"%s\" [%d dbevts]",
self.getClientIP(), self.getClientIP(),
self.site.site_tag, self.site.site_tag,
authenticated_entity, authenticated_entity,
end_time - self.start_time, processing_time,
response_send_time,
usage.ru_utime, usage.ru_utime,
usage.ru_stime, usage.ru_stime,
usage.db_sched_duration_sec, usage.db_sched_duration_sec,
usage.db_txn_duration_sec, usage.db_txn_duration_sec,
int(usage.db_txn_count), int(usage.db_txn_count),
self.sentLength, self.sentLength,
self.code, code,
self.method, self.method,
self.get_redacted_uri(), self.get_redacted_uri(),
self.clientproto, self.clientproto,
@ -140,38 +279,10 @@ class SynapseRequest(Request):
) )
try: try:
self.request_metrics.stop(end_time, self) self.request_metrics.stop(self.finish_time, self)
except Exception as e: except Exception as e:
logger.warn("Failed to stop metrics: %r", e) logger.warn("Failed to stop metrics: %r", e)
@contextlib.contextmanager
def processing(self, servlet_name):
"""Record the fact that we are processing this request.
Returns a context manager; the correct way to use this is:
@defer.inlineCallbacks
def handle_request(request):
with request.processing("FooServlet"):
yield really_handle_the_request()
This will log the request's arrival. Once the context manager is
closed, the completion of the request will be logged, and the various
metrics will be updated.
Args:
servlet_name (str): the name of the servlet which will be
processing this request. This is used in the metrics.
It is possible to update this afterwards by updating
self.request_metrics.servlet_name.
"""
# TODO: we should probably just move this into render() and finish(),
# to save having to call a separate method.
self._started_processing(servlet_name)
yield
self._finished_processing()
class XForwardedForRequest(SynapseRequest): class XForwardedForRequest(SynapseRequest):
def __init__(self, *args, **kw): def __init__(self, *args, **kw):

View file

@ -18,6 +18,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push.pusher import PusherFactory from synapse.push.pusher import PusherFactory
from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logcontext import make_deferred_yieldable, run_in_background
@ -122,8 +123,14 @@ class PusherPool:
p['app_id'], p['pushkey'], p['user_name'], p['app_id'], p['pushkey'], p['user_name'],
) )
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_id, max_stream_id): def on_new_notifications(self, min_stream_id, max_stream_id):
run_as_background_process(
"on_new_notifications",
self._on_new_notifications, min_stream_id, max_stream_id,
)
@defer.inlineCallbacks
def _on_new_notifications(self, min_stream_id, max_stream_id):
try: try:
users_affected = yield self.store.get_push_action_users_in_range( users_affected = yield self.store.get_push_action_users_in_range(
min_stream_id, max_stream_id min_stream_id, max_stream_id
@ -147,8 +154,14 @@ class PusherPool:
except Exception: except Exception:
logger.exception("Exception in pusher on_new_notifications") logger.exception("Exception in pusher on_new_notifications")
@defer.inlineCallbacks
def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
run_as_background_process(
"on_new_receipts",
self._on_new_receipts, min_stream_id, max_stream_id, affected_room_ids,
)
@defer.inlineCallbacks
def _on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
try: try:
# Need to subtract 1 from the minimum because the lower bound here # Need to subtract 1 from the minimum because the lower bound here
# is not inclusive # is not inclusive

View file

@ -156,7 +156,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
edu_content = content["content"] edu_content = content["content"]
logger.info( logger.info(
"Got %r edu from $s", "Got %r edu from %s",
edu_type, origin, edu_type, origin,
) )

View file

@ -107,7 +107,7 @@ class ReplicationClientHandler(object):
Can be overriden in subclasses to handle more. Can be overriden in subclasses to handle more.
""" """
logger.info("Received rdata %s -> %s", stream_name, token) logger.info("Received rdata %s -> %s", stream_name, token)
self.store.process_replication_rows(stream_name, token, rows) return self.store.process_replication_rows(stream_name, token, rows)
def on_position(self, stream_name, token): def on_position(self, stream_name, token):
"""Called when we get new position data. By default this just pokes """Called when we get new position data. By default this just pokes
@ -115,7 +115,7 @@ class ReplicationClientHandler(object):
Can be overriden in subclasses to handle more. Can be overriden in subclasses to handle more.
""" """
self.store.process_replication_rows(stream_name, token, []) return self.store.process_replication_rows(stream_name, token, [])
def on_sync(self, data): def on_sync(self, data):
"""When we received a SYNC we wake up any deferreds that were waiting """When we received a SYNC we wake up any deferreds that were waiting

View file

@ -59,6 +59,12 @@ class Command(object):
""" """
return self.data return self.data
def get_logcontext_id(self):
"""Get a suitable string for the logcontext when processing this command"""
# by default, we just use the command name.
return self.NAME
class ServerCommand(Command): class ServerCommand(Command):
"""Sent by the server on new connection and includes the server_name. """Sent by the server on new connection and includes the server_name.
@ -116,6 +122,9 @@ class RdataCommand(Command):
_json_encoder.encode(self.row), _json_encoder.encode(self.row),
)) ))
def get_logcontext_id(self):
return "RDATA-" + self.stream_name
class PositionCommand(Command): class PositionCommand(Command):
"""Sent by the client to tell the client the stream postition without """Sent by the client to tell the client the stream postition without
@ -190,6 +199,9 @@ class ReplicateCommand(Command):
def to_line(self): def to_line(self):
return " ".join((self.stream_name, str(self.token),)) return " ".join((self.stream_name, str(self.token),))
def get_logcontext_id(self):
return "REPLICATE-" + self.stream_name
class UserSyncCommand(Command): class UserSyncCommand(Command):
"""Sent by the client to inform the server that a user has started or """Sent by the client to inform the server that a user has started or

View file

@ -63,6 +63,8 @@ from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure from twisted.python.failure import Failure
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from .commands import ( from .commands import (
@ -222,7 +224,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# Now lets try and call on_<CMD_NAME> function # Now lets try and call on_<CMD_NAME> function
try: try:
getattr(self, "on_%s" % (cmd_name,))(cmd) run_as_background_process(
"replication-" + cmd.get_logcontext_id(),
getattr(self, "on_%s" % (cmd_name,)),
cmd,
)
except Exception: except Exception:
logger.exception("[%s] Failed to handle line: %r", self.id(), line) logger.exception("[%s] Failed to handle line: %r", self.id(), line)
@ -387,7 +393,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.name = cmd.data self.name = cmd.data
def on_USER_SYNC(self, cmd): def on_USER_SYNC(self, cmd):
self.streamer.on_user_sync( return self.streamer.on_user_sync(
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms, self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms,
) )
@ -397,22 +403,33 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
if stream_name == "ALL": if stream_name == "ALL":
# Subscribe to all streams we're publishing to. # Subscribe to all streams we're publishing to.
for stream in iterkeys(self.streamer.streams_by_name): deferreds = [
self.subscribe_to_stream(stream, token) run_in_background(
self.subscribe_to_stream,
stream, token,
)
for stream in iterkeys(self.streamer.streams_by_name)
]
return make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
else: else:
self.subscribe_to_stream(stream_name, token) return self.subscribe_to_stream(stream_name, token)
def on_FEDERATION_ACK(self, cmd): def on_FEDERATION_ACK(self, cmd):
self.streamer.federation_ack(cmd.token) return self.streamer.federation_ack(cmd.token)
def on_REMOVE_PUSHER(self, cmd): def on_REMOVE_PUSHER(self, cmd):
self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) return self.streamer.on_remove_pusher(
cmd.app_id, cmd.push_key, cmd.user_id,
)
def on_INVALIDATE_CACHE(self, cmd): def on_INVALIDATE_CACHE(self, cmd):
self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
def on_USER_IP(self, cmd): def on_USER_IP(self, cmd):
self.streamer.on_user_ip( return self.streamer.on_user_ip(
cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id, cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
cmd.last_seen, cmd.last_seen,
) )
@ -542,14 +559,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Check if this is the last of a batch of updates # Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, []) rows = self.pending_batches.pop(stream_name, [])
rows.append(row) rows.append(row)
return self.handler.on_rdata(stream_name, cmd.token, rows)
self.handler.on_rdata(stream_name, cmd.token, rows)
def on_POSITION(self, cmd): def on_POSITION(self, cmd):
self.handler.on_position(cmd.stream_name, cmd.token) return self.handler.on_position(cmd.stream_name, cmd.token)
def on_SYNC(self, cmd): def on_SYNC(self, cmd):
self.handler.on_sync(cmd.data) return self.handler.on_sync(cmd.data)
def replicate(self, stream_name, token): def replicate(self, stream_name, token):
"""Send the subscription request to the server """Send the subscription request to the server

View file

@ -84,7 +84,8 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
except Exception: except Exception:
raise SynapseError(400, "Unable to parse state") raise SynapseError(400, "Unable to parse state")
# yield self.presence_handler.set_state(user, state) if self.hs.config.use_presence:
yield self.presence_handler.set_state(user, state)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View file

@ -129,12 +129,9 @@ class RegisterRestServlet(ClientV1RestServlet):
login_type = register_json["type"] login_type = register_json["type"]
is_application_server = login_type == LoginType.APPLICATION_SERVICE is_application_server = login_type == LoginType.APPLICATION_SERVICE
is_using_shared_secret = login_type == LoginType.SHARED_SECRET
can_register = ( can_register = (
self.enable_registration self.enable_registration
or is_application_server or is_application_server
or is_using_shared_secret
) )
if not can_register: if not can_register:
raise SynapseError(403, "Registration has been disabled") raise SynapseError(403, "Registration has been disabled")
@ -144,7 +141,6 @@ class RegisterRestServlet(ClientV1RestServlet):
LoginType.PASSWORD: self._do_password, LoginType.PASSWORD: self._do_password,
LoginType.EMAIL_IDENTITY: self._do_email_identity, LoginType.EMAIL_IDENTITY: self._do_email_identity,
LoginType.APPLICATION_SERVICE: self._do_app_service, LoginType.APPLICATION_SERVICE: self._do_app_service,
LoginType.SHARED_SECRET: self._do_shared_secret,
} }
session_info = self._get_session_info(request, session) session_info = self._get_session_info(request, session)
@ -325,56 +321,6 @@ class RegisterRestServlet(ClientV1RestServlet):
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
}) })
@defer.inlineCallbacks
def _do_shared_secret(self, request, register_json, session):
assert_params_in_dict(register_json, ["mac", "user", "password"])
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
user = register_json["user"].encode("utf-8")
password = register_json["password"].encode("utf-8")
admin = register_json.get("admin", None)
# Its important to check as we use null bytes as HMAC field separators
if b"\x00" in user:
raise SynapseError(400, "Invalid user")
if b"\x00" in password:
raise SynapseError(400, "Invalid password")
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
got_mac = str(register_json["mac"])
want_mac = hmac.new(
key=self.hs.config.registration_shared_secret.encode(),
digestmod=sha1,
)
want_mac.update(user)
want_mac.update(b"\x00")
want_mac.update(password)
want_mac.update(b"\x00")
want_mac.update(b"admin" if admin else b"notadmin")
want_mac = want_mac.hexdigest()
if compare_digest(want_mac, got_mac):
handler = self.handlers.registration_handler
user_id, token = yield handler.register(
localpart=user.lower(),
password=password,
admin=bool(admin),
)
self._remove_session(session)
defer.returnValue({
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
})
else:
raise SynapseError(
403, "HMAC incorrect",
)
class CreateUserRestServlet(ClientV1RestServlet): class CreateUserRestServlet(ClientV1RestServlet):
"""Handles user creation via a server-to-server interface """Handles user creation via a server-to-server interface

View file

@ -96,7 +96,10 @@ class MonthlyActiveUsersStore(SQLBaseStore):
# While Postgres does not require 'LIMIT', but also does not support # While Postgres does not require 'LIMIT', but also does not support
# negative LIMIT values. So there is no way to write it that both can # negative LIMIT values. So there is no way to write it that both can
# support # support
query_args = [self.hs.config.max_mau_value] safe_guard = self.hs.config.max_mau_value - len(self.reserved_users)
# Must be greater than zero for postgres
safe_guard = safe_guard if safe_guard > 0 else 0
query_args = [safe_guard]
base_sql = """ base_sql = """
DELETE FROM monthly_active_users DELETE FROM monthly_active_users

View file

@ -21,7 +21,7 @@ from twisted.internet import defer
import synapse.handlers.auth import synapse.handlers.auth
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.api.errors import AuthError, Codes from synapse.api.errors import AuthError, Codes, ResourceLimitError
from synapse.types import UserID from synapse.types import UserID
from tests import unittest from tests import unittest
@ -455,7 +455,7 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(lots_of_users) return_value=defer.succeed(lots_of_users)
) )
with self.assertRaises(AuthError) as e: with self.assertRaises(ResourceLimitError) as e:
yield self.auth.check_auth_blocking() yield self.auth.check_auth_blocking()
self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri) self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
@ -471,7 +471,7 @@ class AuthTestCase(unittest.TestCase):
def test_hs_disabled(self): def test_hs_disabled(self):
self.hs.config.hs_disabled = True self.hs.config.hs_disabled = True
self.hs.config.hs_disabled_message = "Reason for being disabled" self.hs.config.hs_disabled_message = "Reason for being disabled"
with self.assertRaises(AuthError) as e: with self.assertRaises(ResourceLimitError) as e:
yield self.auth.check_auth_blocking() yield self.auth.check_auth_blocking()
self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri) self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)

0
tests/app/__init__.py Normal file
View file

View file

@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.app.frontend_proxy import FrontendProxyServer
from tests.unittest import HomeserverTestCase
class FrontendProxyTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
http_client=None, homeserverToUse=FrontendProxyServer
)
return hs
def test_listen_http_with_presence_enabled(self):
"""
When presence is on, the stub servlet will not register.
"""
# Presence is on
self.hs.config.use_presence = True
config = {
"port": 8080,
"bind_addresses": ["0.0.0.0"],
"resources": [{"names": ["client"]}],
}
# Listen with the config
self.hs._listen_http(config)
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
self.resource = (
site.resource.children["_matrix"].children["client"].children["r0"]
)
request, channel = self.make_request("PUT", "presence/a/status")
self.render(request)
# 400 + unrecognised, because nothing is registered
self.assertEqual(channel.code, 400)
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
def test_listen_http_with_presence_disabled(self):
"""
When presence is on, the stub servlet will register.
"""
# Presence is off
self.hs.config.use_presence = False
config = {
"port": 8080,
"bind_addresses": ["0.0.0.0"],
"resources": [{"names": ["client"]}],
}
# Listen with the config
self.hs._listen_http(config)
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
self.resource = (
site.resource.children["_matrix"].children["client"].children["r0"]
)
request, channel = self.make_request("PUT", "presence/a/status")
self.render(request)
# 401, because the stub servlet still checks authentication
self.assertEqual(channel.code, 401)
self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN")

View file

@ -20,7 +20,7 @@ from twisted.internet import defer
import synapse import synapse
import synapse.api.errors import synapse.api.errors
from synapse.api.errors import AuthError from synapse.api.errors import ResourceLimitError
from synapse.handlers.auth import AuthHandler from synapse.handlers.auth import AuthHandler
from tests import unittest from tests import unittest
@ -130,13 +130,13 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.large_number_of_users) return_value=defer.succeed(self.large_number_of_users)
) )
with self.assertRaises(AuthError): with self.assertRaises(ResourceLimitError):
yield self.auth_handler.get_access_token_for_user_id('user_a') yield self.auth_handler.get_access_token_for_user_id('user_a')
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users) return_value=defer.succeed(self.large_number_of_users)
) )
with self.assertRaises(AuthError): with self.assertRaises(ResourceLimitError):
yield self.auth_handler.validate_short_term_login_token_and_get_user_id( yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize() self._get_macaroon().serialize()
) )
@ -149,13 +149,13 @@ class AuthTestCase(unittest.TestCase):
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.hs.config.max_mau_value)
) )
with self.assertRaises(AuthError): with self.assertRaises(ResourceLimitError):
yield self.auth_handler.get_access_token_for_user_id('user_a') yield self.auth_handler.get_access_token_for_user_id('user_a')
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.hs.config.max_mau_value)
) )
with self.assertRaises(AuthError): with self.assertRaises(ResourceLimitError):
yield self.auth_handler.validate_short_term_login_token_and_get_user_id( yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize() self._get_macaroon().serialize()
) )

View file

@ -17,7 +17,7 @@ from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError from synapse.api.errors import ResourceLimitError
from synapse.handlers.register import RegistrationHandler from synapse.handlers.register import RegistrationHandler
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
@ -109,13 +109,13 @@ class RegistrationTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users) return_value=defer.succeed(self.lots_of_users)
) )
with self.assertRaises(AuthError): with self.assertRaises(ResourceLimitError):
yield self.handler.get_or_create_user("requester", 'b', "display_name") yield self.handler.get_or_create_user("requester", 'b', "display_name")
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.hs.config.max_mau_value)
) )
with self.assertRaises(AuthError): with self.assertRaises(ResourceLimitError):
yield self.handler.get_or_create_user("requester", 'b', "display_name") yield self.handler.get_or_create_user("requester", 'b', "display_name")
@defer.inlineCallbacks @defer.inlineCallbacks
@ -124,13 +124,13 @@ class RegistrationTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users) return_value=defer.succeed(self.lots_of_users)
) )
with self.assertRaises(AuthError): with self.assertRaises(ResourceLimitError):
yield self.handler.register(localpart="local_part") yield self.handler.register(localpart="local_part")
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.hs.config.max_mau_value)
) )
with self.assertRaises(AuthError): with self.assertRaises(ResourceLimitError):
yield self.handler.register(localpart="local_part") yield self.handler.register(localpart="local_part")
@defer.inlineCallbacks @defer.inlineCallbacks
@ -139,11 +139,11 @@ class RegistrationTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users) return_value=defer.succeed(self.lots_of_users)
) )
with self.assertRaises(AuthError): with self.assertRaises(ResourceLimitError):
yield self.handler.register_saml2(localpart="local_part") yield self.handler.register_saml2(localpart="local_part")
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.hs.config.max_mau_value)
) )
with self.assertRaises(AuthError): with self.assertRaises(ResourceLimitError):
yield self.handler.register_saml2(localpart="local_part") yield self.handler.register_saml2(localpart="local_part")

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, Codes from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
from synapse.handlers.sync import SyncConfig, SyncHandler from synapse.handlers.sync import SyncConfig, SyncHandler
from synapse.types import UserID from synapse.types import UserID
@ -49,7 +49,7 @@ class SyncTestCase(tests.unittest.TestCase):
# Test that global lock works # Test that global lock works
self.hs.config.hs_disabled = True self.hs.config.hs_disabled = True
with self.assertRaises(AuthError) as e: with self.assertRaises(ResourceLimitError) as e:
yield self.sync_handler.wait_for_sync_for_user(sync_config) yield self.sync_handler.wait_for_sync_for_user(sync_config)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
@ -57,7 +57,7 @@ class SyncTestCase(tests.unittest.TestCase):
sync_config = self._generate_sync_config(user_id2) sync_config = self._generate_sync_config(user_id2)
with self.assertRaises(AuthError) as e: with self.assertRaises(ResourceLimitError) as e:
yield self.sync_handler.wait_for_sync_for_user(sync_config) yield self.sync_handler.wait_for_sync_for_user(sync_config)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)

View file

@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
from synapse.rest.client.v1 import presence
from synapse.types import UserID
from tests import unittest
class PresenceTestCase(unittest.HomeserverTestCase):
""" Tests presence REST API. """
user_id = "@sid:red"
user = UserID.from_string(user_id)
servlets = [presence.register_servlets]
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
"red", http_client=None, federation_client=Mock()
)
hs.presence_handler = Mock()
return hs
def test_put_presence(self):
"""
PUT to the status endpoint with use_presence enabled will call
set_state on the presence handler.
"""
self.hs.config.use_presence = True
body = {"presence": "here", "status_msg": "beep boop"}
request, channel = self.make_request(
"PUT", "/presence/%s/status" % (self.user_id,), body
)
self.render(request)
self.assertEqual(channel.code, 200)
self.assertEqual(self.hs.presence_handler.set_state.call_count, 1)
def test_put_presence_disabled(self):
"""
PUT to the status endpoint with use_presence disbled will NOT call
set_state on the presence handler.
"""
self.hs.config.use_presence = False
body = {"presence": "here", "status_msg": "beep boop"}
request, channel = self.make_request(
"PUT", "/presence/%s/status" % (self.user_id,), body
)
self.render(request)
self.assertEqual(channel.code, 200)
self.assertEqual(self.hs.presence_handler.set_state.call_count, 0)

View file

@ -25,7 +25,7 @@ from synapse.rest.client.v1_only.register import register_servlets
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import make_request, setup_test_homeserver from tests.server import make_request, render, setup_test_homeserver
class CreateUserServletTestCase(unittest.TestCase): class CreateUserServletTestCase(unittest.TestCase):
@ -77,10 +77,7 @@ class CreateUserServletTestCase(unittest.TestCase):
) )
request, channel = make_request(b"POST", url, request_data) request, channel = make_request(b"POST", url, request_data)
request.render(res) render(request, res, self.clock)
# Advance the clock because it waits
self.clock.advance(1)
self.assertEquals(channel.result["code"], b"200") self.assertEquals(channel.result["code"], b"200")

View file

@ -23,7 +23,7 @@ from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import Membership
from tests import unittest from tests import unittest
from tests.server import make_request, wait_until_result from tests.server import make_request, render
class RestTestCase(unittest.TestCase): class RestTestCase(unittest.TestCase):
@ -171,8 +171,7 @@ class RestHelper(object):
request, channel = make_request( request, channel = make_request(
"POST", path, json.dumps(content).encode('utf8') "POST", path, json.dumps(content).encode('utf8')
) )
request.render(self.resource) render(request, self.resource, self.hs.get_reactor())
wait_until_result(self.hs.get_reactor(), channel)
assert channel.result["code"] == b"200", channel.result assert channel.result["code"] == b"200", channel.result
self.auth_user_id = temp_id self.auth_user_id = temp_id
@ -220,8 +219,7 @@ class RestHelper(object):
request, channel = make_request("PUT", path, json.dumps(data).encode('utf8')) request, channel = make_request("PUT", path, json.dumps(data).encode('utf8'))
request.render(self.resource) render(request, self.resource, self.hs.get_reactor())
wait_until_result(self.hs.get_reactor(), channel)
assert int(channel.result["code"]) == expect_code, ( assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r" "Expected: %d, got: %d, resp: %r"

View file

@ -24,8 +24,8 @@ from tests import unittest
from tests.server import ( from tests.server import (
ThreadedMemoryReactorClock as MemoryReactorClock, ThreadedMemoryReactorClock as MemoryReactorClock,
make_request, make_request,
render,
setup_test_homeserver, setup_test_homeserver,
wait_until_result,
) )
PATH_PREFIX = "/_matrix/client/v2_alpha" PATH_PREFIX = "/_matrix/client/v2_alpha"
@ -76,8 +76,7 @@ class FilterTestCase(unittest.TestCase):
"/_matrix/client/r0/user/%s/filter" % (self.USER_ID), "/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
self.EXAMPLE_FILTER_JSON, self.EXAMPLE_FILTER_JSON,
) )
request.render(self.resource) render(request, self.resource, self.clock)
wait_until_result(self.clock, channel)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")
self.assertEqual(channel.json_body, {"filter_id": "0"}) self.assertEqual(channel.json_body, {"filter_id": "0"})
@ -91,8 +90,7 @@ class FilterTestCase(unittest.TestCase):
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
self.EXAMPLE_FILTER_JSON, self.EXAMPLE_FILTER_JSON,
) )
request.render(self.resource) render(request, self.resource, self.clock)
wait_until_result(self.clock, channel)
self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.result["code"], b"403")
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
@ -105,8 +103,7 @@ class FilterTestCase(unittest.TestCase):
"/_matrix/client/r0/user/%s/filter" % (self.USER_ID), "/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
self.EXAMPLE_FILTER_JSON, self.EXAMPLE_FILTER_JSON,
) )
request.render(self.resource) render(request, self.resource, self.clock)
wait_until_result(self.clock, channel)
self.hs.is_mine = _is_mine self.hs.is_mine = _is_mine
self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.result["code"], b"403")
@ -121,8 +118,7 @@ class FilterTestCase(unittest.TestCase):
request, channel = make_request( request, channel = make_request(
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id) "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
) )
request.render(self.resource) render(request, self.resource, self.clock)
wait_until_result(self.clock, channel)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")
self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
@ -131,8 +127,7 @@ class FilterTestCase(unittest.TestCase):
request, channel = make_request( request, channel = make_request(
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID) "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
) )
request.render(self.resource) render(request, self.resource, self.clock)
wait_until_result(self.clock, channel)
self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.result["code"], b"400")
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
@ -143,8 +138,7 @@ class FilterTestCase(unittest.TestCase):
request, channel = make_request( request, channel = make_request(
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID) "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
) )
request.render(self.resource) render(request, self.resource, self.clock)
wait_until_result(self.clock, channel)
self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.result["code"], b"400")
@ -153,7 +147,6 @@ class FilterTestCase(unittest.TestCase):
request, channel = make_request( request, channel = make_request(
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID) "GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
) )
request.render(self.resource) render(request, self.resource, self.clock)
wait_until_result(self.clock, channel)
self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.result["code"], b"400")

View file

@ -11,7 +11,7 @@ from synapse.rest.client.v2_alpha.register import register_servlets
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import make_request, setup_test_homeserver, wait_until_result from tests.server import make_request, render, setup_test_homeserver
class RegisterRestServletTestCase(unittest.TestCase): class RegisterRestServletTestCase(unittest.TestCase):
@ -72,8 +72,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
request, channel = make_request( request, channel = make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
) )
request.render(self.resource) render(request, self.resource, self.clock)
wait_until_result(self.clock, channel)
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
det_data = { det_data = {
@ -89,16 +88,14 @@ class RegisterRestServletTestCase(unittest.TestCase):
request, channel = make_request( request, channel = make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
) )
request.render(self.resource) render(request, self.resource, self.clock)
wait_until_result(self.clock, channel)
self.assertEquals(channel.result["code"], b"401", channel.result) self.assertEquals(channel.result["code"], b"401", channel.result)
def test_POST_bad_password(self): def test_POST_bad_password(self):
request_data = json.dumps({"username": "kermit", "password": 666}) request_data = json.dumps({"username": "kermit", "password": 666})
request, channel = make_request(b"POST", self.url, request_data) request, channel = make_request(b"POST", self.url, request_data)
request.render(self.resource) render(request, self.resource, self.clock)
wait_until_result(self.clock, channel)
self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid password") self.assertEquals(channel.json_body["error"], "Invalid password")
@ -106,8 +103,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
def test_POST_bad_username(self): def test_POST_bad_username(self):
request_data = json.dumps({"username": 777, "password": "monkey"}) request_data = json.dumps({"username": 777, "password": "monkey"})
request, channel = make_request(b"POST", self.url, request_data) request, channel = make_request(b"POST", self.url, request_data)
request.render(self.resource) render(request, self.resource, self.clock)
wait_until_result(self.clock, channel)
self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid username") self.assertEquals(channel.json_body["error"], "Invalid username")
@ -126,8 +122,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.device_handler.check_device_registered = Mock(return_value=device_id) self.device_handler.check_device_registered = Mock(return_value=device_id)
request, channel = make_request(b"POST", self.url, request_data) request, channel = make_request(b"POST", self.url, request_data)
request.render(self.resource) render(request, self.resource, self.clock)
wait_until_result(self.clock, channel)
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
@ -149,8 +144,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.registration_handler.register = Mock(return_value=("@user:id", "t")) self.registration_handler.register = Mock(return_value=("@user:id", "t"))
request, channel = make_request(b"POST", self.url, request_data) request, channel = make_request(b"POST", self.url, request_data)
request.render(self.resource) render(request, self.resource, self.clock)
wait_until_result(self.clock, channel)
self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Registration has been disabled") self.assertEquals(channel.json_body["error"], "Registration has been disabled")
@ -162,8 +156,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.registration_handler.register = Mock(return_value=(user_id, None)) self.registration_handler.register = Mock(return_value=(user_id, None))
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}") request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
request.render(self.resource) render(request, self.resource, self.clock)
wait_until_result(self.clock, channel)
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
@ -177,8 +170,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.config.allow_guest_access = False self.hs.config.allow_guest_access = False
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}") request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
request.render(self.resource) render(request, self.resource, self.clock)
wait_until_result(self.clock, channel)
self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Guest access is disabled") self.assertEquals(channel.json_body["error"], "Guest access is disabled")

View file

@ -13,66 +13,30 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import synapse.types from mock import Mock
from synapse.http.server import JsonResource
from synapse.rest.client.v2_alpha import sync from synapse.rest.client.v2_alpha import sync
from synapse.types import UserID
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import (
ThreadedMemoryReactorClock as MemoryReactorClock,
make_request,
setup_test_homeserver,
wait_until_result,
)
PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(unittest.TestCase): class FilterTestCase(unittest.HomeserverTestCase):
USER_ID = "@apple:test" user_id = "@apple:test"
TO_REGISTER = [sync] servlets = [sync.register_servlets]
def setUp(self): def make_homeserver(self, reactor, clock):
self.clock = MemoryReactorClock()
self.hs_clock = Clock(self.clock)
self.hs = setup_test_homeserver( hs = self.setup_test_homeserver(
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock "red", http_client=None, federation_client=Mock()
) )
return hs
self.auth = self.hs.get_auth()
def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.USER_ID),
"token_id": 1,
"is_guest": False,
}
def get_user_by_req(request, allow_guest=False, rights="access"):
return synapse.types.create_requester(
UserID.from_string(self.USER_ID), 1, False, None
)
self.auth.get_user_by_access_token = get_user_by_access_token
self.auth.get_user_by_req = get_user_by_req
self.store = self.hs.get_datastore()
self.filtering = self.hs.get_filtering()
self.resource = JsonResource(self.hs)
for r in self.TO_REGISTER:
r.register_servlets(self.hs, self.resource)
def test_sync_argless(self): def test_sync_argless(self):
request, channel = make_request("GET", "/_matrix/client/r0/sync") request, channel = self.make_request("GET", "/sync")
request.render(self.resource) self.render(request)
wait_until_result(self.clock, channel)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.code, 200)
self.assertTrue( self.assertTrue(
set( set(
[ [
@ -85,3 +49,25 @@ class FilterTestCase(unittest.TestCase):
] ]
).issubset(set(channel.json_body.keys())) ).issubset(set(channel.json_body.keys()))
) )
def test_sync_presence_disabled(self):
"""
When presence is disabled, the key does not appear in /sync.
"""
self.hs.config.use_presence = False
request, channel = self.make_request("GET", "/sync")
self.render(request)
self.assertEqual(channel.code, 200)
self.assertTrue(
set(
[
"next_batch",
"rooms",
"account_data",
"to_device",
"device_lists",
]
).issubset(set(channel.json_body.keys()))
)

View file

@ -24,6 +24,7 @@ class FakeChannel(object):
""" """
result = attr.ib(default=attr.Factory(dict)) result = attr.ib(default=attr.Factory(dict))
_producer = None
@property @property
def json_body(self): def json_body(self):
@ -49,6 +50,15 @@ class FakeChannel(object):
self.result["body"] += content self.result["body"] += content
def registerProducer(self, producer, streaming):
self._producer = producer
def unregisterProducer(self):
if self._producer is None:
return
self._producer = None
def requestDone(self, _self): def requestDone(self, _self):
self.result["done"] = True self.result["done"] = True
@ -111,14 +121,19 @@ def make_request(method, path, content=b""):
return req, channel return req, channel
def wait_until_result(clock, channel, timeout=100): def wait_until_result(clock, request, timeout=100):
""" """
Wait until the channel has a result. Wait until the request is finished.
""" """
clock.run() clock.run()
x = 0 x = 0
while not channel.result: while not request.finished:
# If there's a producer, tell it to resume producing so we get content
if request._channel._producer:
request._channel._producer.resumeProducing()
x += 1 x += 1
if x > timeout: if x > timeout:
@ -129,7 +144,7 @@ def wait_until_result(clock, channel, timeout=100):
def render(request, resource, clock): def render(request, resource, clock):
request.render(resource) request.render(resource)
wait_until_result(clock, request._channel) wait_until_result(clock, request)
class ThreadedMemoryReactorClock(MemoryReactorClock): class ThreadedMemoryReactorClock(MemoryReactorClock):

View file

@ -75,6 +75,19 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase):
active_count = yield self.store.get_monthly_active_count() active_count = yield self.store.get_monthly_active_count()
self.assertEquals(active_count, user_num) self.assertEquals(active_count, user_num)
# Test that regalar users are removed from the db
ru_count = 2
yield self.store.upsert_monthly_active_user("@ru1:server")
yield self.store.upsert_monthly_active_user("@ru2:server")
active_count = yield self.store.get_monthly_active_count()
self.assertEqual(active_count, user_num + ru_count)
self.hs.config.max_mau_value = user_num
yield self.store.reap_monthly_active_users()
active_count = yield self.store.get_monthly_active_count()
self.assertEquals(active_count, user_num)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_can_insert_and_count_mau(self): def test_can_insert_and_count_mau(self):
count = yield self.store.get_monthly_active_count() count = yield self.store.get_monthly_active_count()

View file

@ -8,7 +8,7 @@ from synapse.http.server import JsonResource
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import make_request, setup_test_homeserver from tests.server import make_request, render, setup_test_homeserver
class JsonResourceTests(unittest.TestCase): class JsonResourceTests(unittest.TestCase):
@ -37,7 +37,7 @@ class JsonResourceTests(unittest.TestCase):
) )
request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83") request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83")
request.render(res) render(request, res, self.reactor)
self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]}) self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"}) self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"})
@ -55,7 +55,7 @@ class JsonResourceTests(unittest.TestCase):
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
request, channel = make_request(b"GET", b"/_matrix/foo") request, channel = make_request(b"GET", b"/_matrix/foo")
request.render(res) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b'500') self.assertEqual(channel.result["code"], b'500')
@ -78,13 +78,8 @@ class JsonResourceTests(unittest.TestCase):
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
request, channel = make_request(b"GET", b"/_matrix/foo") request, channel = make_request(b"GET", b"/_matrix/foo")
request.render(res) render(request, res, self.reactor)
# No error has been raised yet
self.assertTrue("code" not in channel.result)
# Advance time, now there's an error
self.reactor.advance(1)
self.assertEqual(channel.result["code"], b'500') self.assertEqual(channel.result["code"], b'500')
def test_callback_synapseerror(self): def test_callback_synapseerror(self):
@ -100,7 +95,7 @@ class JsonResourceTests(unittest.TestCase):
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
request, channel = make_request(b"GET", b"/_matrix/foo") request, channel = make_request(b"GET", b"/_matrix/foo")
request.render(res) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b'403') self.assertEqual(channel.result["code"], b'403')
self.assertEqual(channel.json_body["error"], "Forbidden!!one!") self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
@ -121,7 +116,7 @@ class JsonResourceTests(unittest.TestCase):
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
request, channel = make_request(b"GET", b"/_matrix/foobar") request, channel = make_request(b"GET", b"/_matrix/foobar")
request.render(res) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b'400') self.assertEqual(channel.result["code"], b'400')
self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["error"], "Unrecognized request")

View file

@ -18,6 +18,8 @@ import logging
from mock import Mock from mock import Mock
from canonicaljson import json
import twisted import twisted
import twisted.logger import twisted.logger
from twisted.trial import unittest from twisted.trial import unittest
@ -241,11 +243,15 @@ class HomeserverTestCase(TestCase):
method (bytes/unicode): The HTTP request method ("verb"). method (bytes/unicode): The HTTP request method ("verb").
path (bytes/unicode): The HTTP path, suitably URL encoded (e.g. path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
escaped UTF-8 & spaces and such). escaped UTF-8 & spaces and such).
content (bytes): The body of the request. content (bytes or dict): The body of the request. JSON-encoded, if
a dict.
Returns: Returns:
A synapse.http.site.SynapseRequest. A synapse.http.site.SynapseRequest.
""" """
if isinstance(content, dict):
content = json.dumps(content).encode('utf8')
return make_request(method, path, content) return make_request(method, path, content)
def render(self, request): def render(self, request):

View file

@ -93,7 +93,8 @@ def setupdb():
@defer.inlineCallbacks @defer.inlineCallbacks
def setup_test_homeserver( def setup_test_homeserver(
cleanup_func, name="test", datastore=None, config=None, reactor=None, **kargs cleanup_func, name="test", datastore=None, config=None, reactor=None,
homeserverToUse=HomeServer, **kargs
): ):
""" """
Setup a homeserver suitable for running tests against. Keyword arguments Setup a homeserver suitable for running tests against. Keyword arguments
@ -137,6 +138,7 @@ def setup_test_homeserver(
config.limit_usage_by_mau = False config.limit_usage_by_mau = False
config.hs_disabled = False config.hs_disabled = False
config.hs_disabled_message = "" config.hs_disabled_message = ""
config.hs_disabled_limit_type = ""
config.max_mau_value = 50 config.max_mau_value = 50
config.mau_limits_reserved_threepids = [] config.mau_limits_reserved_threepids = []
config.admin_uri = None config.admin_uri = None
@ -191,7 +193,7 @@ def setup_test_homeserver(
config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
if datastore is None: if datastore is None:
hs = HomeServer( hs = homeserverToUse(
name, name,
config=config, config=config,
db_config=config.database_config, db_config=config.database_config,
@ -234,7 +236,7 @@ def setup_test_homeserver(
hs.setup() hs.setup()
else: else:
hs = HomeServer( hs = homeserverToUse(
name, name,
db_pool=None, db_pool=None,
datastore=datastore, datastore=datastore,