mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-19 17:56:19 +03:00
Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes
This commit is contained in:
commit
f5abc10724
58 changed files with 707 additions and 356 deletions
1
changelog.d/3694.feature
Normal file
1
changelog.d/3694.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Synapse's presence functionality can now be disabled with the "use_presence" configuration option.
|
1
changelog.d/3700.bugfix
Normal file
1
changelog.d/3700.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Improve HTTP request logging to include all requests
|
1
changelog.d/3701.bugfix
Normal file
1
changelog.d/3701.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Avoid timing out requests while we are streaming back the response
|
1
changelog.d/3703.removal
Normal file
1
changelog.d/3703.removal
Normal file
|
@ -0,0 +1 @@
|
||||||
|
The Shared-Secret registration method of the legacy v1/register REST endpoint has been removed. For a replacement, please see [the admin/register API documentation](https://github.com/matrix-org/synapse/blob/master/docs/admin_api/register_api.rst).
|
1
changelog.d/3707.misc
Normal file
1
changelog.d/3707.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
add new error type ResourceLimit
|
1
changelog.d/3708.feature
Normal file
1
changelog.d/3708.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
For resource limit blocked users, prevent writing into rooms
|
1
changelog.d/3709.misc
Normal file
1
changelog.d/3709.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Logcontexts for replication command handlers
|
1
changelog.d/3710.bugfix
Normal file
1
changelog.d/3710.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix "Starting db txn 'get_all_updated_receipts' from sentinel context" warning
|
1
changelog.d/3712.misc
Normal file
1
changelog.d/3712.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Update admin register API documentation to reference a real user ID.
|
1
changelog.d/3713.bugfix
Normal file
1
changelog.d/3713.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Support more federation endpoints on workers
|
|
@ -33,7 +33,7 @@ As an example::
|
||||||
|
|
||||||
< {
|
< {
|
||||||
"access_token": "token_here",
|
"access_token": "token_here",
|
||||||
"user_id": "@pepper_roni@test",
|
"user_id": "@pepper_roni:localhost",
|
||||||
"home_server": "test",
|
"home_server": "test",
|
||||||
"device_id": "device_id_here"
|
"device_id": "device_id_here"
|
||||||
}
|
}
|
||||||
|
|
|
@ -241,6 +241,14 @@ regular expressions::
|
||||||
|
|
||||||
^/_matrix/client/(api/v1|r0|unstable)/keys/upload
|
^/_matrix/client/(api/v1|r0|unstable)/keys/upload
|
||||||
|
|
||||||
|
If ``use_presence`` is False in the homeserver config, it can also handle REST
|
||||||
|
endpoints matching the following regular expressions::
|
||||||
|
|
||||||
|
^/_matrix/client/(api/v1|r0|unstable)/presence/[^/]+/status
|
||||||
|
|
||||||
|
This "stub" presence handler will pass through ``GET`` request but make the
|
||||||
|
``PUT`` effectively a no-op.
|
||||||
|
|
||||||
It will proxy any requests it cannot handle to the main synapse instance. It
|
It will proxy any requests it cannot handle to the main synapse instance. It
|
||||||
must therefore be configured with the location of the main instance, via
|
must therefore be configured with the location of the main instance, via
|
||||||
the ``worker_main_http_uri`` setting in the frontend_proxy worker configuration
|
the ``worker_main_http_uri`` setting in the frontend_proxy worker configuration
|
||||||
|
|
|
@ -25,7 +25,7 @@ from twisted.internet import defer
|
||||||
import synapse.types
|
import synapse.types
|
||||||
from synapse import event_auth
|
from synapse import event_auth
|
||||||
from synapse.api.constants import EventTypes, JoinRules, Membership
|
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||||
from synapse.api.errors import AuthError, Codes
|
from synapse.api.errors import AuthError, Codes, ResourceLimitError
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
|
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
@ -784,10 +784,11 @@ class Auth(object):
|
||||||
MAU cohort
|
MAU cohort
|
||||||
"""
|
"""
|
||||||
if self.hs.config.hs_disabled:
|
if self.hs.config.hs_disabled:
|
||||||
raise AuthError(
|
raise ResourceLimitError(
|
||||||
403, self.hs.config.hs_disabled_message,
|
403, self.hs.config.hs_disabled_message,
|
||||||
errcode=Codes.RESOURCE_LIMIT_EXCEED,
|
errcode=Codes.RESOURCE_LIMIT_EXCEED,
|
||||||
admin_uri=self.hs.config.admin_uri,
|
admin_uri=self.hs.config.admin_uri,
|
||||||
|
limit_type=self.hs.config.hs_disabled_limit_type
|
||||||
)
|
)
|
||||||
if self.hs.config.limit_usage_by_mau is True:
|
if self.hs.config.limit_usage_by_mau is True:
|
||||||
# If the user is already part of the MAU cohort
|
# If the user is already part of the MAU cohort
|
||||||
|
@ -798,8 +799,10 @@ class Auth(object):
|
||||||
# Else if there is no room in the MAU bucket, bail
|
# Else if there is no room in the MAU bucket, bail
|
||||||
current_mau = yield self.store.get_monthly_active_count()
|
current_mau = yield self.store.get_monthly_active_count()
|
||||||
if current_mau >= self.hs.config.max_mau_value:
|
if current_mau >= self.hs.config.max_mau_value:
|
||||||
raise AuthError(
|
raise ResourceLimitError(
|
||||||
403, "Monthly Active User Limits AU Limit Exceeded",
|
403, "Monthly Active User Limit Exceeded",
|
||||||
|
|
||||||
admin_uri=self.hs.config.admin_uri,
|
admin_uri=self.hs.config.admin_uri,
|
||||||
errcode=Codes.RESOURCE_LIMIT_EXCEED
|
errcode=Codes.RESOURCE_LIMIT_EXCEED,
|
||||||
|
limit_type="monthly_active_user"
|
||||||
)
|
)
|
||||||
|
|
|
@ -224,15 +224,34 @@ class NotFoundError(SynapseError):
|
||||||
|
|
||||||
class AuthError(SynapseError):
|
class AuthError(SynapseError):
|
||||||
"""An error raised when there was a problem authorising an event."""
|
"""An error raised when there was a problem authorising an event."""
|
||||||
def __init__(self, code, msg, errcode=Codes.FORBIDDEN, admin_uri=None):
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
if "errcode" not in kwargs:
|
||||||
|
kwargs["errcode"] = Codes.FORBIDDEN
|
||||||
|
super(AuthError, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceLimitError(SynapseError):
|
||||||
|
"""
|
||||||
|
Any error raised when there is a problem with resource usage.
|
||||||
|
For instance, the monthly active user limit for the server has been exceeded
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self, code, msg,
|
||||||
|
errcode=Codes.RESOURCE_LIMIT_EXCEED,
|
||||||
|
admin_uri=None,
|
||||||
|
limit_type=None,
|
||||||
|
):
|
||||||
self.admin_uri = admin_uri
|
self.admin_uri = admin_uri
|
||||||
super(AuthError, self).__init__(code, msg, errcode=errcode)
|
self.limit_type = limit_type
|
||||||
|
super(ResourceLimitError, self).__init__(code, msg, errcode=errcode)
|
||||||
|
|
||||||
def error_dict(self):
|
def error_dict(self):
|
||||||
return cs_error(
|
return cs_error(
|
||||||
self.msg,
|
self.msg,
|
||||||
self.errcode,
|
self.errcode,
|
||||||
admin_uri=self.admin_uri,
|
admin_uri=self.admin_uri,
|
||||||
|
limit_type=self.limit_type
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -140,7 +140,7 @@ def listen_metrics(bind_addresses, port):
|
||||||
logger.info("Metrics now reporting on %s:%d", host, port)
|
logger.info("Metrics now reporting on %s:%d", host, port)
|
||||||
|
|
||||||
|
|
||||||
def listen_tcp(bind_addresses, port, factory, backlog=50):
|
def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
|
||||||
"""
|
"""
|
||||||
Create a TCP socket for a port and several addresses
|
Create a TCP socket for a port and several addresses
|
||||||
"""
|
"""
|
||||||
|
@ -156,7 +156,9 @@ def listen_tcp(bind_addresses, port, factory, backlog=50):
|
||||||
check_bind_error(e, address, bind_addresses)
|
check_bind_error(e, address, bind_addresses)
|
||||||
|
|
||||||
|
|
||||||
def listen_ssl(bind_addresses, port, factory, context_factory, backlog=50):
|
def listen_ssl(
|
||||||
|
bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Create an SSL socket for a port and several addresses
|
Create an SSL socket for a port and several addresses
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -117,8 +117,9 @@ class ASReplicationHandler(ReplicationClientHandler):
|
||||||
super(ASReplicationHandler, self).__init__(hs.get_datastore())
|
super(ASReplicationHandler, self).__init__(hs.get_datastore())
|
||||||
self.appservice_handler = hs.get_application_service_handler()
|
self.appservice_handler = hs.get_application_service_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def on_rdata(self, stream_name, token, rows):
|
def on_rdata(self, stream_name, token, rows):
|
||||||
super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
|
yield super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
|
||||||
|
|
||||||
if stream_name == "events":
|
if stream_name == "events":
|
||||||
max_stream_id = self.store.get_room_max_stream_ordering()
|
max_stream_id = self.store.get_room_max_stream_ordering()
|
||||||
|
|
|
@ -144,8 +144,9 @@ class FederationSenderReplicationHandler(ReplicationClientHandler):
|
||||||
super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
|
super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
|
||||||
self.send_handler = FederationSenderHandler(hs, self)
|
self.send_handler = FederationSenderHandler(hs, self)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def on_rdata(self, stream_name, token, rows):
|
def on_rdata(self, stream_name, token, rows):
|
||||||
super(FederationSenderReplicationHandler, self).on_rdata(
|
yield super(FederationSenderReplicationHandler, self).on_rdata(
|
||||||
stream_name, token, rows
|
stream_name, token, rows
|
||||||
)
|
)
|
||||||
self.send_handler.process_replication_rows(stream_name, token, rows)
|
self.send_handler.process_replication_rows(stream_name, token, rows)
|
||||||
|
|
|
@ -165,7 +165,12 @@ class FrontendProxyServer(HomeServer):
|
||||||
elif name == "client":
|
elif name == "client":
|
||||||
resource = JsonResource(self, canonical_json=False)
|
resource = JsonResource(self, canonical_json=False)
|
||||||
KeyUploadServlet(self).register(resource)
|
KeyUploadServlet(self).register(resource)
|
||||||
PresenceStatusStubServlet(self).register(resource)
|
|
||||||
|
# If presence is disabled, use the stub servlet that does
|
||||||
|
# not allow sending presence
|
||||||
|
if not self.config.use_presence:
|
||||||
|
PresenceStatusStubServlet(self).register(resource)
|
||||||
|
|
||||||
resources.update({
|
resources.update({
|
||||||
"/_matrix/client/r0": resource,
|
"/_matrix/client/r0": resource,
|
||||||
"/_matrix/client/unstable": resource,
|
"/_matrix/client/unstable": resource,
|
||||||
|
@ -184,7 +189,8 @@ class FrontendProxyServer(HomeServer):
|
||||||
listener_config,
|
listener_config,
|
||||||
root_resource,
|
root_resource,
|
||||||
self.version_string,
|
self.version_string,
|
||||||
)
|
),
|
||||||
|
reactor=self.get_reactor()
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Synapse client reader now listening on port %d", port)
|
logger.info("Synapse client reader now listening on port %d", port)
|
||||||
|
|
|
@ -525,6 +525,7 @@ def run(hs):
|
||||||
clock.looping_call(
|
clock.looping_call(
|
||||||
hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60
|
hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60
|
||||||
)
|
)
|
||||||
|
hs.get_datastore().reap_monthly_active_users()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def generate_monthly_active_users():
|
def generate_monthly_active_users():
|
||||||
|
|
|
@ -148,8 +148,9 @@ class PusherReplicationHandler(ReplicationClientHandler):
|
||||||
|
|
||||||
self.pusher_pool = hs.get_pusherpool()
|
self.pusher_pool = hs.get_pusherpool()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def on_rdata(self, stream_name, token, rows):
|
def on_rdata(self, stream_name, token, rows):
|
||||||
super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
|
yield super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
|
||||||
run_in_background(self.poke_pushers, stream_name, token, rows)
|
run_in_background(self.poke_pushers, stream_name, token, rows)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -162,11 +163,11 @@ class PusherReplicationHandler(ReplicationClientHandler):
|
||||||
else:
|
else:
|
||||||
yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
|
yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
|
||||||
elif stream_name == "events":
|
elif stream_name == "events":
|
||||||
yield self.pusher_pool.on_new_notifications(
|
self.pusher_pool.on_new_notifications(
|
||||||
token, token,
|
token, token,
|
||||||
)
|
)
|
||||||
elif stream_name == "receipts":
|
elif stream_name == "receipts":
|
||||||
yield self.pusher_pool.on_new_receipts(
|
self.pusher_pool.on_new_receipts(
|
||||||
token, token, set(row.room_id for row in rows)
|
token, token, set(row.room_id for row in rows)
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
@ -114,8 +114,10 @@ class SynchrotronPresence(object):
|
||||||
logger.info("Presence process_id is %r", self.process_id)
|
logger.info("Presence process_id is %r", self.process_id)
|
||||||
|
|
||||||
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
|
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
|
||||||
return
|
if self.hs.config.use_presence:
|
||||||
self.hs.get_tcp_replication().send_user_sync(user_id, is_syncing, last_sync_ms)
|
self.hs.get_tcp_replication().send_user_sync(
|
||||||
|
user_id, is_syncing, last_sync_ms
|
||||||
|
)
|
||||||
|
|
||||||
def mark_as_coming_online(self, user_id):
|
def mark_as_coming_online(self, user_id):
|
||||||
"""A user has started syncing. Send a UserSync to the master, unless they
|
"""A user has started syncing. Send a UserSync to the master, unless they
|
||||||
|
@ -212,12 +214,13 @@ class SynchrotronPresence(object):
|
||||||
yield self.notify_from_replication(states, stream_id)
|
yield self.notify_from_replication(states, stream_id)
|
||||||
|
|
||||||
def get_currently_syncing_users(self):
|
def get_currently_syncing_users(self):
|
||||||
# presence is disabled on matrix.org, so we return the empty set
|
if self.hs.config.use_presence:
|
||||||
return set()
|
return [
|
||||||
return [
|
user_id for user_id, count in iteritems(self.user_to_num_current_syncs)
|
||||||
user_id for user_id, count in iteritems(self.user_to_num_current_syncs)
|
if count > 0
|
||||||
if count > 0
|
]
|
||||||
]
|
else:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
|
||||||
class SynchrotronTyping(object):
|
class SynchrotronTyping(object):
|
||||||
|
@ -335,8 +338,9 @@ class SyncReplicationHandler(ReplicationClientHandler):
|
||||||
self.presence_handler = hs.get_presence_handler()
|
self.presence_handler = hs.get_presence_handler()
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def on_rdata(self, stream_name, token, rows):
|
def on_rdata(self, stream_name, token, rows):
|
||||||
super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
|
yield super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
|
||||||
run_in_background(self.process_and_notify, stream_name, token, rows)
|
run_in_background(self.process_and_notify, stream_name, token, rows)
|
||||||
|
|
||||||
def get_streams_to_replicate(self):
|
def get_streams_to_replicate(self):
|
||||||
|
|
|
@ -169,8 +169,9 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler):
|
||||||
super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore())
|
super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore())
|
||||||
self.user_directory = hs.get_user_directory_handler()
|
self.user_directory = hs.get_user_directory_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def on_rdata(self, stream_name, token, rows):
|
def on_rdata(self, stream_name, token, rows):
|
||||||
super(UserDirectoryReplicationHandler, self).on_rdata(
|
yield super(UserDirectoryReplicationHandler, self).on_rdata(
|
||||||
stream_name, token, rows
|
stream_name, token, rows
|
||||||
)
|
)
|
||||||
if stream_name == "current_state_deltas":
|
if stream_name == "current_state_deltas":
|
||||||
|
|
|
@ -49,6 +49,9 @@ class ServerConfig(Config):
|
||||||
# "disable" federation
|
# "disable" federation
|
||||||
self.send_federation = config.get("send_federation", True)
|
self.send_federation = config.get("send_federation", True)
|
||||||
|
|
||||||
|
# Whether to enable user presence.
|
||||||
|
self.use_presence = config.get("use_presence", True)
|
||||||
|
|
||||||
# Whether to update the user directory or not. This should be set to
|
# Whether to update the user directory or not. This should be set to
|
||||||
# false only if we are updating the user directory in a worker
|
# false only if we are updating the user directory in a worker
|
||||||
self.update_user_directory = config.get("update_user_directory", True)
|
self.update_user_directory = config.get("update_user_directory", True)
|
||||||
|
@ -81,6 +84,7 @@ class ServerConfig(Config):
|
||||||
# Options to disable HS
|
# Options to disable HS
|
||||||
self.hs_disabled = config.get("hs_disabled", False)
|
self.hs_disabled = config.get("hs_disabled", False)
|
||||||
self.hs_disabled_message = config.get("hs_disabled_message", "")
|
self.hs_disabled_message = config.get("hs_disabled_message", "")
|
||||||
|
self.hs_disabled_limit_type = config.get("hs_disabled_limit_type", "")
|
||||||
|
|
||||||
# Admin uri to direct users at should their instance become blocked
|
# Admin uri to direct users at should their instance become blocked
|
||||||
# due to resource constraints
|
# due to resource constraints
|
||||||
|
@ -249,6 +253,9 @@ class ServerConfig(Config):
|
||||||
# hard limit.
|
# hard limit.
|
||||||
soft_file_limit: 0
|
soft_file_limit: 0
|
||||||
|
|
||||||
|
# Set to false to disable presence tracking on this homeserver.
|
||||||
|
use_presence: true
|
||||||
|
|
||||||
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
|
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
|
||||||
# gc_thresholds: [700, 10, 10]
|
# gc_thresholds: [700, 10, 10]
|
||||||
|
|
||||||
|
@ -340,6 +347,32 @@ class ServerConfig(Config):
|
||||||
# - port: 9000
|
# - port: 9000
|
||||||
# bind_addresses: ['::1', '127.0.0.1']
|
# bind_addresses: ['::1', '127.0.0.1']
|
||||||
# type: manhole
|
# type: manhole
|
||||||
|
|
||||||
|
|
||||||
|
# Homeserver blocking
|
||||||
|
#
|
||||||
|
# How to reach the server admin, used in ResourceLimitError
|
||||||
|
# admin_uri: 'mailto:admin@server.com'
|
||||||
|
#
|
||||||
|
# Global block config
|
||||||
|
#
|
||||||
|
# hs_disabled: False
|
||||||
|
# hs_disabled_message: 'Human readable reason for why the HS is blocked'
|
||||||
|
# hs_disabled_limit_type: 'error code(str), to help clients decode reason'
|
||||||
|
#
|
||||||
|
# Monthly Active User Blocking
|
||||||
|
#
|
||||||
|
# Enables monthly active user checking
|
||||||
|
# limit_usage_by_mau: False
|
||||||
|
# max_mau_value: 50
|
||||||
|
#
|
||||||
|
# Sometimes the server admin will want to ensure certain accounts are
|
||||||
|
# never blocked by mau checking. These accounts are specified here.
|
||||||
|
#
|
||||||
|
# mau_limit_reserved_threepids:
|
||||||
|
# - medium: 'email'
|
||||||
|
# address: 'reserved_user@example.com'
|
||||||
|
|
||||||
""" % locals()
|
""" % locals()
|
||||||
|
|
||||||
def read_arguments(self, args):
|
def read_arguments(self, args):
|
||||||
|
|
|
@ -58,6 +58,7 @@ class TransactionQueue(object):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
self.hs = hs
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
@ -308,7 +309,9 @@ class TransactionQueue(object):
|
||||||
Args:
|
Args:
|
||||||
states (list(UserPresenceState))
|
states (list(UserPresenceState))
|
||||||
"""
|
"""
|
||||||
return
|
if not self.hs.config.use_presence:
|
||||||
|
# No-op if presence is disabled.
|
||||||
|
return
|
||||||
|
|
||||||
# First we queue up the new presence by user ID, so multiple presence
|
# First we queue up the new presence by user ID, so multiple presence
|
||||||
# updates in quick successtion are correctly handled
|
# updates in quick successtion are correctly handled
|
||||||
|
|
|
@ -2386,8 +2386,7 @@ class FederationHandler(BaseHandler):
|
||||||
extra_users=extra_users
|
extra_users=extra_users
|
||||||
)
|
)
|
||||||
|
|
||||||
logcontext.run_in_background(
|
self.pusher_pool.on_new_notifications(
|
||||||
self.pusher_pool.on_new_notifications,
|
|
||||||
event_stream_id, max_stream_id,
|
event_stream_id, max_stream_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -372,7 +372,10 @@ class InitialSyncHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_presence():
|
def get_presence():
|
||||||
defer.returnValue([])
|
# If presence is disabled, return an empty list
|
||||||
|
if not self.hs.config.use_presence:
|
||||||
|
defer.returnValue([])
|
||||||
|
|
||||||
states = yield presence_handler.get_states(
|
states = yield presence_handler.get_states(
|
||||||
[m.user_id for m in room_members],
|
[m.user_id for m in room_members],
|
||||||
as_event=True,
|
as_event=True,
|
||||||
|
|
|
@ -276,10 +276,14 @@ class EventCreationHandler(object):
|
||||||
where *hashes* is a map from algorithm to hash.
|
where *hashes* is a map from algorithm to hash.
|
||||||
|
|
||||||
If None, they will be requested from the database.
|
If None, they will be requested from the database.
|
||||||
|
Raises:
|
||||||
|
ResourceLimitError if server is blocked to some resource being
|
||||||
|
exceeded
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of created event (FrozenEvent), Context
|
Tuple of created event (FrozenEvent), Context
|
||||||
"""
|
"""
|
||||||
|
yield self.auth.check_auth_blocking(requester.user.to_string())
|
||||||
|
|
||||||
builder = self.event_builder_factory.new(event_dict)
|
builder = self.event_builder_factory.new(event_dict)
|
||||||
|
|
||||||
self.validator.validate_new(builder)
|
self.validator.validate_new(builder)
|
||||||
|
@ -774,11 +778,8 @@ class EventCreationHandler(object):
|
||||||
event, context=context
|
event, context=context
|
||||||
)
|
)
|
||||||
|
|
||||||
# this intentionally does not yield: we don't care about the result
|
self.pusher_pool.on_new_notifications(
|
||||||
# and don't need to wait for it.
|
event_stream_id, max_stream_id,
|
||||||
run_in_background(
|
|
||||||
self.pusher_pool.on_new_notifications,
|
|
||||||
event_stream_id, max_stream_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _notify():
|
def _notify():
|
||||||
|
|
|
@ -395,7 +395,10 @@ class PresenceHandler(object):
|
||||||
"""We've seen the user do something that indicates they're interacting
|
"""We've seen the user do something that indicates they're interacting
|
||||||
with the app.
|
with the app.
|
||||||
"""
|
"""
|
||||||
return
|
# If presence is disabled, no-op
|
||||||
|
if not self.hs.config.use_presence:
|
||||||
|
return
|
||||||
|
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
|
||||||
bump_active_time_counter.inc()
|
bump_active_time_counter.inc()
|
||||||
|
@ -425,7 +428,11 @@ class PresenceHandler(object):
|
||||||
Useful for streams that are not associated with an actual
|
Useful for streams that are not associated with an actual
|
||||||
client that is being used by a user.
|
client that is being used by a user.
|
||||||
"""
|
"""
|
||||||
affect_presence = False
|
# Override if it should affect the user's presence, if presence is
|
||||||
|
# disabled.
|
||||||
|
if not self.hs.config.use_presence:
|
||||||
|
affect_presence = False
|
||||||
|
|
||||||
if affect_presence:
|
if affect_presence:
|
||||||
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
|
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
|
||||||
self.user_to_num_current_syncs[user_id] = curr_sync + 1
|
self.user_to_num_current_syncs[user_id] = curr_sync + 1
|
||||||
|
@ -471,15 +478,16 @@ class PresenceHandler(object):
|
||||||
Returns:
|
Returns:
|
||||||
set(str): A set of user_id strings.
|
set(str): A set of user_id strings.
|
||||||
"""
|
"""
|
||||||
# presence is disabled on matrix.org, so we return the empty set
|
if self.hs.config.use_presence:
|
||||||
return set()
|
syncing_user_ids = {
|
||||||
syncing_user_ids = {
|
user_id for user_id, count in self.user_to_num_current_syncs.items()
|
||||||
user_id for user_id, count in self.user_to_num_current_syncs.items()
|
if count
|
||||||
if count
|
}
|
||||||
}
|
for user_ids in self.external_process_to_current_syncs.values():
|
||||||
for user_ids in self.external_process_to_current_syncs.values():
|
syncing_user_ids.update(user_ids)
|
||||||
syncing_user_ids.update(user_ids)
|
return syncing_user_ids
|
||||||
return syncing_user_ids
|
else:
|
||||||
|
return set()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec):
|
def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec):
|
||||||
|
|
|
@ -18,7 +18,6 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import get_domain_from_id
|
||||||
from synapse.util import logcontext
|
from synapse.util import logcontext
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
@ -116,16 +115,15 @@ class ReceiptsHandler(BaseHandler):
|
||||||
|
|
||||||
affected_room_ids = list(set([r["room_id"] for r in receipts]))
|
affected_room_ids = list(set([r["room_id"] for r in receipts]))
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
self.notifier.on_new_event(
|
||||||
self.notifier.on_new_event(
|
"receipt_key", max_batch_id, rooms=affected_room_ids
|
||||||
"receipt_key", max_batch_id, rooms=affected_room_ids
|
)
|
||||||
)
|
# Note that the min here shouldn't be relied upon to be accurate.
|
||||||
# Note that the min here shouldn't be relied upon to be accurate.
|
self.hs.get_pusherpool().on_new_receipts(
|
||||||
self.hs.get_pusherpool().on_new_receipts(
|
min_batch_id, max_batch_id, affected_room_ids,
|
||||||
min_batch_id, max_batch_id, affected_room_ids
|
)
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue(True)
|
defer.returnValue(True)
|
||||||
|
|
||||||
@logcontext.preserve_fn # caller should not yield on this
|
@logcontext.preserve_fn # caller should not yield on this
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -98,9 +98,13 @@ class RoomCreationHandler(BaseHandler):
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError if the room ID couldn't be stored, or something went
|
SynapseError if the room ID couldn't be stored, or something went
|
||||||
horribly wrong.
|
horribly wrong.
|
||||||
|
ResourceLimitError if server is blocked to some resource being
|
||||||
|
exceeded
|
||||||
"""
|
"""
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
self.auth.check_auth_blocking(user_id)
|
||||||
|
|
||||||
if not self.spam_checker.user_may_create_room(user_id):
|
if not self.spam_checker.user_may_create_room(user_id):
|
||||||
raise SynapseError(403, "You are not permitted to create rooms")
|
raise SynapseError(403, "You are not permitted to create rooms")
|
||||||
|
|
||||||
|
|
|
@ -187,6 +187,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
|
||||||
class SyncHandler(object):
|
class SyncHandler(object):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
self.hs_config = hs.config
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self.presence_handler = hs.get_presence_handler()
|
self.presence_handler = hs.get_presence_handler()
|
||||||
|
@ -864,7 +865,7 @@ class SyncHandler(object):
|
||||||
since_token is None and
|
since_token is None and
|
||||||
sync_config.filter_collection.blocks_all_presence()
|
sync_config.filter_collection.blocks_all_presence()
|
||||||
)
|
)
|
||||||
if False and not block_all_presence_data:
|
if self.hs_config.use_presence and not block_all_presence_data:
|
||||||
yield self._generate_sync_entry_for_presence(
|
yield self._generate_sync_entry_for_presence(
|
||||||
sync_result_builder, newly_joined_rooms, newly_joined_users
|
sync_result_builder, newly_joined_rooms, newly_joined_users
|
||||||
)
|
)
|
||||||
|
|
|
@ -25,8 +25,9 @@ from canonicaljson import encode_canonical_json, encode_pretty_printed_json, jso
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.python import failure
|
from twisted.python import failure
|
||||||
from twisted.web import resource, server
|
from twisted.web import resource
|
||||||
from twisted.web.server import NOT_DONE_YET
|
from twisted.web.server import NOT_DONE_YET
|
||||||
|
from twisted.web.static import NoRangeStaticProducer
|
||||||
from twisted.web.util import redirectTo
|
from twisted.web.util import redirectTo
|
||||||
|
|
||||||
import synapse.events
|
import synapse.events
|
||||||
|
@ -37,10 +38,13 @@ from synapse.api.errors import (
|
||||||
SynapseError,
|
SynapseError,
|
||||||
UnrecognizedRequestError,
|
UnrecognizedRequestError,
|
||||||
)
|
)
|
||||||
from synapse.http.request_metrics import requests_counter
|
|
||||||
from synapse.util.caches import intern_dict
|
from synapse.util.caches import intern_dict
|
||||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
from synapse.util.logcontext import preserve_fn
|
||||||
from synapse.util.metrics import Measure
|
|
||||||
|
if PY3:
|
||||||
|
from io import BytesIO
|
||||||
|
else:
|
||||||
|
from cStringIO import StringIO as BytesIO
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -60,11 +64,10 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
|
||||||
def wrap_json_request_handler(h):
|
def wrap_json_request_handler(h):
|
||||||
"""Wraps a request handler method with exception handling.
|
"""Wraps a request handler method with exception handling.
|
||||||
|
|
||||||
Also adds logging as per wrap_request_handler_with_logging.
|
Also does the wrapping with request.processing as per wrap_async_request_handler.
|
||||||
|
|
||||||
The handler method must have a signature of "handle_foo(self, request)",
|
The handler method must have a signature of "handle_foo(self, request)",
|
||||||
where "self" must have a "clock" attribute (and "request" must be a
|
where "request" must be a SynapseRequest.
|
||||||
SynapseRequest).
|
|
||||||
|
|
||||||
The handler must return a deferred. If the deferred succeeds we assume that
|
The handler must return a deferred. If the deferred succeeds we assume that
|
||||||
a response has been sent. If the deferred fails with a SynapseError we use
|
a response has been sent. If the deferred fails with a SynapseError we use
|
||||||
|
@ -108,24 +111,23 @@ def wrap_json_request_handler(h):
|
||||||
pretty_print=_request_user_agent_is_curl(request),
|
pretty_print=_request_user_agent_is_curl(request),
|
||||||
)
|
)
|
||||||
|
|
||||||
return wrap_request_handler_with_logging(wrapped_request_handler)
|
return wrap_async_request_handler(wrapped_request_handler)
|
||||||
|
|
||||||
|
|
||||||
def wrap_html_request_handler(h):
|
def wrap_html_request_handler(h):
|
||||||
"""Wraps a request handler method with exception handling.
|
"""Wraps a request handler method with exception handling.
|
||||||
|
|
||||||
Also adds logging as per wrap_request_handler_with_logging.
|
Also does the wrapping with request.processing as per wrap_async_request_handler.
|
||||||
|
|
||||||
The handler method must have a signature of "handle_foo(self, request)",
|
The handler method must have a signature of "handle_foo(self, request)",
|
||||||
where "self" must have a "clock" attribute (and "request" must be a
|
where "request" must be a SynapseRequest.
|
||||||
SynapseRequest).
|
|
||||||
"""
|
"""
|
||||||
def wrapped_request_handler(self, request):
|
def wrapped_request_handler(self, request):
|
||||||
d = defer.maybeDeferred(h, self, request)
|
d = defer.maybeDeferred(h, self, request)
|
||||||
d.addErrback(_return_html_error, request)
|
d.addErrback(_return_html_error, request)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
return wrap_request_handler_with_logging(wrapped_request_handler)
|
return wrap_async_request_handler(wrapped_request_handler)
|
||||||
|
|
||||||
|
|
||||||
def _return_html_error(f, request):
|
def _return_html_error(f, request):
|
||||||
|
@ -170,46 +172,26 @@ def _return_html_error(f, request):
|
||||||
finish_request(request)
|
finish_request(request)
|
||||||
|
|
||||||
|
|
||||||
def wrap_request_handler_with_logging(h):
|
def wrap_async_request_handler(h):
|
||||||
"""Wraps a request handler to provide logging and metrics
|
"""Wraps an async request handler so that it calls request.processing.
|
||||||
|
|
||||||
|
This helps ensure that work done by the request handler after the request is completed
|
||||||
|
is correctly recorded against the request metrics/logs.
|
||||||
|
|
||||||
The handler method must have a signature of "handle_foo(self, request)",
|
The handler method must have a signature of "handle_foo(self, request)",
|
||||||
where "self" must have a "clock" attribute (and "request" must be a
|
where "request" must be a SynapseRequest.
|
||||||
SynapseRequest).
|
|
||||||
|
|
||||||
As well as calling `request.processing` (which will log the response and
|
The handler may return a deferred, in which case the completion of the request isn't
|
||||||
duration for this request), the wrapped request handler will insert the
|
logged until the deferred completes.
|
||||||
request id into the logging context.
|
|
||||||
"""
|
"""
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def wrapped_request_handler(self, request):
|
def wrapped_async_request_handler(self, request):
|
||||||
"""
|
with request.processing():
|
||||||
Args:
|
yield h(self, request)
|
||||||
self:
|
|
||||||
request (synapse.http.site.SynapseRequest):
|
|
||||||
"""
|
|
||||||
|
|
||||||
request_id = request.get_request_id()
|
# we need to preserve_fn here, because the synchronous render method won't yield for
|
||||||
with LoggingContext(request_id) as request_context:
|
# us (obviously)
|
||||||
request_context.request = request_id
|
return preserve_fn(wrapped_async_request_handler)
|
||||||
with Measure(self.clock, "wrapped_request_handler"):
|
|
||||||
# we start the request metrics timer here with an initial stab
|
|
||||||
# at the servlet name. For most requests that name will be
|
|
||||||
# JsonResource (or a subclass), and JsonResource._async_render
|
|
||||||
# will update it once it picks a servlet.
|
|
||||||
servlet_name = self.__class__.__name__
|
|
||||||
with request.processing(servlet_name):
|
|
||||||
with PreserveLoggingContext(request_context):
|
|
||||||
d = defer.maybeDeferred(h, self, request)
|
|
||||||
|
|
||||||
# record the arrival of the request *after*
|
|
||||||
# dispatching to the handler, so that the handler
|
|
||||||
# can update the servlet name in the request
|
|
||||||
# metrics
|
|
||||||
requests_counter.labels(request.method,
|
|
||||||
request.request_metrics.name).inc()
|
|
||||||
yield d
|
|
||||||
return wrapped_request_handler
|
|
||||||
|
|
||||||
|
|
||||||
class HttpServer(object):
|
class HttpServer(object):
|
||||||
|
@ -272,7 +254,7 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
""" This gets called by twisted every time someone sends us a request.
|
""" This gets called by twisted every time someone sends us a request.
|
||||||
"""
|
"""
|
||||||
self._async_render(request)
|
self._async_render(request)
|
||||||
return server.NOT_DONE_YET
|
return NOT_DONE_YET
|
||||||
|
|
||||||
@wrap_json_request_handler
|
@wrap_json_request_handler
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -413,8 +395,7 @@ def respond_with_json(request, code, json_object, send_cors=False,
|
||||||
return
|
return
|
||||||
|
|
||||||
if pretty_print:
|
if pretty_print:
|
||||||
json_bytes = (encode_pretty_printed_json(json_object) + "\n"
|
json_bytes = encode_pretty_printed_json(json_object) + b"\n"
|
||||||
).encode("utf-8")
|
|
||||||
else:
|
else:
|
||||||
if canonical_json or synapse.events.USE_FROZEN_DICTS:
|
if canonical_json or synapse.events.USE_FROZEN_DICTS:
|
||||||
# canonicaljson already encodes to bytes
|
# canonicaljson already encodes to bytes
|
||||||
|
@ -450,8 +431,12 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
|
||||||
if send_cors:
|
if send_cors:
|
||||||
set_cors_headers(request)
|
set_cors_headers(request)
|
||||||
|
|
||||||
request.write(json_bytes)
|
# todo: we can almost certainly avoid this copy and encode the json straight into
|
||||||
finish_request(request)
|
# the bytesIO, but it would involve faffing around with string->bytes wrappers.
|
||||||
|
bytes_io = BytesIO(json_bytes)
|
||||||
|
|
||||||
|
producer = NoRangeStaticProducer(request, bytes_io)
|
||||||
|
producer.start()
|
||||||
return NOT_DONE_YET
|
return NOT_DONE_YET
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,6 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
@ -19,8 +18,8 @@ import time
|
||||||
from twisted.web.server import Request, Site
|
from twisted.web.server import Request, Site
|
||||||
|
|
||||||
from synapse.http import redact_uri
|
from synapse.http import redact_uri
|
||||||
from synapse.http.request_metrics import RequestMetrics
|
from synapse.http.request_metrics import RequestMetrics, requests_counter
|
||||||
from synapse.util.logcontext import ContextResourceUsage, LoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -34,25 +33,43 @@ class SynapseRequest(Request):
|
||||||
|
|
||||||
It extends twisted's twisted.web.server.Request, and adds:
|
It extends twisted's twisted.web.server.Request, and adds:
|
||||||
* Unique request ID
|
* Unique request ID
|
||||||
|
* A log context associated with the request
|
||||||
* Redaction of access_token query-params in __repr__
|
* Redaction of access_token query-params in __repr__
|
||||||
* Logging at start and end
|
* Logging at start and end
|
||||||
* Metrics to record CPU, wallclock and DB time by endpoint.
|
* Metrics to record CPU, wallclock and DB time by endpoint.
|
||||||
|
|
||||||
It provides a method `processing` which should be called by the Resource
|
It also provides a method `processing`, which returns a context manager. If this
|
||||||
which is handling the request, and returns a context manager.
|
method is called, the request won't be logged until the context manager is closed;
|
||||||
|
this is useful for asynchronous request handlers which may go on processing the
|
||||||
|
request even after the client has disconnected.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
logcontext(LoggingContext) : the log context for this request
|
||||||
"""
|
"""
|
||||||
def __init__(self, site, channel, *args, **kw):
|
def __init__(self, site, channel, *args, **kw):
|
||||||
Request.__init__(self, channel, *args, **kw)
|
Request.__init__(self, channel, *args, **kw)
|
||||||
self.site = site
|
self.site = site
|
||||||
self._channel = channel
|
self._channel = channel # this is used by the tests
|
||||||
self.authenticated_entity = None
|
self.authenticated_entity = None
|
||||||
self.start_time = 0
|
self.start_time = 0
|
||||||
|
|
||||||
|
# we can't yet create the logcontext, as we don't know the method.
|
||||||
|
self.logcontext = None
|
||||||
|
|
||||||
global _next_request_seq
|
global _next_request_seq
|
||||||
self.request_seq = _next_request_seq
|
self.request_seq = _next_request_seq
|
||||||
_next_request_seq += 1
|
_next_request_seq += 1
|
||||||
|
|
||||||
|
# whether an asynchronous request handler has called processing()
|
||||||
|
self._is_processing = False
|
||||||
|
|
||||||
|
# the time when the asynchronous request handler completed its processing
|
||||||
|
self._processing_finished_time = None
|
||||||
|
|
||||||
|
# what time we finished sending the response to the client (or the connection
|
||||||
|
# dropped)
|
||||||
|
self.finish_time = None
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
# We overwrite this so that we don't log ``access_token``
|
# We overwrite this so that we don't log ``access_token``
|
||||||
return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % (
|
return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % (
|
||||||
|
@ -74,11 +91,116 @@ class SynapseRequest(Request):
|
||||||
return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
|
return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
|
||||||
|
|
||||||
def render(self, resrc):
|
def render(self, resrc):
|
||||||
|
# this is called once a Resource has been found to serve the request; in our
|
||||||
|
# case the Resource in question will normally be a JsonResource.
|
||||||
|
|
||||||
|
# create a LogContext for this request
|
||||||
|
request_id = self.get_request_id()
|
||||||
|
logcontext = self.logcontext = LoggingContext(request_id)
|
||||||
|
logcontext.request = request_id
|
||||||
|
|
||||||
# override the Server header which is set by twisted
|
# override the Server header which is set by twisted
|
||||||
self.setHeader("Server", self.site.server_version_string)
|
self.setHeader("Server", self.site.server_version_string)
|
||||||
return Request.render(self, resrc)
|
|
||||||
|
with PreserveLoggingContext(self.logcontext):
|
||||||
|
# we start the request metrics timer here with an initial stab
|
||||||
|
# at the servlet name. For most requests that name will be
|
||||||
|
# JsonResource (or a subclass), and JsonResource._async_render
|
||||||
|
# will update it once it picks a servlet.
|
||||||
|
servlet_name = resrc.__class__.__name__
|
||||||
|
self._started_processing(servlet_name)
|
||||||
|
|
||||||
|
Request.render(self, resrc)
|
||||||
|
|
||||||
|
# record the arrival of the request *after*
|
||||||
|
# dispatching to the handler, so that the handler
|
||||||
|
# can update the servlet name in the request
|
||||||
|
# metrics
|
||||||
|
requests_counter.labels(self.method,
|
||||||
|
self.request_metrics.name).inc()
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def processing(self):
|
||||||
|
"""Record the fact that we are processing this request.
|
||||||
|
|
||||||
|
Returns a context manager; the correct way to use this is:
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def handle_request(request):
|
||||||
|
with request.processing("FooServlet"):
|
||||||
|
yield really_handle_the_request()
|
||||||
|
|
||||||
|
Once the context manager is closed, the completion of the request will be logged,
|
||||||
|
and the various metrics will be updated.
|
||||||
|
"""
|
||||||
|
if self._is_processing:
|
||||||
|
raise RuntimeError("Request is already processing")
|
||||||
|
self._is_processing = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
except Exception:
|
||||||
|
# this should already have been caught, and sent back to the client as a 500.
|
||||||
|
logger.exception("Asynchronous messge handler raised an uncaught exception")
|
||||||
|
finally:
|
||||||
|
# the request handler has finished its work and either sent the whole response
|
||||||
|
# back, or handed over responsibility to a Producer.
|
||||||
|
|
||||||
|
self._processing_finished_time = time.time()
|
||||||
|
self._is_processing = False
|
||||||
|
|
||||||
|
# if we've already sent the response, log it now; otherwise, we wait for the
|
||||||
|
# response to be sent.
|
||||||
|
if self.finish_time is not None:
|
||||||
|
self._finished_processing()
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
"""Called when all response data has been written to this Request.
|
||||||
|
|
||||||
|
Overrides twisted.web.server.Request.finish to record the finish time and do
|
||||||
|
logging.
|
||||||
|
"""
|
||||||
|
self.finish_time = time.time()
|
||||||
|
Request.finish(self)
|
||||||
|
if not self._is_processing:
|
||||||
|
with PreserveLoggingContext(self.logcontext):
|
||||||
|
self._finished_processing()
|
||||||
|
|
||||||
|
def connectionLost(self, reason):
|
||||||
|
"""Called when the client connection is closed before the response is written.
|
||||||
|
|
||||||
|
Overrides twisted.web.server.Request.connectionLost to record the finish time and
|
||||||
|
do logging.
|
||||||
|
"""
|
||||||
|
self.finish_time = time.time()
|
||||||
|
Request.connectionLost(self, reason)
|
||||||
|
|
||||||
|
# we only get here if the connection to the client drops before we send
|
||||||
|
# the response.
|
||||||
|
#
|
||||||
|
# It's useful to log it here so that we can get an idea of when
|
||||||
|
# the client disconnects.
|
||||||
|
with PreserveLoggingContext(self.logcontext):
|
||||||
|
logger.warn(
|
||||||
|
"Error processing request: %s %s", reason.type, reason.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self._is_processing:
|
||||||
|
self._finished_processing()
|
||||||
|
|
||||||
def _started_processing(self, servlet_name):
|
def _started_processing(self, servlet_name):
|
||||||
|
"""Record the fact that we are processing this request.
|
||||||
|
|
||||||
|
This will log the request's arrival. Once the request completes,
|
||||||
|
be sure to call finished_processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
servlet_name (str): the name of the servlet which will be
|
||||||
|
processing this request. This is used in the metrics.
|
||||||
|
|
||||||
|
It is possible to update this afterwards by updating
|
||||||
|
self.request_metrics.name.
|
||||||
|
"""
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
self.request_metrics = RequestMetrics()
|
self.request_metrics = RequestMetrics()
|
||||||
self.request_metrics.start(
|
self.request_metrics.start(
|
||||||
|
@ -94,13 +216,21 @@ class SynapseRequest(Request):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _finished_processing(self):
|
def _finished_processing(self):
|
||||||
try:
|
"""Log the completion of this request and update the metrics
|
||||||
context = LoggingContext.current_context()
|
"""
|
||||||
usage = context.get_resource_usage()
|
|
||||||
except Exception:
|
|
||||||
usage = ContextResourceUsage()
|
|
||||||
|
|
||||||
end_time = time.time()
|
usage = self.logcontext.get_resource_usage()
|
||||||
|
|
||||||
|
if self._processing_finished_time is None:
|
||||||
|
# we completed the request without anything calling processing()
|
||||||
|
self._processing_finished_time = time.time()
|
||||||
|
|
||||||
|
# the time between receiving the request and the request handler finishing
|
||||||
|
processing_time = self._processing_finished_time - self.start_time
|
||||||
|
|
||||||
|
# the time between the request handler finishing and the response being sent
|
||||||
|
# to the client (nb may be negative)
|
||||||
|
response_send_time = self.finish_time - self._processing_finished_time
|
||||||
|
|
||||||
# need to decode as it could be raw utf-8 bytes
|
# need to decode as it could be raw utf-8 bytes
|
||||||
# from a IDN servname in an auth header
|
# from a IDN servname in an auth header
|
||||||
|
@ -116,22 +246,31 @@ class SynapseRequest(Request):
|
||||||
user_agent = self.get_user_agent()
|
user_agent = self.get_user_agent()
|
||||||
if user_agent is not None:
|
if user_agent is not None:
|
||||||
user_agent = user_agent.decode("utf-8", "replace")
|
user_agent = user_agent.decode("utf-8", "replace")
|
||||||
|
else:
|
||||||
|
user_agent = "-"
|
||||||
|
|
||||||
|
code = str(self.code)
|
||||||
|
if not self.finished:
|
||||||
|
# we didn't send the full response before we gave up (presumably because
|
||||||
|
# the connection dropped)
|
||||||
|
code += "!"
|
||||||
|
|
||||||
self.site.access_logger.info(
|
self.site.access_logger.info(
|
||||||
"%s - %s - {%s}"
|
"%s - %s - {%s}"
|
||||||
" Processed request: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
|
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
|
||||||
" %sB %s \"%s %s %s\" \"%s\" [%d dbevts]",
|
" %sB %s \"%s %s %s\" \"%s\" [%d dbevts]",
|
||||||
self.getClientIP(),
|
self.getClientIP(),
|
||||||
self.site.site_tag,
|
self.site.site_tag,
|
||||||
authenticated_entity,
|
authenticated_entity,
|
||||||
end_time - self.start_time,
|
processing_time,
|
||||||
|
response_send_time,
|
||||||
usage.ru_utime,
|
usage.ru_utime,
|
||||||
usage.ru_stime,
|
usage.ru_stime,
|
||||||
usage.db_sched_duration_sec,
|
usage.db_sched_duration_sec,
|
||||||
usage.db_txn_duration_sec,
|
usage.db_txn_duration_sec,
|
||||||
int(usage.db_txn_count),
|
int(usage.db_txn_count),
|
||||||
self.sentLength,
|
self.sentLength,
|
||||||
self.code,
|
code,
|
||||||
self.method,
|
self.method,
|
||||||
self.get_redacted_uri(),
|
self.get_redacted_uri(),
|
||||||
self.clientproto,
|
self.clientproto,
|
||||||
|
@ -140,38 +279,10 @@ class SynapseRequest(Request):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.request_metrics.stop(end_time, self)
|
self.request_metrics.stop(self.finish_time, self)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warn("Failed to stop metrics: %r", e)
|
logger.warn("Failed to stop metrics: %r", e)
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def processing(self, servlet_name):
|
|
||||||
"""Record the fact that we are processing this request.
|
|
||||||
|
|
||||||
Returns a context manager; the correct way to use this is:
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def handle_request(request):
|
|
||||||
with request.processing("FooServlet"):
|
|
||||||
yield really_handle_the_request()
|
|
||||||
|
|
||||||
This will log the request's arrival. Once the context manager is
|
|
||||||
closed, the completion of the request will be logged, and the various
|
|
||||||
metrics will be updated.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
servlet_name (str): the name of the servlet which will be
|
|
||||||
processing this request. This is used in the metrics.
|
|
||||||
|
|
||||||
It is possible to update this afterwards by updating
|
|
||||||
self.request_metrics.servlet_name.
|
|
||||||
"""
|
|
||||||
# TODO: we should probably just move this into render() and finish(),
|
|
||||||
# to save having to call a separate method.
|
|
||||||
self._started_processing(servlet_name)
|
|
||||||
yield
|
|
||||||
self._finished_processing()
|
|
||||||
|
|
||||||
|
|
||||||
class XForwardedForRequest(SynapseRequest):
|
class XForwardedForRequest(SynapseRequest):
|
||||||
def __init__(self, *args, **kw):
|
def __init__(self, *args, **kw):
|
||||||
|
|
|
@ -18,6 +18,7 @@ import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.push.pusher import PusherFactory
|
from synapse.push.pusher import PusherFactory
|
||||||
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
||||||
|
|
||||||
|
@ -122,8 +123,14 @@ class PusherPool:
|
||||||
p['app_id'], p['pushkey'], p['user_name'],
|
p['app_id'], p['pushkey'], p['user_name'],
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def on_new_notifications(self, min_stream_id, max_stream_id):
|
def on_new_notifications(self, min_stream_id, max_stream_id):
|
||||||
|
run_as_background_process(
|
||||||
|
"on_new_notifications",
|
||||||
|
self._on_new_notifications, min_stream_id, max_stream_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _on_new_notifications(self, min_stream_id, max_stream_id):
|
||||||
try:
|
try:
|
||||||
users_affected = yield self.store.get_push_action_users_in_range(
|
users_affected = yield self.store.get_push_action_users_in_range(
|
||||||
min_stream_id, max_stream_id
|
min_stream_id, max_stream_id
|
||||||
|
@ -147,8 +154,14 @@ class PusherPool:
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Exception in pusher on_new_notifications")
|
logger.exception("Exception in pusher on_new_notifications")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
|
def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
|
||||||
|
run_as_background_process(
|
||||||
|
"on_new_receipts",
|
||||||
|
self._on_new_receipts, min_stream_id, max_stream_id, affected_room_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
|
||||||
try:
|
try:
|
||||||
# Need to subtract 1 from the minimum because the lower bound here
|
# Need to subtract 1 from the minimum because the lower bound here
|
||||||
# is not inclusive
|
# is not inclusive
|
||||||
|
|
|
@ -156,7 +156,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
|
||||||
edu_content = content["content"]
|
edu_content = content["content"]
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Got %r edu from $s",
|
"Got %r edu from %s",
|
||||||
edu_type, origin,
|
edu_type, origin,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -107,7 +107,7 @@ class ReplicationClientHandler(object):
|
||||||
Can be overriden in subclasses to handle more.
|
Can be overriden in subclasses to handle more.
|
||||||
"""
|
"""
|
||||||
logger.info("Received rdata %s -> %s", stream_name, token)
|
logger.info("Received rdata %s -> %s", stream_name, token)
|
||||||
self.store.process_replication_rows(stream_name, token, rows)
|
return self.store.process_replication_rows(stream_name, token, rows)
|
||||||
|
|
||||||
def on_position(self, stream_name, token):
|
def on_position(self, stream_name, token):
|
||||||
"""Called when we get new position data. By default this just pokes
|
"""Called when we get new position data. By default this just pokes
|
||||||
|
@ -115,7 +115,7 @@ class ReplicationClientHandler(object):
|
||||||
|
|
||||||
Can be overriden in subclasses to handle more.
|
Can be overriden in subclasses to handle more.
|
||||||
"""
|
"""
|
||||||
self.store.process_replication_rows(stream_name, token, [])
|
return self.store.process_replication_rows(stream_name, token, [])
|
||||||
|
|
||||||
def on_sync(self, data):
|
def on_sync(self, data):
|
||||||
"""When we received a SYNC we wake up any deferreds that were waiting
|
"""When we received a SYNC we wake up any deferreds that were waiting
|
||||||
|
|
|
@ -59,6 +59,12 @@ class Command(object):
|
||||||
"""
|
"""
|
||||||
return self.data
|
return self.data
|
||||||
|
|
||||||
|
def get_logcontext_id(self):
|
||||||
|
"""Get a suitable string for the logcontext when processing this command"""
|
||||||
|
|
||||||
|
# by default, we just use the command name.
|
||||||
|
return self.NAME
|
||||||
|
|
||||||
|
|
||||||
class ServerCommand(Command):
|
class ServerCommand(Command):
|
||||||
"""Sent by the server on new connection and includes the server_name.
|
"""Sent by the server on new connection and includes the server_name.
|
||||||
|
@ -116,6 +122,9 @@ class RdataCommand(Command):
|
||||||
_json_encoder.encode(self.row),
|
_json_encoder.encode(self.row),
|
||||||
))
|
))
|
||||||
|
|
||||||
|
def get_logcontext_id(self):
|
||||||
|
return "RDATA-" + self.stream_name
|
||||||
|
|
||||||
|
|
||||||
class PositionCommand(Command):
|
class PositionCommand(Command):
|
||||||
"""Sent by the client to tell the client the stream postition without
|
"""Sent by the client to tell the client the stream postition without
|
||||||
|
@ -190,6 +199,9 @@ class ReplicateCommand(Command):
|
||||||
def to_line(self):
|
def to_line(self):
|
||||||
return " ".join((self.stream_name, str(self.token),))
|
return " ".join((self.stream_name, str(self.token),))
|
||||||
|
|
||||||
|
def get_logcontext_id(self):
|
||||||
|
return "REPLICATE-" + self.stream_name
|
||||||
|
|
||||||
|
|
||||||
class UserSyncCommand(Command):
|
class UserSyncCommand(Command):
|
||||||
"""Sent by the client to inform the server that a user has started or
|
"""Sent by the client to inform the server that a user has started or
|
||||||
|
|
|
@ -63,6 +63,8 @@ from twisted.protocols.basic import LineOnlyReceiver
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
from synapse.metrics import LaterGauge
|
from synapse.metrics import LaterGauge
|
||||||
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
|
||||||
from .commands import (
|
from .commands import (
|
||||||
|
@ -222,7 +224,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||||
|
|
||||||
# Now lets try and call on_<CMD_NAME> function
|
# Now lets try and call on_<CMD_NAME> function
|
||||||
try:
|
try:
|
||||||
getattr(self, "on_%s" % (cmd_name,))(cmd)
|
run_as_background_process(
|
||||||
|
"replication-" + cmd.get_logcontext_id(),
|
||||||
|
getattr(self, "on_%s" % (cmd_name,)),
|
||||||
|
cmd,
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("[%s] Failed to handle line: %r", self.id(), line)
|
logger.exception("[%s] Failed to handle line: %r", self.id(), line)
|
||||||
|
|
||||||
|
@ -387,7 +393,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||||
self.name = cmd.data
|
self.name = cmd.data
|
||||||
|
|
||||||
def on_USER_SYNC(self, cmd):
|
def on_USER_SYNC(self, cmd):
|
||||||
self.streamer.on_user_sync(
|
return self.streamer.on_user_sync(
|
||||||
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms,
|
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -397,22 +403,33 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||||
|
|
||||||
if stream_name == "ALL":
|
if stream_name == "ALL":
|
||||||
# Subscribe to all streams we're publishing to.
|
# Subscribe to all streams we're publishing to.
|
||||||
for stream in iterkeys(self.streamer.streams_by_name):
|
deferreds = [
|
||||||
self.subscribe_to_stream(stream, token)
|
run_in_background(
|
||||||
|
self.subscribe_to_stream,
|
||||||
|
stream, token,
|
||||||
|
)
|
||||||
|
for stream in iterkeys(self.streamer.streams_by_name)
|
||||||
|
]
|
||||||
|
|
||||||
|
return make_deferred_yieldable(
|
||||||
|
defer.gatherResults(deferreds, consumeErrors=True)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.subscribe_to_stream(stream_name, token)
|
return self.subscribe_to_stream(stream_name, token)
|
||||||
|
|
||||||
def on_FEDERATION_ACK(self, cmd):
|
def on_FEDERATION_ACK(self, cmd):
|
||||||
self.streamer.federation_ack(cmd.token)
|
return self.streamer.federation_ack(cmd.token)
|
||||||
|
|
||||||
def on_REMOVE_PUSHER(self, cmd):
|
def on_REMOVE_PUSHER(self, cmd):
|
||||||
self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
|
return self.streamer.on_remove_pusher(
|
||||||
|
cmd.app_id, cmd.push_key, cmd.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
def on_INVALIDATE_CACHE(self, cmd):
|
def on_INVALIDATE_CACHE(self, cmd):
|
||||||
self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
|
return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
|
||||||
|
|
||||||
def on_USER_IP(self, cmd):
|
def on_USER_IP(self, cmd):
|
||||||
self.streamer.on_user_ip(
|
return self.streamer.on_user_ip(
|
||||||
cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
|
cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
|
||||||
cmd.last_seen,
|
cmd.last_seen,
|
||||||
)
|
)
|
||||||
|
@ -542,14 +559,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||||
# Check if this is the last of a batch of updates
|
# Check if this is the last of a batch of updates
|
||||||
rows = self.pending_batches.pop(stream_name, [])
|
rows = self.pending_batches.pop(stream_name, [])
|
||||||
rows.append(row)
|
rows.append(row)
|
||||||
|
return self.handler.on_rdata(stream_name, cmd.token, rows)
|
||||||
self.handler.on_rdata(stream_name, cmd.token, rows)
|
|
||||||
|
|
||||||
def on_POSITION(self, cmd):
|
def on_POSITION(self, cmd):
|
||||||
self.handler.on_position(cmd.stream_name, cmd.token)
|
return self.handler.on_position(cmd.stream_name, cmd.token)
|
||||||
|
|
||||||
def on_SYNC(self, cmd):
|
def on_SYNC(self, cmd):
|
||||||
self.handler.on_sync(cmd.data)
|
return self.handler.on_sync(cmd.data)
|
||||||
|
|
||||||
def replicate(self, stream_name, token):
|
def replicate(self, stream_name, token):
|
||||||
"""Send the subscription request to the server
|
"""Send the subscription request to the server
|
||||||
|
|
|
@ -84,7 +84,8 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
|
||||||
except Exception:
|
except Exception:
|
||||||
raise SynapseError(400, "Unable to parse state")
|
raise SynapseError(400, "Unable to parse state")
|
||||||
|
|
||||||
# yield self.presence_handler.set_state(user, state)
|
if self.hs.config.use_presence:
|
||||||
|
yield self.presence_handler.set_state(user, state)
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
|
@ -129,12 +129,9 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
login_type = register_json["type"]
|
login_type = register_json["type"]
|
||||||
|
|
||||||
is_application_server = login_type == LoginType.APPLICATION_SERVICE
|
is_application_server = login_type == LoginType.APPLICATION_SERVICE
|
||||||
is_using_shared_secret = login_type == LoginType.SHARED_SECRET
|
|
||||||
|
|
||||||
can_register = (
|
can_register = (
|
||||||
self.enable_registration
|
self.enable_registration
|
||||||
or is_application_server
|
or is_application_server
|
||||||
or is_using_shared_secret
|
|
||||||
)
|
)
|
||||||
if not can_register:
|
if not can_register:
|
||||||
raise SynapseError(403, "Registration has been disabled")
|
raise SynapseError(403, "Registration has been disabled")
|
||||||
|
@ -144,7 +141,6 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
LoginType.PASSWORD: self._do_password,
|
LoginType.PASSWORD: self._do_password,
|
||||||
LoginType.EMAIL_IDENTITY: self._do_email_identity,
|
LoginType.EMAIL_IDENTITY: self._do_email_identity,
|
||||||
LoginType.APPLICATION_SERVICE: self._do_app_service,
|
LoginType.APPLICATION_SERVICE: self._do_app_service,
|
||||||
LoginType.SHARED_SECRET: self._do_shared_secret,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
session_info = self._get_session_info(request, session)
|
session_info = self._get_session_info(request, session)
|
||||||
|
@ -325,56 +321,6 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
})
|
})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _do_shared_secret(self, request, register_json, session):
|
|
||||||
assert_params_in_dict(register_json, ["mac", "user", "password"])
|
|
||||||
|
|
||||||
if not self.hs.config.registration_shared_secret:
|
|
||||||
raise SynapseError(400, "Shared secret registration is not enabled")
|
|
||||||
|
|
||||||
user = register_json["user"].encode("utf-8")
|
|
||||||
password = register_json["password"].encode("utf-8")
|
|
||||||
admin = register_json.get("admin", None)
|
|
||||||
|
|
||||||
# Its important to check as we use null bytes as HMAC field separators
|
|
||||||
if b"\x00" in user:
|
|
||||||
raise SynapseError(400, "Invalid user")
|
|
||||||
if b"\x00" in password:
|
|
||||||
raise SynapseError(400, "Invalid password")
|
|
||||||
|
|
||||||
# str() because otherwise hmac complains that 'unicode' does not
|
|
||||||
# have the buffer interface
|
|
||||||
got_mac = str(register_json["mac"])
|
|
||||||
|
|
||||||
want_mac = hmac.new(
|
|
||||||
key=self.hs.config.registration_shared_secret.encode(),
|
|
||||||
digestmod=sha1,
|
|
||||||
)
|
|
||||||
want_mac.update(user)
|
|
||||||
want_mac.update(b"\x00")
|
|
||||||
want_mac.update(password)
|
|
||||||
want_mac.update(b"\x00")
|
|
||||||
want_mac.update(b"admin" if admin else b"notadmin")
|
|
||||||
want_mac = want_mac.hexdigest()
|
|
||||||
|
|
||||||
if compare_digest(want_mac, got_mac):
|
|
||||||
handler = self.handlers.registration_handler
|
|
||||||
user_id, token = yield handler.register(
|
|
||||||
localpart=user.lower(),
|
|
||||||
password=password,
|
|
||||||
admin=bool(admin),
|
|
||||||
)
|
|
||||||
self._remove_session(session)
|
|
||||||
defer.returnValue({
|
|
||||||
"user_id": user_id,
|
|
||||||
"access_token": token,
|
|
||||||
"home_server": self.hs.hostname,
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
raise SynapseError(
|
|
||||||
403, "HMAC incorrect",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CreateUserRestServlet(ClientV1RestServlet):
|
class CreateUserRestServlet(ClientV1RestServlet):
|
||||||
"""Handles user creation via a server-to-server interface
|
"""Handles user creation via a server-to-server interface
|
||||||
|
|
|
@ -96,7 +96,10 @@ class MonthlyActiveUsersStore(SQLBaseStore):
|
||||||
# While Postgres does not require 'LIMIT', but also does not support
|
# While Postgres does not require 'LIMIT', but also does not support
|
||||||
# negative LIMIT values. So there is no way to write it that both can
|
# negative LIMIT values. So there is no way to write it that both can
|
||||||
# support
|
# support
|
||||||
query_args = [self.hs.config.max_mau_value]
|
safe_guard = self.hs.config.max_mau_value - len(self.reserved_users)
|
||||||
|
# Must be greater than zero for postgres
|
||||||
|
safe_guard = safe_guard if safe_guard > 0 else 0
|
||||||
|
query_args = [safe_guard]
|
||||||
|
|
||||||
base_sql = """
|
base_sql = """
|
||||||
DELETE FROM monthly_active_users
|
DELETE FROM monthly_active_users
|
||||||
|
|
|
@ -21,7 +21,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
import synapse.handlers.auth
|
import synapse.handlers.auth
|
||||||
from synapse.api.auth import Auth
|
from synapse.api.auth import Auth
|
||||||
from synapse.api.errors import AuthError, Codes
|
from synapse.api.errors import AuthError, Codes, ResourceLimitError
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
@ -455,7 +455,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
return_value=defer.succeed(lots_of_users)
|
return_value=defer.succeed(lots_of_users)
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.assertRaises(AuthError) as e:
|
with self.assertRaises(ResourceLimitError) as e:
|
||||||
yield self.auth.check_auth_blocking()
|
yield self.auth.check_auth_blocking()
|
||||||
self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
|
self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
|
||||||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
|
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
|
||||||
|
@ -471,7 +471,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
def test_hs_disabled(self):
|
def test_hs_disabled(self):
|
||||||
self.hs.config.hs_disabled = True
|
self.hs.config.hs_disabled = True
|
||||||
self.hs.config.hs_disabled_message = "Reason for being disabled"
|
self.hs.config.hs_disabled_message = "Reason for being disabled"
|
||||||
with self.assertRaises(AuthError) as e:
|
with self.assertRaises(ResourceLimitError) as e:
|
||||||
yield self.auth.check_auth_blocking()
|
yield self.auth.check_auth_blocking()
|
||||||
self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
|
self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
|
||||||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
|
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
|
||||||
|
|
0
tests/app/__init__.py
Normal file
0
tests/app/__init__.py
Normal file
88
tests/app/test_frontend_proxy.py
Normal file
88
tests/app/test_frontend_proxy.py
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.app.frontend_proxy import FrontendProxyServer
|
||||||
|
|
||||||
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class FrontendProxyTests(HomeserverTestCase):
|
||||||
|
def make_homeserver(self, reactor, clock):
|
||||||
|
|
||||||
|
hs = self.setup_test_homeserver(
|
||||||
|
http_client=None, homeserverToUse=FrontendProxyServer
|
||||||
|
)
|
||||||
|
|
||||||
|
return hs
|
||||||
|
|
||||||
|
def test_listen_http_with_presence_enabled(self):
|
||||||
|
"""
|
||||||
|
When presence is on, the stub servlet will not register.
|
||||||
|
"""
|
||||||
|
# Presence is on
|
||||||
|
self.hs.config.use_presence = True
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"port": 8080,
|
||||||
|
"bind_addresses": ["0.0.0.0"],
|
||||||
|
"resources": [{"names": ["client"]}],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Listen with the config
|
||||||
|
self.hs._listen_http(config)
|
||||||
|
|
||||||
|
# Grab the resource from the site that was told to listen
|
||||||
|
self.assertEqual(len(self.reactor.tcpServers), 1)
|
||||||
|
site = self.reactor.tcpServers[0][1]
|
||||||
|
self.resource = (
|
||||||
|
site.resource.children["_matrix"].children["client"].children["r0"]
|
||||||
|
)
|
||||||
|
|
||||||
|
request, channel = self.make_request("PUT", "presence/a/status")
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
# 400 + unrecognised, because nothing is registered
|
||||||
|
self.assertEqual(channel.code, 400)
|
||||||
|
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
|
||||||
|
|
||||||
|
def test_listen_http_with_presence_disabled(self):
|
||||||
|
"""
|
||||||
|
When presence is on, the stub servlet will register.
|
||||||
|
"""
|
||||||
|
# Presence is off
|
||||||
|
self.hs.config.use_presence = False
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"port": 8080,
|
||||||
|
"bind_addresses": ["0.0.0.0"],
|
||||||
|
"resources": [{"names": ["client"]}],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Listen with the config
|
||||||
|
self.hs._listen_http(config)
|
||||||
|
|
||||||
|
# Grab the resource from the site that was told to listen
|
||||||
|
self.assertEqual(len(self.reactor.tcpServers), 1)
|
||||||
|
site = self.reactor.tcpServers[0][1]
|
||||||
|
self.resource = (
|
||||||
|
site.resource.children["_matrix"].children["client"].children["r0"]
|
||||||
|
)
|
||||||
|
|
||||||
|
request, channel = self.make_request("PUT", "presence/a/status")
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
# 401, because the stub servlet still checks authentication
|
||||||
|
self.assertEqual(channel.code, 401)
|
||||||
|
self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN")
|
|
@ -20,7 +20,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
import synapse
|
import synapse
|
||||||
import synapse.api.errors
|
import synapse.api.errors
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import ResourceLimitError
|
||||||
from synapse.handlers.auth import AuthHandler
|
from synapse.handlers.auth import AuthHandler
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
@ -130,13 +130,13 @@ class AuthTestCase(unittest.TestCase):
|
||||||
return_value=defer.succeed(self.large_number_of_users)
|
return_value=defer.succeed(self.large_number_of_users)
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.assertRaises(AuthError):
|
with self.assertRaises(ResourceLimitError):
|
||||||
yield self.auth_handler.get_access_token_for_user_id('user_a')
|
yield self.auth_handler.get_access_token_for_user_id('user_a')
|
||||||
|
|
||||||
self.hs.get_datastore().get_monthly_active_count = Mock(
|
self.hs.get_datastore().get_monthly_active_count = Mock(
|
||||||
return_value=defer.succeed(self.large_number_of_users)
|
return_value=defer.succeed(self.large_number_of_users)
|
||||||
)
|
)
|
||||||
with self.assertRaises(AuthError):
|
with self.assertRaises(ResourceLimitError):
|
||||||
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
self._get_macaroon().serialize()
|
self._get_macaroon().serialize()
|
||||||
)
|
)
|
||||||
|
@ -149,13 +149,13 @@ class AuthTestCase(unittest.TestCase):
|
||||||
self.hs.get_datastore().get_monthly_active_count = Mock(
|
self.hs.get_datastore().get_monthly_active_count = Mock(
|
||||||
return_value=defer.succeed(self.hs.config.max_mau_value)
|
return_value=defer.succeed(self.hs.config.max_mau_value)
|
||||||
)
|
)
|
||||||
with self.assertRaises(AuthError):
|
with self.assertRaises(ResourceLimitError):
|
||||||
yield self.auth_handler.get_access_token_for_user_id('user_a')
|
yield self.auth_handler.get_access_token_for_user_id('user_a')
|
||||||
|
|
||||||
self.hs.get_datastore().get_monthly_active_count = Mock(
|
self.hs.get_datastore().get_monthly_active_count = Mock(
|
||||||
return_value=defer.succeed(self.hs.config.max_mau_value)
|
return_value=defer.succeed(self.hs.config.max_mau_value)
|
||||||
)
|
)
|
||||||
with self.assertRaises(AuthError):
|
with self.assertRaises(ResourceLimitError):
|
||||||
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
self._get_macaroon().serialize()
|
self._get_macaroon().serialize()
|
||||||
)
|
)
|
||||||
|
|
|
@ -17,7 +17,7 @@ from mock import Mock
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import ResourceLimitError
|
||||||
from synapse.handlers.register import RegistrationHandler
|
from synapse.handlers.register import RegistrationHandler
|
||||||
from synapse.types import UserID, create_requester
|
from synapse.types import UserID, create_requester
|
||||||
|
|
||||||
|
@ -109,13 +109,13 @@ class RegistrationTestCase(unittest.TestCase):
|
||||||
self.store.get_monthly_active_count = Mock(
|
self.store.get_monthly_active_count = Mock(
|
||||||
return_value=defer.succeed(self.lots_of_users)
|
return_value=defer.succeed(self.lots_of_users)
|
||||||
)
|
)
|
||||||
with self.assertRaises(AuthError):
|
with self.assertRaises(ResourceLimitError):
|
||||||
yield self.handler.get_or_create_user("requester", 'b', "display_name")
|
yield self.handler.get_or_create_user("requester", 'b', "display_name")
|
||||||
|
|
||||||
self.store.get_monthly_active_count = Mock(
|
self.store.get_monthly_active_count = Mock(
|
||||||
return_value=defer.succeed(self.hs.config.max_mau_value)
|
return_value=defer.succeed(self.hs.config.max_mau_value)
|
||||||
)
|
)
|
||||||
with self.assertRaises(AuthError):
|
with self.assertRaises(ResourceLimitError):
|
||||||
yield self.handler.get_or_create_user("requester", 'b', "display_name")
|
yield self.handler.get_or_create_user("requester", 'b', "display_name")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -124,13 +124,13 @@ class RegistrationTestCase(unittest.TestCase):
|
||||||
self.store.get_monthly_active_count = Mock(
|
self.store.get_monthly_active_count = Mock(
|
||||||
return_value=defer.succeed(self.lots_of_users)
|
return_value=defer.succeed(self.lots_of_users)
|
||||||
)
|
)
|
||||||
with self.assertRaises(AuthError):
|
with self.assertRaises(ResourceLimitError):
|
||||||
yield self.handler.register(localpart="local_part")
|
yield self.handler.register(localpart="local_part")
|
||||||
|
|
||||||
self.store.get_monthly_active_count = Mock(
|
self.store.get_monthly_active_count = Mock(
|
||||||
return_value=defer.succeed(self.hs.config.max_mau_value)
|
return_value=defer.succeed(self.hs.config.max_mau_value)
|
||||||
)
|
)
|
||||||
with self.assertRaises(AuthError):
|
with self.assertRaises(ResourceLimitError):
|
||||||
yield self.handler.register(localpart="local_part")
|
yield self.handler.register(localpart="local_part")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -139,11 +139,11 @@ class RegistrationTestCase(unittest.TestCase):
|
||||||
self.store.get_monthly_active_count = Mock(
|
self.store.get_monthly_active_count = Mock(
|
||||||
return_value=defer.succeed(self.lots_of_users)
|
return_value=defer.succeed(self.lots_of_users)
|
||||||
)
|
)
|
||||||
with self.assertRaises(AuthError):
|
with self.assertRaises(ResourceLimitError):
|
||||||
yield self.handler.register_saml2(localpart="local_part")
|
yield self.handler.register_saml2(localpart="local_part")
|
||||||
|
|
||||||
self.store.get_monthly_active_count = Mock(
|
self.store.get_monthly_active_count = Mock(
|
||||||
return_value=defer.succeed(self.hs.config.max_mau_value)
|
return_value=defer.succeed(self.hs.config.max_mau_value)
|
||||||
)
|
)
|
||||||
with self.assertRaises(AuthError):
|
with self.assertRaises(ResourceLimitError):
|
||||||
yield self.handler.register_saml2(localpart="local_part")
|
yield self.handler.register_saml2(localpart="local_part")
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import AuthError, Codes
|
from synapse.api.errors import Codes, ResourceLimitError
|
||||||
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
|
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
|
||||||
from synapse.handlers.sync import SyncConfig, SyncHandler
|
from synapse.handlers.sync import SyncConfig, SyncHandler
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
@ -49,7 +49,7 @@ class SyncTestCase(tests.unittest.TestCase):
|
||||||
|
|
||||||
# Test that global lock works
|
# Test that global lock works
|
||||||
self.hs.config.hs_disabled = True
|
self.hs.config.hs_disabled = True
|
||||||
with self.assertRaises(AuthError) as e:
|
with self.assertRaises(ResourceLimitError) as e:
|
||||||
yield self.sync_handler.wait_for_sync_for_user(sync_config)
|
yield self.sync_handler.wait_for_sync_for_user(sync_config)
|
||||||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
|
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ class SyncTestCase(tests.unittest.TestCase):
|
||||||
|
|
||||||
sync_config = self._generate_sync_config(user_id2)
|
sync_config = self._generate_sync_config(user_id2)
|
||||||
|
|
||||||
with self.assertRaises(AuthError) as e:
|
with self.assertRaises(ResourceLimitError) as e:
|
||||||
yield self.sync_handler.wait_for_sync_for_user(sync_config)
|
yield self.sync_handler.wait_for_sync_for_user(sync_config)
|
||||||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
|
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
|
||||||
|
|
||||||
|
|
72
tests/rest/client/v1/test_presence.py
Normal file
72
tests/rest/client/v1/test_presence.py
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from mock import Mock
|
||||||
|
|
||||||
|
from synapse.rest.client.v1 import presence
|
||||||
|
from synapse.types import UserID
|
||||||
|
|
||||||
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class PresenceTestCase(unittest.HomeserverTestCase):
|
||||||
|
""" Tests presence REST API. """
|
||||||
|
|
||||||
|
user_id = "@sid:red"
|
||||||
|
|
||||||
|
user = UserID.from_string(user_id)
|
||||||
|
servlets = [presence.register_servlets]
|
||||||
|
|
||||||
|
def make_homeserver(self, reactor, clock):
|
||||||
|
|
||||||
|
hs = self.setup_test_homeserver(
|
||||||
|
"red", http_client=None, federation_client=Mock()
|
||||||
|
)
|
||||||
|
|
||||||
|
hs.presence_handler = Mock()
|
||||||
|
|
||||||
|
return hs
|
||||||
|
|
||||||
|
def test_put_presence(self):
|
||||||
|
"""
|
||||||
|
PUT to the status endpoint with use_presence enabled will call
|
||||||
|
set_state on the presence handler.
|
||||||
|
"""
|
||||||
|
self.hs.config.use_presence = True
|
||||||
|
|
||||||
|
body = {"presence": "here", "status_msg": "beep boop"}
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"PUT", "/presence/%s/status" % (self.user_id,), body
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
self.assertEqual(self.hs.presence_handler.set_state.call_count, 1)
|
||||||
|
|
||||||
|
def test_put_presence_disabled(self):
|
||||||
|
"""
|
||||||
|
PUT to the status endpoint with use_presence disbled will NOT call
|
||||||
|
set_state on the presence handler.
|
||||||
|
"""
|
||||||
|
self.hs.config.use_presence = False
|
||||||
|
|
||||||
|
body = {"presence": "here", "status_msg": "beep boop"}
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"PUT", "/presence/%s/status" % (self.user_id,), body
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
self.assertEqual(self.hs.presence_handler.set_state.call_count, 0)
|
|
@ -25,7 +25,7 @@ from synapse.rest.client.v1_only.register import register_servlets
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import make_request, setup_test_homeserver
|
from tests.server import make_request, render, setup_test_homeserver
|
||||||
|
|
||||||
|
|
||||||
class CreateUserServletTestCase(unittest.TestCase):
|
class CreateUserServletTestCase(unittest.TestCase):
|
||||||
|
@ -77,10 +77,7 @@ class CreateUserServletTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
request, channel = make_request(b"POST", url, request_data)
|
request, channel = make_request(b"POST", url, request_data)
|
||||||
request.render(res)
|
render(request, res, self.clock)
|
||||||
|
|
||||||
# Advance the clock because it waits
|
|
||||||
self.clock.advance(1)
|
|
||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"200")
|
self.assertEquals(channel.result["code"], b"200")
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ from twisted.internet import defer
|
||||||
from synapse.api.constants import Membership
|
from synapse.api.constants import Membership
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import make_request, wait_until_result
|
from tests.server import make_request, render
|
||||||
|
|
||||||
|
|
||||||
class RestTestCase(unittest.TestCase):
|
class RestTestCase(unittest.TestCase):
|
||||||
|
@ -171,8 +171,7 @@ class RestHelper(object):
|
||||||
request, channel = make_request(
|
request, channel = make_request(
|
||||||
"POST", path, json.dumps(content).encode('utf8')
|
"POST", path, json.dumps(content).encode('utf8')
|
||||||
)
|
)
|
||||||
request.render(self.resource)
|
render(request, self.resource, self.hs.get_reactor())
|
||||||
wait_until_result(self.hs.get_reactor(), channel)
|
|
||||||
|
|
||||||
assert channel.result["code"] == b"200", channel.result
|
assert channel.result["code"] == b"200", channel.result
|
||||||
self.auth_user_id = temp_id
|
self.auth_user_id = temp_id
|
||||||
|
@ -220,8 +219,7 @@ class RestHelper(object):
|
||||||
|
|
||||||
request, channel = make_request("PUT", path, json.dumps(data).encode('utf8'))
|
request, channel = make_request("PUT", path, json.dumps(data).encode('utf8'))
|
||||||
|
|
||||||
request.render(self.resource)
|
render(request, self.resource, self.hs.get_reactor())
|
||||||
wait_until_result(self.hs.get_reactor(), channel)
|
|
||||||
|
|
||||||
assert int(channel.result["code"]) == expect_code, (
|
assert int(channel.result["code"]) == expect_code, (
|
||||||
"Expected: %d, got: %d, resp: %r"
|
"Expected: %d, got: %d, resp: %r"
|
||||||
|
|
|
@ -24,8 +24,8 @@ from tests import unittest
|
||||||
from tests.server import (
|
from tests.server import (
|
||||||
ThreadedMemoryReactorClock as MemoryReactorClock,
|
ThreadedMemoryReactorClock as MemoryReactorClock,
|
||||||
make_request,
|
make_request,
|
||||||
|
render,
|
||||||
setup_test_homeserver,
|
setup_test_homeserver,
|
||||||
wait_until_result,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
PATH_PREFIX = "/_matrix/client/v2_alpha"
|
PATH_PREFIX = "/_matrix/client/v2_alpha"
|
||||||
|
@ -76,8 +76,7 @@ class FilterTestCase(unittest.TestCase):
|
||||||
"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
|
"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
|
||||||
self.EXAMPLE_FILTER_JSON,
|
self.EXAMPLE_FILTER_JSON,
|
||||||
)
|
)
|
||||||
request.render(self.resource)
|
render(request, self.resource, self.clock)
|
||||||
wait_until_result(self.clock, channel)
|
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"200")
|
self.assertEqual(channel.result["code"], b"200")
|
||||||
self.assertEqual(channel.json_body, {"filter_id": "0"})
|
self.assertEqual(channel.json_body, {"filter_id": "0"})
|
||||||
|
@ -91,8 +90,7 @@ class FilterTestCase(unittest.TestCase):
|
||||||
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
|
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
|
||||||
self.EXAMPLE_FILTER_JSON,
|
self.EXAMPLE_FILTER_JSON,
|
||||||
)
|
)
|
||||||
request.render(self.resource)
|
render(request, self.resource, self.clock)
|
||||||
wait_until_result(self.clock, channel)
|
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"403")
|
self.assertEqual(channel.result["code"], b"403")
|
||||||
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
|
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
|
||||||
|
@ -105,8 +103,7 @@ class FilterTestCase(unittest.TestCase):
|
||||||
"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
|
"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
|
||||||
self.EXAMPLE_FILTER_JSON,
|
self.EXAMPLE_FILTER_JSON,
|
||||||
)
|
)
|
||||||
request.render(self.resource)
|
render(request, self.resource, self.clock)
|
||||||
wait_until_result(self.clock, channel)
|
|
||||||
|
|
||||||
self.hs.is_mine = _is_mine
|
self.hs.is_mine = _is_mine
|
||||||
self.assertEqual(channel.result["code"], b"403")
|
self.assertEqual(channel.result["code"], b"403")
|
||||||
|
@ -121,8 +118,7 @@ class FilterTestCase(unittest.TestCase):
|
||||||
request, channel = make_request(
|
request, channel = make_request(
|
||||||
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
|
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
|
||||||
)
|
)
|
||||||
request.render(self.resource)
|
render(request, self.resource, self.clock)
|
||||||
wait_until_result(self.clock, channel)
|
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"200")
|
self.assertEqual(channel.result["code"], b"200")
|
||||||
self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
|
self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
|
||||||
|
@ -131,8 +127,7 @@ class FilterTestCase(unittest.TestCase):
|
||||||
request, channel = make_request(
|
request, channel = make_request(
|
||||||
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
|
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
|
||||||
)
|
)
|
||||||
request.render(self.resource)
|
render(request, self.resource, self.clock)
|
||||||
wait_until_result(self.clock, channel)
|
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"400")
|
self.assertEqual(channel.result["code"], b"400")
|
||||||
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
|
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
|
||||||
|
@ -143,8 +138,7 @@ class FilterTestCase(unittest.TestCase):
|
||||||
request, channel = make_request(
|
request, channel = make_request(
|
||||||
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
|
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
|
||||||
)
|
)
|
||||||
request.render(self.resource)
|
render(request, self.resource, self.clock)
|
||||||
wait_until_result(self.clock, channel)
|
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"400")
|
self.assertEqual(channel.result["code"], b"400")
|
||||||
|
|
||||||
|
@ -153,7 +147,6 @@ class FilterTestCase(unittest.TestCase):
|
||||||
request, channel = make_request(
|
request, channel = make_request(
|
||||||
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
|
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
|
||||||
)
|
)
|
||||||
request.render(self.resource)
|
render(request, self.resource, self.clock)
|
||||||
wait_until_result(self.clock, channel)
|
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"400")
|
self.assertEqual(channel.result["code"], b"400")
|
||||||
|
|
|
@ -11,7 +11,7 @@ from synapse.rest.client.v2_alpha.register import register_servlets
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import make_request, setup_test_homeserver, wait_until_result
|
from tests.server import make_request, render, setup_test_homeserver
|
||||||
|
|
||||||
|
|
||||||
class RegisterRestServletTestCase(unittest.TestCase):
|
class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
|
@ -72,8 +72,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
request, channel = make_request(
|
request, channel = make_request(
|
||||||
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
|
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
|
||||||
)
|
)
|
||||||
request.render(self.resource)
|
render(request, self.resource, self.clock)
|
||||||
wait_until_result(self.clock, channel)
|
|
||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
det_data = {
|
det_data = {
|
||||||
|
@ -89,16 +88,14 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
request, channel = make_request(
|
request, channel = make_request(
|
||||||
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
|
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
|
||||||
)
|
)
|
||||||
request.render(self.resource)
|
render(request, self.resource, self.clock)
|
||||||
wait_until_result(self.clock, channel)
|
|
||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||||
|
|
||||||
def test_POST_bad_password(self):
|
def test_POST_bad_password(self):
|
||||||
request_data = json.dumps({"username": "kermit", "password": 666})
|
request_data = json.dumps({"username": "kermit", "password": 666})
|
||||||
request, channel = make_request(b"POST", self.url, request_data)
|
request, channel = make_request(b"POST", self.url, request_data)
|
||||||
request.render(self.resource)
|
render(request, self.resource, self.clock)
|
||||||
wait_until_result(self.clock, channel)
|
|
||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"400", channel.result)
|
self.assertEquals(channel.result["code"], b"400", channel.result)
|
||||||
self.assertEquals(channel.json_body["error"], "Invalid password")
|
self.assertEquals(channel.json_body["error"], "Invalid password")
|
||||||
|
@ -106,8 +103,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
def test_POST_bad_username(self):
|
def test_POST_bad_username(self):
|
||||||
request_data = json.dumps({"username": 777, "password": "monkey"})
|
request_data = json.dumps({"username": 777, "password": "monkey"})
|
||||||
request, channel = make_request(b"POST", self.url, request_data)
|
request, channel = make_request(b"POST", self.url, request_data)
|
||||||
request.render(self.resource)
|
render(request, self.resource, self.clock)
|
||||||
wait_until_result(self.clock, channel)
|
|
||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"400", channel.result)
|
self.assertEquals(channel.result["code"], b"400", channel.result)
|
||||||
self.assertEquals(channel.json_body["error"], "Invalid username")
|
self.assertEquals(channel.json_body["error"], "Invalid username")
|
||||||
|
@ -126,8 +122,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
self.device_handler.check_device_registered = Mock(return_value=device_id)
|
self.device_handler.check_device_registered = Mock(return_value=device_id)
|
||||||
|
|
||||||
request, channel = make_request(b"POST", self.url, request_data)
|
request, channel = make_request(b"POST", self.url, request_data)
|
||||||
request.render(self.resource)
|
render(request, self.resource, self.clock)
|
||||||
wait_until_result(self.clock, channel)
|
|
||||||
|
|
||||||
det_data = {
|
det_data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
@ -149,8 +144,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
|
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
|
||||||
|
|
||||||
request, channel = make_request(b"POST", self.url, request_data)
|
request, channel = make_request(b"POST", self.url, request_data)
|
||||||
request.render(self.resource)
|
render(request, self.resource, self.clock)
|
||||||
wait_until_result(self.clock, channel)
|
|
||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"403", channel.result)
|
self.assertEquals(channel.result["code"], b"403", channel.result)
|
||||||
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
|
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
|
||||||
|
@ -162,8 +156,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
self.registration_handler.register = Mock(return_value=(user_id, None))
|
self.registration_handler.register = Mock(return_value=(user_id, None))
|
||||||
|
|
||||||
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||||
request.render(self.resource)
|
render(request, self.resource, self.clock)
|
||||||
wait_until_result(self.clock, channel)
|
|
||||||
|
|
||||||
det_data = {
|
det_data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
@ -177,8 +170,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
self.hs.config.allow_guest_access = False
|
self.hs.config.allow_guest_access = False
|
||||||
|
|
||||||
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||||
request.render(self.resource)
|
render(request, self.resource, self.clock)
|
||||||
wait_until_result(self.clock, channel)
|
|
||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"403", channel.result)
|
self.assertEquals(channel.result["code"], b"403", channel.result)
|
||||||
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
|
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
|
||||||
|
|
|
@ -13,66 +13,30 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import synapse.types
|
from mock import Mock
|
||||||
from synapse.http.server import JsonResource
|
|
||||||
from synapse.rest.client.v2_alpha import sync
|
from synapse.rest.client.v2_alpha import sync
|
||||||
from synapse.types import UserID
|
|
||||||
from synapse.util import Clock
|
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import (
|
|
||||||
ThreadedMemoryReactorClock as MemoryReactorClock,
|
|
||||||
make_request,
|
|
||||||
setup_test_homeserver,
|
|
||||||
wait_until_result,
|
|
||||||
)
|
|
||||||
|
|
||||||
PATH_PREFIX = "/_matrix/client/v2_alpha"
|
|
||||||
|
|
||||||
|
|
||||||
class FilterTestCase(unittest.TestCase):
|
class FilterTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
USER_ID = "@apple:test"
|
user_id = "@apple:test"
|
||||||
TO_REGISTER = [sync]
|
servlets = [sync.register_servlets]
|
||||||
|
|
||||||
def setUp(self):
|
def make_homeserver(self, reactor, clock):
|
||||||
self.clock = MemoryReactorClock()
|
|
||||||
self.hs_clock = Clock(self.clock)
|
|
||||||
|
|
||||||
self.hs = setup_test_homeserver(
|
hs = self.setup_test_homeserver(
|
||||||
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
|
"red", http_client=None, federation_client=Mock()
|
||||||
)
|
)
|
||||||
|
return hs
|
||||||
self.auth = self.hs.get_auth()
|
|
||||||
|
|
||||||
def get_user_by_access_token(token=None, allow_guest=False):
|
|
||||||
return {
|
|
||||||
"user": UserID.from_string(self.USER_ID),
|
|
||||||
"token_id": 1,
|
|
||||||
"is_guest": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_user_by_req(request, allow_guest=False, rights="access"):
|
|
||||||
return synapse.types.create_requester(
|
|
||||||
UserID.from_string(self.USER_ID), 1, False, None
|
|
||||||
)
|
|
||||||
|
|
||||||
self.auth.get_user_by_access_token = get_user_by_access_token
|
|
||||||
self.auth.get_user_by_req = get_user_by_req
|
|
||||||
|
|
||||||
self.store = self.hs.get_datastore()
|
|
||||||
self.filtering = self.hs.get_filtering()
|
|
||||||
self.resource = JsonResource(self.hs)
|
|
||||||
|
|
||||||
for r in self.TO_REGISTER:
|
|
||||||
r.register_servlets(self.hs, self.resource)
|
|
||||||
|
|
||||||
def test_sync_argless(self):
|
def test_sync_argless(self):
|
||||||
request, channel = make_request("GET", "/_matrix/client/r0/sync")
|
request, channel = self.make_request("GET", "/sync")
|
||||||
request.render(self.resource)
|
self.render(request)
|
||||||
wait_until_result(self.clock, channel)
|
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"200")
|
self.assertEqual(channel.code, 200)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
set(
|
set(
|
||||||
[
|
[
|
||||||
|
@ -85,3 +49,25 @@ class FilterTestCase(unittest.TestCase):
|
||||||
]
|
]
|
||||||
).issubset(set(channel.json_body.keys()))
|
).issubset(set(channel.json_body.keys()))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_sync_presence_disabled(self):
|
||||||
|
"""
|
||||||
|
When presence is disabled, the key does not appear in /sync.
|
||||||
|
"""
|
||||||
|
self.hs.config.use_presence = False
|
||||||
|
|
||||||
|
request, channel = self.make_request("GET", "/sync")
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
self.assertTrue(
|
||||||
|
set(
|
||||||
|
[
|
||||||
|
"next_batch",
|
||||||
|
"rooms",
|
||||||
|
"account_data",
|
||||||
|
"to_device",
|
||||||
|
"device_lists",
|
||||||
|
]
|
||||||
|
).issubset(set(channel.json_body.keys()))
|
||||||
|
)
|
||||||
|
|
|
@ -24,6 +24,7 @@ class FakeChannel(object):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = attr.ib(default=attr.Factory(dict))
|
result = attr.ib(default=attr.Factory(dict))
|
||||||
|
_producer = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def json_body(self):
|
def json_body(self):
|
||||||
|
@ -49,6 +50,15 @@ class FakeChannel(object):
|
||||||
|
|
||||||
self.result["body"] += content
|
self.result["body"] += content
|
||||||
|
|
||||||
|
def registerProducer(self, producer, streaming):
|
||||||
|
self._producer = producer
|
||||||
|
|
||||||
|
def unregisterProducer(self):
|
||||||
|
if self._producer is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._producer = None
|
||||||
|
|
||||||
def requestDone(self, _self):
|
def requestDone(self, _self):
|
||||||
self.result["done"] = True
|
self.result["done"] = True
|
||||||
|
|
||||||
|
@ -111,14 +121,19 @@ def make_request(method, path, content=b""):
|
||||||
return req, channel
|
return req, channel
|
||||||
|
|
||||||
|
|
||||||
def wait_until_result(clock, channel, timeout=100):
|
def wait_until_result(clock, request, timeout=100):
|
||||||
"""
|
"""
|
||||||
Wait until the channel has a result.
|
Wait until the request is finished.
|
||||||
"""
|
"""
|
||||||
clock.run()
|
clock.run()
|
||||||
x = 0
|
x = 0
|
||||||
|
|
||||||
while not channel.result:
|
while not request.finished:
|
||||||
|
|
||||||
|
# If there's a producer, tell it to resume producing so we get content
|
||||||
|
if request._channel._producer:
|
||||||
|
request._channel._producer.resumeProducing()
|
||||||
|
|
||||||
x += 1
|
x += 1
|
||||||
|
|
||||||
if x > timeout:
|
if x > timeout:
|
||||||
|
@ -129,7 +144,7 @@ def wait_until_result(clock, channel, timeout=100):
|
||||||
|
|
||||||
def render(request, resource, clock):
|
def render(request, resource, clock):
|
||||||
request.render(resource)
|
request.render(resource)
|
||||||
wait_until_result(clock, request._channel)
|
wait_until_result(clock, request)
|
||||||
|
|
||||||
|
|
||||||
class ThreadedMemoryReactorClock(MemoryReactorClock):
|
class ThreadedMemoryReactorClock(MemoryReactorClock):
|
||||||
|
|
|
@ -75,6 +75,19 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase):
|
||||||
active_count = yield self.store.get_monthly_active_count()
|
active_count = yield self.store.get_monthly_active_count()
|
||||||
self.assertEquals(active_count, user_num)
|
self.assertEquals(active_count, user_num)
|
||||||
|
|
||||||
|
# Test that regalar users are removed from the db
|
||||||
|
ru_count = 2
|
||||||
|
yield self.store.upsert_monthly_active_user("@ru1:server")
|
||||||
|
yield self.store.upsert_monthly_active_user("@ru2:server")
|
||||||
|
active_count = yield self.store.get_monthly_active_count()
|
||||||
|
|
||||||
|
self.assertEqual(active_count, user_num + ru_count)
|
||||||
|
self.hs.config.max_mau_value = user_num
|
||||||
|
yield self.store.reap_monthly_active_users()
|
||||||
|
|
||||||
|
active_count = yield self.store.get_monthly_active_count()
|
||||||
|
self.assertEquals(active_count, user_num)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_can_insert_and_count_mau(self):
|
def test_can_insert_and_count_mau(self):
|
||||||
count = yield self.store.get_monthly_active_count()
|
count = yield self.store.get_monthly_active_count()
|
||||||
|
|
|
@ -8,7 +8,7 @@ from synapse.http.server import JsonResource
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import make_request, setup_test_homeserver
|
from tests.server import make_request, render, setup_test_homeserver
|
||||||
|
|
||||||
|
|
||||||
class JsonResourceTests(unittest.TestCase):
|
class JsonResourceTests(unittest.TestCase):
|
||||||
|
@ -37,7 +37,7 @@ class JsonResourceTests(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83")
|
request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83")
|
||||||
request.render(res)
|
render(request, res, self.reactor)
|
||||||
|
|
||||||
self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
|
self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
|
||||||
self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"})
|
self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"})
|
||||||
|
@ -55,7 +55,7 @@ class JsonResourceTests(unittest.TestCase):
|
||||||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
||||||
|
|
||||||
request, channel = make_request(b"GET", b"/_matrix/foo")
|
request, channel = make_request(b"GET", b"/_matrix/foo")
|
||||||
request.render(res)
|
render(request, res, self.reactor)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b'500')
|
self.assertEqual(channel.result["code"], b'500')
|
||||||
|
|
||||||
|
@ -78,13 +78,8 @@ class JsonResourceTests(unittest.TestCase):
|
||||||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
||||||
|
|
||||||
request, channel = make_request(b"GET", b"/_matrix/foo")
|
request, channel = make_request(b"GET", b"/_matrix/foo")
|
||||||
request.render(res)
|
render(request, res, self.reactor)
|
||||||
|
|
||||||
# No error has been raised yet
|
|
||||||
self.assertTrue("code" not in channel.result)
|
|
||||||
|
|
||||||
# Advance time, now there's an error
|
|
||||||
self.reactor.advance(1)
|
|
||||||
self.assertEqual(channel.result["code"], b'500')
|
self.assertEqual(channel.result["code"], b'500')
|
||||||
|
|
||||||
def test_callback_synapseerror(self):
|
def test_callback_synapseerror(self):
|
||||||
|
@ -100,7 +95,7 @@ class JsonResourceTests(unittest.TestCase):
|
||||||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
||||||
|
|
||||||
request, channel = make_request(b"GET", b"/_matrix/foo")
|
request, channel = make_request(b"GET", b"/_matrix/foo")
|
||||||
request.render(res)
|
render(request, res, self.reactor)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b'403')
|
self.assertEqual(channel.result["code"], b'403')
|
||||||
self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
|
self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
|
||||||
|
@ -121,7 +116,7 @@ class JsonResourceTests(unittest.TestCase):
|
||||||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
||||||
|
|
||||||
request, channel = make_request(b"GET", b"/_matrix/foobar")
|
request, channel = make_request(b"GET", b"/_matrix/foobar")
|
||||||
request.render(res)
|
render(request, res, self.reactor)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b'400')
|
self.assertEqual(channel.result["code"], b'400')
|
||||||
self.assertEqual(channel.json_body["error"], "Unrecognized request")
|
self.assertEqual(channel.json_body["error"], "Unrecognized request")
|
||||||
|
|
|
@ -18,6 +18,8 @@ import logging
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
|
from canonicaljson import json
|
||||||
|
|
||||||
import twisted
|
import twisted
|
||||||
import twisted.logger
|
import twisted.logger
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
|
@ -241,11 +243,15 @@ class HomeserverTestCase(TestCase):
|
||||||
method (bytes/unicode): The HTTP request method ("verb").
|
method (bytes/unicode): The HTTP request method ("verb").
|
||||||
path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
|
path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
|
||||||
escaped UTF-8 & spaces and such).
|
escaped UTF-8 & spaces and such).
|
||||||
content (bytes): The body of the request.
|
content (bytes or dict): The body of the request. JSON-encoded, if
|
||||||
|
a dict.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A synapse.http.site.SynapseRequest.
|
A synapse.http.site.SynapseRequest.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(content, dict):
|
||||||
|
content = json.dumps(content).encode('utf8')
|
||||||
|
|
||||||
return make_request(method, path, content)
|
return make_request(method, path, content)
|
||||||
|
|
||||||
def render(self, request):
|
def render(self, request):
|
||||||
|
|
|
@ -93,7 +93,8 @@ def setupdb():
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def setup_test_homeserver(
|
def setup_test_homeserver(
|
||||||
cleanup_func, name="test", datastore=None, config=None, reactor=None, **kargs
|
cleanup_func, name="test", datastore=None, config=None, reactor=None,
|
||||||
|
homeserverToUse=HomeServer, **kargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Setup a homeserver suitable for running tests against. Keyword arguments
|
Setup a homeserver suitable for running tests against. Keyword arguments
|
||||||
|
@ -137,6 +138,7 @@ def setup_test_homeserver(
|
||||||
config.limit_usage_by_mau = False
|
config.limit_usage_by_mau = False
|
||||||
config.hs_disabled = False
|
config.hs_disabled = False
|
||||||
config.hs_disabled_message = ""
|
config.hs_disabled_message = ""
|
||||||
|
config.hs_disabled_limit_type = ""
|
||||||
config.max_mau_value = 50
|
config.max_mau_value = 50
|
||||||
config.mau_limits_reserved_threepids = []
|
config.mau_limits_reserved_threepids = []
|
||||||
config.admin_uri = None
|
config.admin_uri = None
|
||||||
|
@ -191,7 +193,7 @@ def setup_test_homeserver(
|
||||||
config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
|
config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
|
||||||
|
|
||||||
if datastore is None:
|
if datastore is None:
|
||||||
hs = HomeServer(
|
hs = homeserverToUse(
|
||||||
name,
|
name,
|
||||||
config=config,
|
config=config,
|
||||||
db_config=config.database_config,
|
db_config=config.database_config,
|
||||||
|
@ -234,7 +236,7 @@ def setup_test_homeserver(
|
||||||
|
|
||||||
hs.setup()
|
hs.setup()
|
||||||
else:
|
else:
|
||||||
hs = HomeServer(
|
hs = homeserverToUse(
|
||||||
name,
|
name,
|
||||||
db_pool=None,
|
db_pool=None,
|
||||||
datastore=datastore,
|
datastore=datastore,
|
||||||
|
|
Loading…
Reference in a new issue