Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2017-03-31 10:11:07 +01:00
commit d7dbc56c71
22 changed files with 423 additions and 123 deletions

View file

@ -146,6 +146,7 @@ To install the synapse homeserver run::
virtualenv -p python2.7 ~/.synapse virtualenv -p python2.7 ~/.synapse
source ~/.synapse/bin/activate source ~/.synapse/bin/activate
pip install --upgrade pip
pip install --upgrade setuptools pip install --upgrade setuptools
pip install https://github.com/matrix-org/synapse/tarball/master pip install https://github.com/matrix-org/synapse/tarball/master
@ -228,6 +229,7 @@ To get started, it is easiest to use the command line to register new users::
New user localpart: erikj New user localpart: erikj
Password: Password:
Confirm password: Confirm password:
Make admin [no]:
Success! Success!
This process uses a setting ``registration_shared_secret`` in This process uses a setting ``registration_shared_secret`` in

View file

@ -204,9 +204,14 @@ That doesn't follow the rules, but we can fix it by wrapping it with
This technique works equally for external functions which return deferreds, This technique works equally for external functions which return deferreds,
or deferreds we have made ourselves. or deferreds we have made ourselves.
XXX: think this is what ``preserve_context_over_deferred`` is supposed to do, You can also use ``logcontext.make_deferred_yieldable``, which just does the
though it is broken, in that it only restores the logcontext for the duration boilerplate for you, so the above could be written:
of the callbacks, which doesn't comply with the logcontext rules.
.. code:: python
def sleep(seconds):
return logcontext.make_deferred_yieldable(get_sleep_deferred(seconds))
Fire-and-forget Fire-and-forget
--------------- ---------------

View file

@ -23,14 +23,27 @@ import signal
import subprocess import subprocess
import sys import sys
import yaml import yaml
import errno
import time
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"] SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
GREEN = "\x1b[1;32m" GREEN = "\x1b[1;32m"
YELLOW = "\x1b[1;33m"
RED = "\x1b[1;31m" RED = "\x1b[1;31m"
NORMAL = "\x1b[m" NORMAL = "\x1b[m"
def pid_running(pid):
try:
os.kill(pid, 0)
return True
except OSError, err:
if err.errno == errno.EPERM:
return True
return False
def write(message, colour=NORMAL, stream=sys.stdout): def write(message, colour=NORMAL, stream=sys.stdout):
if colour == NORMAL: if colour == NORMAL:
stream.write(message + "\n") stream.write(message + "\n")
@ -38,6 +51,11 @@ def write(message, colour=NORMAL, stream=sys.stdout):
stream.write(colour + message + NORMAL + "\n") stream.write(colour + message + NORMAL + "\n")
def abort(message, colour=RED, stream=sys.stderr):
write(message, colour, stream)
sys.exit(1)
def start(configfile): def start(configfile):
write("Starting ...") write("Starting ...")
args = SYNAPSE args = SYNAPSE
@ -45,7 +63,8 @@ def start(configfile):
try: try:
subprocess.check_call(args) subprocess.check_call(args)
write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN) write("started synapse.app.homeserver(%r)" %
(configfile,), colour=GREEN)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
write( write(
"error starting (exit code: %d); see above for logs" % e.returncode, "error starting (exit code: %d); see above for logs" % e.returncode,
@ -76,8 +95,16 @@ def start_worker(app, configfile, worker_configfile):
def stop(pidfile, app): def stop(pidfile, app):
if os.path.exists(pidfile): if os.path.exists(pidfile):
pid = int(open(pidfile).read()) pid = int(open(pidfile).read())
os.kill(pid, signal.SIGTERM) try:
write("stopped %s" % (app,), colour=GREEN) os.kill(pid, signal.SIGTERM)
write("stopped %s" % (app,), colour=GREEN)
except OSError, err:
if err.errno == errno.ESRCH:
write("%s not running" % (app,), colour=YELLOW)
elif err.errno == errno.EPERM:
abort("Cannot stop %s: Operation not permitted" % (app,))
else:
abort("Cannot stop %s: Unknown error" % (app,))
Worker = collections.namedtuple("Worker", [ Worker = collections.namedtuple("Worker", [
@ -190,7 +217,19 @@ def main():
if start_stop_synapse: if start_stop_synapse:
stop(pidfile, "synapse.app.homeserver") stop(pidfile, "synapse.app.homeserver")
# TODO: Wait for synapse to actually shutdown before starting it again # Wait for synapse to actually shutdown before starting it again
if action == "restart":
running_pids = []
if start_stop_synapse and os.path.exists(pidfile):
running_pids.append(int(open(pidfile).read()))
for worker in workers:
if os.path.exists(worker.pidfile):
running_pids.append(int(open(worker.pidfile).read()))
if len(running_pids) > 0:
write("Waiting for process to exit before restarting...")
for running_pid in running_pids:
while pid_running(running_pid):
time.sleep(0.2)
if action == "start" or action == "restart": if action == "start" or action == "restart":
if start_stop_synapse: if start_stop_synapse:

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.util.caches.descriptors import cachedInlineCallbacks
from twisted.internet import defer from twisted.internet import defer
@ -124,29 +125,23 @@ class ApplicationService(object):
raise ValueError( raise ValueError(
"Expected bool for 'exclusive' in ns '%s'" % ns "Expected bool for 'exclusive' in ns '%s'" % ns
) )
if not isinstance(regex_obj.get("regex"), basestring): regex = regex_obj.get("regex")
if isinstance(regex, basestring):
regex_obj["regex"] = re.compile(regex) # Pre-compile regex
else:
raise ValueError( raise ValueError(
"Expected string for 'regex' in ns '%s'" % ns "Expected string for 'regex' in ns '%s'" % ns
) )
return namespaces return namespaces
def _matches_regex(self, test_string, namespace_key, return_obj=False): def _matches_regex(self, test_string, namespace_key):
if not isinstance(test_string, basestring):
logger.error(
"Expected a string to test regex against, but got %s",
test_string
)
return False
for regex_obj in self.namespaces[namespace_key]: for regex_obj in self.namespaces[namespace_key]:
if re.match(regex_obj["regex"], test_string): if regex_obj["regex"].match(test_string):
if return_obj: return regex_obj
return regex_obj return None
return True
return False
def _is_exclusive(self, ns_key, test_string): def _is_exclusive(self, ns_key, test_string):
regex_obj = self._matches_regex(test_string, ns_key, return_obj=True) regex_obj = self._matches_regex(test_string, ns_key)
if regex_obj: if regex_obj:
return regex_obj["exclusive"] return regex_obj["exclusive"]
return False return False
@ -166,7 +161,14 @@ class ApplicationService(object):
if not store: if not store:
defer.returnValue(False) defer.returnValue(False)
member_list = yield store.get_users_in_room(event.room_id) does_match = yield self._matches_user_in_member_list(event.room_id, store)
defer.returnValue(does_match)
@cachedInlineCallbacks(num_args=1, cache_context=True)
def _matches_user_in_member_list(self, room_id, store, cache_context):
member_list = yield store.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate
)
# check joined member events # check joined member events
for user_id in member_list: for user_id in member_list:
@ -219,10 +221,10 @@ class ApplicationService(object):
) )
def is_interested_in_alias(self, alias): def is_interested_in_alias(self, alias):
return self._matches_regex(alias, ApplicationService.NS_ALIASES) return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
def is_interested_in_room(self, room_id): def is_interested_in_room(self, room_id):
return self._matches_regex(room_id, ApplicationService.NS_ROOMS) return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
def is_exclusive_user(self, user_id): def is_exclusive_user(self, user_id):
return ( return (

View file

@ -54,6 +54,7 @@ class FederationRemoteSendQueue(object):
def __init__(self, hs): def __init__(self, hs):
self.server_name = hs.hostname self.server_name = hs.hostname
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self.presence_map = {} self.presence_map = {}
self.presence_changed = sorteddict() self.presence_changed = sorteddict()
@ -186,6 +187,8 @@ class FederationRemoteSendQueue(object):
else: else:
self.edus[pos] = edu self.edus[pos] = edu
self.notifier.on_new_replication_data()
def send_presence(self, destination, states): def send_presence(self, destination, states):
"""As per TransactionQueue""" """As per TransactionQueue"""
pos = self._next_pos() pos = self._next_pos()
@ -199,16 +202,20 @@ class FederationRemoteSendQueue(object):
(destination, state.user_id) for state in states (destination, state.user_id) for state in states
] ]
self.notifier.on_new_replication_data()
def send_failure(self, failure, destination): def send_failure(self, failure, destination):
"""As per TransactionQueue""" """As per TransactionQueue"""
pos = self._next_pos() pos = self._next_pos()
self.failures[pos] = (destination, str(failure)) self.failures[pos] = (destination, str(failure))
self.notifier.on_new_replication_data()
def send_device_messages(self, destination): def send_device_messages(self, destination):
"""As per TransactionQueue""" """As per TransactionQueue"""
pos = self._next_pos() pos = self._next_pos()
self.device_messages[pos] = destination self.device_messages[pos] = destination
self.notifier.on_new_replication_data()
def get_current_token(self): def get_current_token(self):
return self.pos - 1 return self.pos - 1

View file

@ -22,7 +22,7 @@ from .units import Transaction, Edu
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
@ -303,12 +303,19 @@ class TransactionQueue(object):
) )
return return
pending_pdus = []
try: try:
self.pending_transactions[destination] = 1 self.pending_transactions[destination] = 1
# This will throw if we wouldn't retry. We do this here so we fail
# quickly, but we will later check this again in the http client,
# hence why we throw the result away.
yield get_retry_limiter(destination, self.clock, self.store)
# XXX: what's this for? # XXX: what's this for?
yield run_on_reactor() yield run_on_reactor()
pending_pdus = []
while True: while True:
device_message_edus, device_stream_id, dev_list_id = ( device_message_edus, device_stream_id, dev_list_id = (
yield self._get_new_device_messages(destination) yield self._get_new_device_messages(destination)
@ -397,7 +404,7 @@ class TransactionQueue(object):
destination, destination,
e, e,
) )
for p in pending_pdus: for p, _ in pending_pdus:
logger.info("Failed to send event %s to %s", p.event_id, logger.info("Failed to send event %s to %s", p.event_id,
destination) destination)
finally: finally:

View file

@ -17,6 +17,7 @@ import logging
import re import re
from synapse.types import UserID from synapse.types import UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -125,6 +126,11 @@ class PushRuleEvaluatorForEvent(object):
return self._value_cache.get(dotted_key, None) return self._value_cache.get(dotted_key, None)
# Caches (glob, word_boundary) -> regex for push. See _glob_matches
regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
register_cache("regex_push_cache", regex_cache)
def _glob_matches(glob, value, word_boundary=False): def _glob_matches(glob, value, word_boundary=False):
"""Tests if value matches glob. """Tests if value matches glob.
@ -137,46 +143,63 @@ def _glob_matches(glob, value, word_boundary=False):
Returns: Returns:
bool bool
""" """
try: try:
if IS_GLOB.search(glob): r = regex_cache.get((glob, word_boundary), None)
r = re.escape(glob) if not r:
r = _glob_to_re(glob, word_boundary)
r = r.replace(r'\*', '.*?') regex_cache[(glob, word_boundary)] = r
r = r.replace(r'\?', '.') return r.search(value)
# handle [abc], [a-z] and [!a-z] style ranges.
r = GLOB_REGEX.sub(
lambda x: (
'[%s%s]' % (
x.group(1) and '^' or '',
x.group(2).replace(r'\\\-', '-')
)
),
r,
)
if word_boundary:
r = r"\b%s\b" % (r,)
r = _compile_regex(r)
return r.search(value)
else:
r = r + "$"
r = _compile_regex(r)
return r.match(value)
elif word_boundary:
r = re.escape(glob)
r = r"\b%s\b" % (r,)
r = _compile_regex(r)
return r.search(value)
else:
return value.lower() == glob.lower()
except re.error: except re.error:
logger.warn("Failed to parse glob to regex: %r", glob) logger.warn("Failed to parse glob to regex: %r", glob)
return False return False
def _glob_to_re(glob, word_boundary):
"""Generates regex for a given glob.
Args:
glob (string)
word_boundary (bool): Whether to match against word boundaries or entire
string. Defaults to False.
Returns:
regex object
"""
if IS_GLOB.search(glob):
r = re.escape(glob)
r = r.replace(r'\*', '.*?')
r = r.replace(r'\?', '.')
# handle [abc], [a-z] and [!a-z] style ranges.
r = GLOB_REGEX.sub(
lambda x: (
'[%s%s]' % (
x.group(1) and '^' or '',
x.group(2).replace(r'\\\-', '-')
)
),
r,
)
if word_boundary:
r = r"\b%s\b" % (r,)
return re.compile(r, flags=re.IGNORECASE)
else:
r = "^" + r + "$"
return re.compile(r, flags=re.IGNORECASE)
elif word_boundary:
r = re.escape(glob)
r = r"\b%s\b" % (r,)
return re.compile(r, flags=re.IGNORECASE)
else:
r = "^" + re.escape(glob) + "$"
return re.compile(r, flags=re.IGNORECASE)
def _flatten_dict(d, prefix=[], result={}): def _flatten_dict(d, prefix=[], result={}):
for key, value in d.items(): for key, value in d.items():
if isinstance(value, basestring): if isinstance(value, basestring):
@ -185,16 +208,3 @@ def _flatten_dict(d, prefix=[], result={}):
_flatten_dict(value, prefix=(prefix + [key]), result=result) _flatten_dict(value, prefix=(prefix + [key]), result=result)
return result return result
regex_cache = LruCache(5000)
def _compile_regex(regex_str):
r = regex_cache.get(regex_str, None)
if r:
return r
r = re.compile(regex_str, flags=re.IGNORECASE)
regex_cache[regex_str] = r
return r

View file

@ -17,15 +17,12 @@ from twisted.internet import defer
from synapse.push.presentable_names import ( from synapse.push.presentable_names import (
calculate_room_name, name_from_member_event calculate_room_name, name_from_member_event
) )
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@defer.inlineCallbacks @defer.inlineCallbacks
def get_badge_count(store, user_id): def get_badge_count(store, user_id):
invites, joins = yield preserve_context_over_deferred(defer.gatherResults([ invites = yield store.get_invited_rooms_for_user(user_id)
preserve_fn(store.get_invited_rooms_for_user)(user_id), joins = yield store.get_rooms_for_user(user_id)
preserve_fn(store.get_rooms_for_user)(user_id),
], consumeErrors=True))
my_receipts_by_room = yield store.get_receipts_for_user( my_receipts_by_room = yield store.get_receipts_for_user(
user_id, "m.read", user_id, "m.read",

View file

@ -195,11 +195,11 @@ class StateHandler(object):
(s.type, s.state_key): s.event_id for s in old_state (s.type, s.state_key): s.event_id for s in old_state
} }
if event.is_state(): if event.is_state():
context.current_state_events = dict(context.prev_state_ids) context.current_state_ids = dict(context.prev_state_ids)
key = (event.type, event.state_key) key = (event.type, event.state_key)
context.current_state_events[key] = event.event_id context.current_state_ids[key] = event.event_id
else: else:
context.current_state_events = context.prev_state_ids context.current_state_ids = context.prev_state_ids
else: else:
context.current_state_ids = {} context.current_state_ids = {}
context.prev_state_ids = {} context.prev_state_ids = {}

View file

@ -329,6 +329,7 @@ class DeviceStore(SQLBaseStore):
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id GROUP BY user_id, device_id
LIMIT 20
""" """
txn.execute( txn.execute(
sql, (destination, from_stream_id, now_stream_id, False) sql, (destination, from_stream_id, now_stream_id, False)
@ -339,6 +340,9 @@ class DeviceStore(SQLBaseStore):
if not query_map: if not query_map:
return (now_stream_id, []) return (now_stream_id, [])
if len(query_map) >= 20:
now_stream_id = max(stream_id for stream_id in query_map.itervalues())
devices = self._get_e2e_device_keys_txn( devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True txn, query_map.keys(), include_all_devices=True
) )

View file

@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
import ujson as json import ujson as json
@ -120,24 +122,63 @@ class EndToEndKeyStore(SQLBaseStore):
return result return result
@defer.inlineCallbacks
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
"""Insert some new one time keys for a device.
Checks if any of the keys are already inserted, if they are then check
if they match. If they don't then we raise an error.
"""
# First we check if we have already persisted any of the keys.
rows = yield self._simple_select_many_batch(
table="e2e_one_time_keys_json",
column="key_id",
iterable=[key_id for _, key_id, _ in key_list],
retcols=("algorithm", "key_id", "key_json",),
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
desc="add_e2e_one_time_keys_check",
)
existing_key_map = {
(row["algorithm"], row["key_id"]): row["key_json"] for row in rows
}
new_keys = [] # Keys that we need to insert
for algorithm, key_id, json_bytes in key_list:
ex_bytes = existing_key_map.get((algorithm, key_id), None)
if ex_bytes:
if json_bytes != ex_bytes:
raise SynapseError(
400, "One time key with key_id %r already exists" % (key_id,)
)
else:
new_keys.append((algorithm, key_id, json_bytes))
def _add_e2e_one_time_keys(txn): def _add_e2e_one_time_keys(txn):
for (algorithm, key_id, json_bytes) in key_list: # We are protected from race between lookup and insertion due to
self._simple_upsert_txn( # a unique constraint. If there is a race of two calls to
txn, table="e2e_one_time_keys_json", # `add_e2e_one_time_keys` then they'll conflict and we will only
keyvalues={ # insert one set.
self._simple_insert_many_txn(
txn, table="e2e_one_time_keys_json",
values=[
{
"user_id": user_id, "user_id": user_id,
"device_id": device_id, "device_id": device_id,
"algorithm": algorithm, "algorithm": algorithm,
"key_id": key_id, "key_id": key_id,
},
values={
"ts_added_ms": time_now, "ts_added_ms": time_now,
"key_json": json_bytes, "key_json": json_bytes,
} }
) for algorithm, key_id, json_bytes in new_keys
return self.runInteraction( ],
"add_e2e_one_time_keys", _add_e2e_one_time_keys )
yield self.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
) )
def count_e2e_one_time_keys(self, user_id, device_id): def count_e2e_one_time_keys(self, user_id, device_id):

View file

@ -152,7 +152,7 @@ class EventFederationStore(SQLBaseStore):
txn.execute(sql, (room_id, )) txn.execute(sql, (room_id, ))
results = [] results = []
for event_id, depth in txn: for event_id, depth in txn.fetchall():
hashes = self._get_event_reference_hashes_txn(txn, event_id) hashes = self._get_event_reference_hashes_txn(txn, event_id)
prev_hashes = { prev_hashes = {
k: encode_base64(v) for k, v in hashes.items() k: encode_base64(v) for k, v in hashes.items()

View file

@ -496,7 +496,7 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_ids_for_events([event_id], types) state_map = yield self.get_state_ids_for_events([event_id], types)
defer.returnValue(state_map[event_id]) defer.returnValue(state_map[event_id])
@cached(num_args=2, max_entries=10000) @cached(num_args=2, max_entries=100000)
def _get_state_group_for_event(self, room_id, event_id): def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(
table="event_to_state_groups", table="event_to_state_groups",

View file

@ -216,9 +216,7 @@ class StreamToken(
return self return self
def copy_and_replace(self, key, new_value): def copy_and_replace(self, key, new_value):
d = self._asdict() return self._replace(**{key: new_value})
d[key] = new_value
return StreamToken(**d)
StreamToken.START = StreamToken( StreamToken.START = StreamToken(

View file

@ -89,6 +89,11 @@ class ObservableDeferred(object):
deferred.addCallbacks(callback, errback) deferred.addCallbacks(callback, errback)
def observe(self): def observe(self):
"""Observe the underlying deferred.
Can return either a deferred if the underlying deferred is still pending
(or has failed), or the actual value. Callers may need to use maybeDeferred.
"""
if not self._result: if not self._result:
d = defer.Deferred() d = defer.Deferred()
@ -101,7 +106,7 @@ class ObservableDeferred(object):
return d return d
else: else:
success, res = self._result success, res = self._result
return defer.succeed(res) if success else defer.fail(res) return res if success else defer.fail(res)
def observers(self): def observers(self):
return self._observers return self._observers

View file

@ -15,12 +15,9 @@
import logging import logging
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError, logcontext
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
)
from . import DEBUG_CACHES, register_cache from . import DEBUG_CACHES, register_cache
@ -227,8 +224,20 @@ class _CacheDescriptorBase(object):
) )
self.num_args = num_args self.num_args = num_args
# list of the names of the args used as the cache key
self.arg_names = all_args[1:num_args + 1] self.arg_names = all_args[1:num_args + 1]
# self.arg_defaults is a map of arg name to its default value for each
# argument that has a default value
if arg_spec.defaults:
self.arg_defaults = dict(zip(
all_args[-len(arg_spec.defaults):],
arg_spec.defaults
))
else:
self.arg_defaults = {}
if "cache_context" in self.arg_names: if "cache_context" in self.arg_names:
raise Exception( raise Exception(
"cache_context arg cannot be included among the cache keys" "cache_context arg cannot be included among the cache keys"
@ -292,18 +301,31 @@ class CacheDescriptor(_CacheDescriptorBase):
iterable=self.iterable, iterable=self.iterable,
) )
def get_cache_key(args, kwargs):
"""Given some args/kwargs return a generator that resolves into
the cache_key.
We loop through each arg name, looking up if its in the `kwargs`,
otherwise using the next argument in `args`. If there are no more
args then we try looking the arg name up in the defaults
"""
pos = 0
for nm in self.arg_names:
if nm in kwargs:
yield kwargs[nm]
elif pos < len(args):
yield args[pos]
pos += 1
else:
yield self.arg_defaults[nm]
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate() # If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated # whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None) invalidate_callback = kwargs.pop("on_invalidate", None)
# Add temp cache_context so inspect.getcallargs doesn't explode cache_key = tuple(get_cache_key(args, kwargs))
if self.add_cache_context:
kwargs["cache_context"] = None
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
# Add our own `cache_context` to argument list if the wrapped function # Add our own `cache_context` to argument list if the wrapped function
# has asked for one # has asked for one
@ -328,11 +350,9 @@ class CacheDescriptor(_CacheDescriptorBase):
defer.returnValue(cached_result) defer.returnValue(cached_result)
observer.addCallback(check_result) observer.addCallback(check_result)
return preserve_context_over_deferred(observer)
except KeyError: except KeyError:
ret = defer.maybeDeferred( ret = defer.maybeDeferred(
preserve_context_over_fn, logcontext.preserve_fn(self.function_to_call),
self.function_to_call,
obj, *args, **kwargs obj, *args, **kwargs
) )
@ -342,10 +362,14 @@ class CacheDescriptor(_CacheDescriptorBase):
ret.addErrback(onErr) ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True) result_d = ObservableDeferred(ret, consumeErrors=True)
cache.set(cache_key, ret, callback=invalidate_callback) cache.set(cache_key, result_d, callback=invalidate_callback)
observer = result_d.observe()
return preserve_context_over_deferred(ret.observe()) if isinstance(observer, defer.Deferred):
return logcontext.make_deferred_yieldable(observer)
else:
return observer
wrapped.invalidate = cache.invalidate wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all wrapped.invalidate_all = cache.invalidate_all
@ -362,7 +386,11 @@ class CacheListDescriptor(_CacheDescriptorBase):
"""Wraps an existing cache to support bulk fetching of keys. """Wraps an existing cache to support bulk fetching of keys.
Given a list of keys it looks in the cache to find any hits, then passes Given a list of keys it looks in the cache to find any hits, then passes
the list of missing keys to the wrapped fucntion. the list of missing keys to the wrapped function.
Once wrapped, the function returns either a Deferred which resolves to
the list of results, or (if all results were cached), just the list of
results.
""" """
def __init__(self, orig, cached_method_name, list_name, num_args=None, def __init__(self, orig, cached_method_name, list_name, num_args=None,
@ -433,8 +461,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
args_to_call[self.list_name] = missing args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred( ret_d = defer.maybeDeferred(
preserve_context_over_fn, logcontext.preserve_fn(self.function_to_call),
self.function_to_call,
**args_to_call **args_to_call
) )
@ -443,8 +470,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
# We need to create deferreds for each arg in the list so that # We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache. # we can insert the new deferred into the cache.
for arg in missing: for arg in missing:
with PreserveLoggingContext(): observer = ret_d.observe()
observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg) observer.addCallback(lambda r, arg: r.get(arg, None), arg)
observer = ObservableDeferred(observer) observer = ObservableDeferred(observer)
@ -471,7 +497,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
results.update(res) results.update(res)
return results return results
return preserve_context_over_deferred(defer.gatherResults( return logcontext.make_deferred_yieldable(defer.gatherResults(
cached_defers.values(), cached_defers.values(),
consumeErrors=True, consumeErrors=True,
).addCallback(update_results_dict).addErrback( ).addCallback(update_results_dict).addErrback(

View file

@ -310,6 +310,10 @@ def preserve_context_over_fn(fn, *args, **kwargs):
def preserve_context_over_deferred(deferred, context=None): def preserve_context_over_deferred(deferred, context=None):
"""Given a deferred wrap it such that any callbacks added later to it will """Given a deferred wrap it such that any callbacks added later to it will
be invoked with the current context. be invoked with the current context.
Deprecated: this almost certainly doesn't do want you want, ie make
the deferred follow the synapse logcontext rules: try
``make_deferred_yieldable`` instead.
""" """
if context is None: if context is None:
context = LoggingContext.current_context() context = LoggingContext.current_context()
@ -359,6 +363,25 @@ def preserve_fn(f):
return g return g
@defer.inlineCallbacks
def make_deferred_yieldable(deferred):
"""Given a deferred, make it follow the Synapse logcontext rules:
If the deferred has completed (or is not actually a Deferred), essentially
does nothing (just returns another completed deferred with the
result/failure).
If the deferred has not yet completed, resets the logcontext before
returning a deferred. Then, when the deferred completes, restores the
current logcontext before running callbacks/errbacks.
(This is more-or-less the opposite operation to preserve_fn.)
"""
with PreserveLoggingContext():
r = yield deferred
defer.returnValue(r)
# modules to ignore in `logcontext_tracer` # modules to ignore in `logcontext_tracer`
_to_ignore = [ _to_ignore = [
"synapse.util.logcontext", "synapse.util.logcontext",

View file

@ -56,7 +56,8 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
events ([synapse.events.EventBase]): list of events to filter events ([synapse.events.EventBase]): list of events to filter
""" """
forgotten = yield preserve_context_over_deferred(defer.gatherResults([ forgotten = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(store.who_forgot_in_room)( defer.maybeDeferred(
preserve_fn(store.who_forgot_in_room),
room_id, room_id,
) )
for room_id in frozenset(e.room_id for e in events) for room_id in frozenset(e.room_id for e in events)

View file

@ -19,10 +19,12 @@ from twisted.internet import defer
from mock import Mock from mock import Mock
from tests import unittest from tests import unittest
import re
def _regex(regex, exclusive=True): def _regex(regex, exclusive=True):
return { return {
"regex": regex, "regex": re.compile(regex),
"exclusive": exclusive "exclusive": exclusive
} }

View file

@ -199,7 +199,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
a.func.prefill(("foo",), ObservableDeferred(d)) a.func.prefill(("foo",), ObservableDeferred(d))
self.assertEquals(a.func("foo").result, d.result) self.assertEquals(a.func("foo"), d.result)
self.assertEquals(callcount[0], 0) self.assertEquals(callcount[0], 0)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -12,11 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import mock import mock
from synapse.api.errors import SynapseError
from synapse.util import async
from synapse.util import logcontext
from twisted.internet import defer from twisted.internet import defer
from synapse.util.caches import descriptors from synapse.util.caches import descriptors
from tests import unittest from tests import unittest
logger = logging.getLogger(__name__)
class DescriptorTestCase(unittest.TestCase): class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
@ -84,3 +91,125 @@ class DescriptorTestCase(unittest.TestCase):
r = yield obj.fn(2, 5) r = yield obj.fn(2, 5)
self.assertEqual(r, 'chips') self.assertEqual(r, 'chips')
obj.mock.assert_not_called() obj.mock.assert_not_called()
def test_cache_logcontexts(self):
"""Check that logcontexts are set and restored correctly when
using the cache."""
complete_lookup = defer.Deferred()
class Cls(object):
@descriptors.cached()
def fn(self, arg1):
@defer.inlineCallbacks
def inner_fn():
with logcontext.PreserveLoggingContext():
yield complete_lookup
defer.returnValue(1)
return inner_fn()
@defer.inlineCallbacks
def do_lookup():
with logcontext.LoggingContext() as c1:
c1.name = "c1"
r = yield obj.fn(1)
self.assertEqual(logcontext.LoggingContext.current_context(),
c1)
defer.returnValue(r)
def check_result(r):
self.assertEqual(r, 1)
obj = Cls()
# set off a deferred which will do a cache lookup
d1 = do_lookup()
self.assertEqual(logcontext.LoggingContext.current_context(),
logcontext.LoggingContext.sentinel)
d1.addCallback(check_result)
# and another
d2 = do_lookup()
self.assertEqual(logcontext.LoggingContext.current_context(),
logcontext.LoggingContext.sentinel)
d2.addCallback(check_result)
# let the lookup complete
complete_lookup.callback(None)
return defer.gatherResults([d1, d2])
def test_cache_logcontexts_with_exception(self):
"""Check that the cache sets and restores logcontexts correctly when
the lookup function throws an exception"""
class Cls(object):
@descriptors.cached()
def fn(self, arg1):
@defer.inlineCallbacks
def inner_fn():
yield async.run_on_reactor()
raise SynapseError(400, "blah")
return inner_fn()
@defer.inlineCallbacks
def do_lookup():
with logcontext.LoggingContext() as c1:
c1.name = "c1"
try:
yield obj.fn(1)
self.fail("No exception thrown")
except SynapseError:
pass
self.assertEqual(logcontext.LoggingContext.current_context(),
c1)
obj = Cls()
# set off a deferred which will do a cache lookup
d1 = do_lookup()
self.assertEqual(logcontext.LoggingContext.current_context(),
logcontext.LoggingContext.sentinel)
return d1
@defer.inlineCallbacks
def test_cache_default_args(self):
class Cls(object):
def __init__(self):
self.mock = mock.Mock()
@descriptors.cached()
def fn(self, arg1, arg2=2, arg3=3):
return self.mock(arg1, arg2, arg3)
obj = Cls()
obj.mock.return_value = 'fish'
r = yield obj.fn(1, 2, 3)
self.assertEqual(r, 'fish')
obj.mock.assert_called_once_with(1, 2, 3)
obj.mock.reset_mock()
# a call with same params shouldn't call the mock again
r = yield obj.fn(1, 2)
self.assertEqual(r, 'fish')
obj.mock.assert_not_called()
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = 'chips'
r = yield obj.fn(2, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_called_once_with(2, 3, 3)
obj.mock.reset_mock()
# the two values should now be cached
r = yield obj.fn(1, 2)
self.assertEqual(r, 'fish')
r = yield obj.fn(2, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_not_called()

View file

@ -53,7 +53,9 @@ class SnapshotCacheTestCase(unittest.TestCase):
# before the cache expires returns a resolved deferred. # before the cache expires returns a resolved deferred.
get_result_at_11 = self.cache.get(11, "key") get_result_at_11 = self.cache.get(11, "key")
self.assertIsNotNone(get_result_at_11) self.assertIsNotNone(get_result_at_11)
self.assertTrue(get_result_at_11.called) if isinstance(get_result_at_11, Deferred):
# The cache may return the actual result rather than a deferred
self.assertTrue(get_result_at_11.called)
# Check that getting the key after the deferred has resolved # Check that getting the key after the deferred has resolved
# after the cache expires returns None # after the cache expires returns None