mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-18 17:10:43 +03:00
Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes
This commit is contained in:
commit
d7dbc56c71
22 changed files with 423 additions and 123 deletions
|
@ -146,6 +146,7 @@ To install the synapse homeserver run::
|
|||
|
||||
virtualenv -p python2.7 ~/.synapse
|
||||
source ~/.synapse/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade setuptools
|
||||
pip install https://github.com/matrix-org/synapse/tarball/master
|
||||
|
||||
|
@ -228,6 +229,7 @@ To get started, it is easiest to use the command line to register new users::
|
|||
New user localpart: erikj
|
||||
Password:
|
||||
Confirm password:
|
||||
Make admin [no]:
|
||||
Success!
|
||||
|
||||
This process uses a setting ``registration_shared_secret`` in
|
||||
|
|
|
@ -204,9 +204,14 @@ That doesn't follow the rules, but we can fix it by wrapping it with
|
|||
This technique works equally for external functions which return deferreds,
|
||||
or deferreds we have made ourselves.
|
||||
|
||||
XXX: think this is what ``preserve_context_over_deferred`` is supposed to do,
|
||||
though it is broken, in that it only restores the logcontext for the duration
|
||||
of the callbacks, which doesn't comply with the logcontext rules.
|
||||
You can also use ``logcontext.make_deferred_yieldable``, which just does the
|
||||
boilerplate for you, so the above could be written:
|
||||
|
||||
.. code:: python
|
||||
|
||||
def sleep(seconds):
|
||||
return logcontext.make_deferred_yieldable(get_sleep_deferred(seconds))
|
||||
|
||||
|
||||
Fire-and-forget
|
||||
---------------
|
||||
|
|
|
@ -23,14 +23,27 @@ import signal
|
|||
import subprocess
|
||||
import sys
|
||||
import yaml
|
||||
import errno
|
||||
import time
|
||||
|
||||
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
|
||||
|
||||
GREEN = "\x1b[1;32m"
|
||||
YELLOW = "\x1b[1;33m"
|
||||
RED = "\x1b[1;31m"
|
||||
NORMAL = "\x1b[m"
|
||||
|
||||
|
||||
def pid_running(pid):
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
return True
|
||||
except OSError, err:
|
||||
if err.errno == errno.EPERM:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def write(message, colour=NORMAL, stream=sys.stdout):
|
||||
if colour == NORMAL:
|
||||
stream.write(message + "\n")
|
||||
|
@ -38,6 +51,11 @@ def write(message, colour=NORMAL, stream=sys.stdout):
|
|||
stream.write(colour + message + NORMAL + "\n")
|
||||
|
||||
|
||||
def abort(message, colour=RED, stream=sys.stderr):
|
||||
write(message, colour, stream)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def start(configfile):
|
||||
write("Starting ...")
|
||||
args = SYNAPSE
|
||||
|
@ -45,7 +63,8 @@ def start(configfile):
|
|||
|
||||
try:
|
||||
subprocess.check_call(args)
|
||||
write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN)
|
||||
write("started synapse.app.homeserver(%r)" %
|
||||
(configfile,), colour=GREEN)
|
||||
except subprocess.CalledProcessError as e:
|
||||
write(
|
||||
"error starting (exit code: %d); see above for logs" % e.returncode,
|
||||
|
@ -76,8 +95,16 @@ def start_worker(app, configfile, worker_configfile):
|
|||
def stop(pidfile, app):
|
||||
if os.path.exists(pidfile):
|
||||
pid = int(open(pidfile).read())
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
write("stopped %s" % (app,), colour=GREEN)
|
||||
try:
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
write("stopped %s" % (app,), colour=GREEN)
|
||||
except OSError, err:
|
||||
if err.errno == errno.ESRCH:
|
||||
write("%s not running" % (app,), colour=YELLOW)
|
||||
elif err.errno == errno.EPERM:
|
||||
abort("Cannot stop %s: Operation not permitted" % (app,))
|
||||
else:
|
||||
abort("Cannot stop %s: Unknown error" % (app,))
|
||||
|
||||
|
||||
Worker = collections.namedtuple("Worker", [
|
||||
|
@ -190,7 +217,19 @@ def main():
|
|||
if start_stop_synapse:
|
||||
stop(pidfile, "synapse.app.homeserver")
|
||||
|
||||
# TODO: Wait for synapse to actually shutdown before starting it again
|
||||
# Wait for synapse to actually shutdown before starting it again
|
||||
if action == "restart":
|
||||
running_pids = []
|
||||
if start_stop_synapse and os.path.exists(pidfile):
|
||||
running_pids.append(int(open(pidfile).read()))
|
||||
for worker in workers:
|
||||
if os.path.exists(worker.pidfile):
|
||||
running_pids.append(int(open(worker.pidfile).read()))
|
||||
if len(running_pids) > 0:
|
||||
write("Waiting for process to exit before restarting...")
|
||||
for running_pid in running_pids:
|
||||
while pid_running(running_pid):
|
||||
time.sleep(0.2)
|
||||
|
||||
if action == "start" or action == "restart":
|
||||
if start_stop_synapse:
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -124,29 +125,23 @@ class ApplicationService(object):
|
|||
raise ValueError(
|
||||
"Expected bool for 'exclusive' in ns '%s'" % ns
|
||||
)
|
||||
if not isinstance(regex_obj.get("regex"), basestring):
|
||||
regex = regex_obj.get("regex")
|
||||
if isinstance(regex, basestring):
|
||||
regex_obj["regex"] = re.compile(regex) # Pre-compile regex
|
||||
else:
|
||||
raise ValueError(
|
||||
"Expected string for 'regex' in ns '%s'" % ns
|
||||
)
|
||||
return namespaces
|
||||
|
||||
def _matches_regex(self, test_string, namespace_key, return_obj=False):
|
||||
if not isinstance(test_string, basestring):
|
||||
logger.error(
|
||||
"Expected a string to test regex against, but got %s",
|
||||
test_string
|
||||
)
|
||||
return False
|
||||
|
||||
def _matches_regex(self, test_string, namespace_key):
|
||||
for regex_obj in self.namespaces[namespace_key]:
|
||||
if re.match(regex_obj["regex"], test_string):
|
||||
if return_obj:
|
||||
return regex_obj
|
||||
return True
|
||||
return False
|
||||
if regex_obj["regex"].match(test_string):
|
||||
return regex_obj
|
||||
return None
|
||||
|
||||
def _is_exclusive(self, ns_key, test_string):
|
||||
regex_obj = self._matches_regex(test_string, ns_key, return_obj=True)
|
||||
regex_obj = self._matches_regex(test_string, ns_key)
|
||||
if regex_obj:
|
||||
return regex_obj["exclusive"]
|
||||
return False
|
||||
|
@ -166,7 +161,14 @@ class ApplicationService(object):
|
|||
if not store:
|
||||
defer.returnValue(False)
|
||||
|
||||
member_list = yield store.get_users_in_room(event.room_id)
|
||||
does_match = yield self._matches_user_in_member_list(event.room_id, store)
|
||||
defer.returnValue(does_match)
|
||||
|
||||
@cachedInlineCallbacks(num_args=1, cache_context=True)
|
||||
def _matches_user_in_member_list(self, room_id, store, cache_context):
|
||||
member_list = yield store.get_users_in_room(
|
||||
room_id, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
|
||||
# check joined member events
|
||||
for user_id in member_list:
|
||||
|
@ -219,10 +221,10 @@ class ApplicationService(object):
|
|||
)
|
||||
|
||||
def is_interested_in_alias(self, alias):
|
||||
return self._matches_regex(alias, ApplicationService.NS_ALIASES)
|
||||
return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
|
||||
|
||||
def is_interested_in_room(self, room_id):
|
||||
return self._matches_regex(room_id, ApplicationService.NS_ROOMS)
|
||||
return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
|
||||
|
||||
def is_exclusive_user(self, user_id):
|
||||
return (
|
||||
|
|
|
@ -54,6 +54,7 @@ class FederationRemoteSendQueue(object):
|
|||
def __init__(self, hs):
|
||||
self.server_name = hs.hostname
|
||||
self.clock = hs.get_clock()
|
||||
self.notifier = hs.get_notifier()
|
||||
|
||||
self.presence_map = {}
|
||||
self.presence_changed = sorteddict()
|
||||
|
@ -186,6 +187,8 @@ class FederationRemoteSendQueue(object):
|
|||
else:
|
||||
self.edus[pos] = edu
|
||||
|
||||
self.notifier.on_new_replication_data()
|
||||
|
||||
def send_presence(self, destination, states):
|
||||
"""As per TransactionQueue"""
|
||||
pos = self._next_pos()
|
||||
|
@ -199,16 +202,20 @@ class FederationRemoteSendQueue(object):
|
|||
(destination, state.user_id) for state in states
|
||||
]
|
||||
|
||||
self.notifier.on_new_replication_data()
|
||||
|
||||
def send_failure(self, failure, destination):
|
||||
"""As per TransactionQueue"""
|
||||
pos = self._next_pos()
|
||||
|
||||
self.failures[pos] = (destination, str(failure))
|
||||
self.notifier.on_new_replication_data()
|
||||
|
||||
def send_device_messages(self, destination):
|
||||
"""As per TransactionQueue"""
|
||||
pos = self._next_pos()
|
||||
self.device_messages[pos] = destination
|
||||
self.notifier.on_new_replication_data()
|
||||
|
||||
def get_current_token(self):
|
||||
return self.pos - 1
|
||||
|
|
|
@ -22,7 +22,7 @@ from .units import Transaction, Edu
|
|||
from synapse.api.errors import HttpResponseException
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.handlers.presence import format_user_presence_state
|
||||
|
@ -303,12 +303,19 @@ class TransactionQueue(object):
|
|||
)
|
||||
return
|
||||
|
||||
pending_pdus = []
|
||||
try:
|
||||
self.pending_transactions[destination] = 1
|
||||
|
||||
# This will throw if we wouldn't retry. We do this here so we fail
|
||||
# quickly, but we will later check this again in the http client,
|
||||
# hence why we throw the result away.
|
||||
yield get_retry_limiter(destination, self.clock, self.store)
|
||||
|
||||
# XXX: what's this for?
|
||||
yield run_on_reactor()
|
||||
|
||||
pending_pdus = []
|
||||
while True:
|
||||
device_message_edus, device_stream_id, dev_list_id = (
|
||||
yield self._get_new_device_messages(destination)
|
||||
|
@ -397,7 +404,7 @@ class TransactionQueue(object):
|
|||
destination,
|
||||
e,
|
||||
)
|
||||
for p in pending_pdus:
|
||||
for p, _ in pending_pdus:
|
||||
logger.info("Failed to send event %s to %s", p.event_id,
|
||||
destination)
|
||||
finally:
|
||||
|
|
|
@ -17,6 +17,7 @@ import logging
|
|||
import re
|
||||
|
||||
from synapse.types import UserID
|
||||
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -125,6 +126,11 @@ class PushRuleEvaluatorForEvent(object):
|
|||
return self._value_cache.get(dotted_key, None)
|
||||
|
||||
|
||||
# Caches (glob, word_boundary) -> regex for push. See _glob_matches
|
||||
regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
|
||||
register_cache("regex_push_cache", regex_cache)
|
||||
|
||||
|
||||
def _glob_matches(glob, value, word_boundary=False):
|
||||
"""Tests if value matches glob.
|
||||
|
||||
|
@ -137,46 +143,63 @@ def _glob_matches(glob, value, word_boundary=False):
|
|||
Returns:
|
||||
bool
|
||||
"""
|
||||
|
||||
try:
|
||||
if IS_GLOB.search(glob):
|
||||
r = re.escape(glob)
|
||||
|
||||
r = r.replace(r'\*', '.*?')
|
||||
r = r.replace(r'\?', '.')
|
||||
|
||||
# handle [abc], [a-z] and [!a-z] style ranges.
|
||||
r = GLOB_REGEX.sub(
|
||||
lambda x: (
|
||||
'[%s%s]' % (
|
||||
x.group(1) and '^' or '',
|
||||
x.group(2).replace(r'\\\-', '-')
|
||||
)
|
||||
),
|
||||
r,
|
||||
)
|
||||
if word_boundary:
|
||||
r = r"\b%s\b" % (r,)
|
||||
r = _compile_regex(r)
|
||||
|
||||
return r.search(value)
|
||||
else:
|
||||
r = r + "$"
|
||||
r = _compile_regex(r)
|
||||
|
||||
return r.match(value)
|
||||
elif word_boundary:
|
||||
r = re.escape(glob)
|
||||
r = r"\b%s\b" % (r,)
|
||||
r = _compile_regex(r)
|
||||
|
||||
return r.search(value)
|
||||
else:
|
||||
return value.lower() == glob.lower()
|
||||
r = regex_cache.get((glob, word_boundary), None)
|
||||
if not r:
|
||||
r = _glob_to_re(glob, word_boundary)
|
||||
regex_cache[(glob, word_boundary)] = r
|
||||
return r.search(value)
|
||||
except re.error:
|
||||
logger.warn("Failed to parse glob to regex: %r", glob)
|
||||
return False
|
||||
|
||||
|
||||
def _glob_to_re(glob, word_boundary):
|
||||
"""Generates regex for a given glob.
|
||||
|
||||
Args:
|
||||
glob (string)
|
||||
word_boundary (bool): Whether to match against word boundaries or entire
|
||||
string. Defaults to False.
|
||||
|
||||
Returns:
|
||||
regex object
|
||||
"""
|
||||
if IS_GLOB.search(glob):
|
||||
r = re.escape(glob)
|
||||
|
||||
r = r.replace(r'\*', '.*?')
|
||||
r = r.replace(r'\?', '.')
|
||||
|
||||
# handle [abc], [a-z] and [!a-z] style ranges.
|
||||
r = GLOB_REGEX.sub(
|
||||
lambda x: (
|
||||
'[%s%s]' % (
|
||||
x.group(1) and '^' or '',
|
||||
x.group(2).replace(r'\\\-', '-')
|
||||
)
|
||||
),
|
||||
r,
|
||||
)
|
||||
if word_boundary:
|
||||
r = r"\b%s\b" % (r,)
|
||||
|
||||
return re.compile(r, flags=re.IGNORECASE)
|
||||
else:
|
||||
r = "^" + r + "$"
|
||||
|
||||
return re.compile(r, flags=re.IGNORECASE)
|
||||
elif word_boundary:
|
||||
r = re.escape(glob)
|
||||
r = r"\b%s\b" % (r,)
|
||||
|
||||
return re.compile(r, flags=re.IGNORECASE)
|
||||
else:
|
||||
r = "^" + re.escape(glob) + "$"
|
||||
return re.compile(r, flags=re.IGNORECASE)
|
||||
|
||||
|
||||
def _flatten_dict(d, prefix=[], result={}):
|
||||
for key, value in d.items():
|
||||
if isinstance(value, basestring):
|
||||
|
@ -185,16 +208,3 @@ def _flatten_dict(d, prefix=[], result={}):
|
|||
_flatten_dict(value, prefix=(prefix + [key]), result=result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
regex_cache = LruCache(5000)
|
||||
|
||||
|
||||
def _compile_regex(regex_str):
|
||||
r = regex_cache.get(regex_str, None)
|
||||
if r:
|
||||
return r
|
||||
|
||||
r = re.compile(regex_str, flags=re.IGNORECASE)
|
||||
regex_cache[regex_str] = r
|
||||
return r
|
||||
|
|
|
@ -17,15 +17,12 @@ from twisted.internet import defer
|
|||
from synapse.push.presentable_names import (
|
||||
calculate_room_name, name_from_member_event
|
||||
)
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_badge_count(store, user_id):
|
||||
invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
|
||||
preserve_fn(store.get_invited_rooms_for_user)(user_id),
|
||||
preserve_fn(store.get_rooms_for_user)(user_id),
|
||||
], consumeErrors=True))
|
||||
invites = yield store.get_invited_rooms_for_user(user_id)
|
||||
joins = yield store.get_rooms_for_user(user_id)
|
||||
|
||||
my_receipts_by_room = yield store.get_receipts_for_user(
|
||||
user_id, "m.read",
|
||||
|
|
|
@ -195,11 +195,11 @@ class StateHandler(object):
|
|||
(s.type, s.state_key): s.event_id for s in old_state
|
||||
}
|
||||
if event.is_state():
|
||||
context.current_state_events = dict(context.prev_state_ids)
|
||||
context.current_state_ids = dict(context.prev_state_ids)
|
||||
key = (event.type, event.state_key)
|
||||
context.current_state_events[key] = event.event_id
|
||||
context.current_state_ids[key] = event.event_id
|
||||
else:
|
||||
context.current_state_events = context.prev_state_ids
|
||||
context.current_state_ids = context.prev_state_ids
|
||||
else:
|
||||
context.current_state_ids = {}
|
||||
context.prev_state_ids = {}
|
||||
|
|
|
@ -329,6 +329,7 @@ class DeviceStore(SQLBaseStore):
|
|||
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
|
||||
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
|
||||
GROUP BY user_id, device_id
|
||||
LIMIT 20
|
||||
"""
|
||||
txn.execute(
|
||||
sql, (destination, from_stream_id, now_stream_id, False)
|
||||
|
@ -339,6 +340,9 @@ class DeviceStore(SQLBaseStore):
|
|||
if not query_map:
|
||||
return (now_stream_id, [])
|
||||
|
||||
if len(query_map) >= 20:
|
||||
now_stream_id = max(stream_id for stream_id in query_map.itervalues())
|
||||
|
||||
devices = self._get_e2e_device_keys_txn(
|
||||
txn, query_map.keys(), include_all_devices=True
|
||||
)
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
import ujson as json
|
||||
|
||||
|
@ -120,24 +122,63 @@ class EndToEndKeyStore(SQLBaseStore):
|
|||
|
||||
return result
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
|
||||
"""Insert some new one time keys for a device.
|
||||
|
||||
Checks if any of the keys are already inserted, if they are then check
|
||||
if they match. If they don't then we raise an error.
|
||||
"""
|
||||
|
||||
# First we check if we have already persisted any of the keys.
|
||||
rows = yield self._simple_select_many_batch(
|
||||
table="e2e_one_time_keys_json",
|
||||
column="key_id",
|
||||
iterable=[key_id for _, key_id, _ in key_list],
|
||||
retcols=("algorithm", "key_id", "key_json",),
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
desc="add_e2e_one_time_keys_check",
|
||||
)
|
||||
|
||||
existing_key_map = {
|
||||
(row["algorithm"], row["key_id"]): row["key_json"] for row in rows
|
||||
}
|
||||
|
||||
new_keys = [] # Keys that we need to insert
|
||||
for algorithm, key_id, json_bytes in key_list:
|
||||
ex_bytes = existing_key_map.get((algorithm, key_id), None)
|
||||
if ex_bytes:
|
||||
if json_bytes != ex_bytes:
|
||||
raise SynapseError(
|
||||
400, "One time key with key_id %r already exists" % (key_id,)
|
||||
)
|
||||
else:
|
||||
new_keys.append((algorithm, key_id, json_bytes))
|
||||
|
||||
def _add_e2e_one_time_keys(txn):
|
||||
for (algorithm, key_id, json_bytes) in key_list:
|
||||
self._simple_upsert_txn(
|
||||
txn, table="e2e_one_time_keys_json",
|
||||
keyvalues={
|
||||
# We are protected from race between lookup and insertion due to
|
||||
# a unique constraint. If there is a race of two calls to
|
||||
# `add_e2e_one_time_keys` then they'll conflict and we will only
|
||||
# insert one set.
|
||||
self._simple_insert_many_txn(
|
||||
txn, table="e2e_one_time_keys_json",
|
||||
values=[
|
||||
{
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"algorithm": algorithm,
|
||||
"key_id": key_id,
|
||||
},
|
||||
values={
|
||||
"ts_added_ms": time_now,
|
||||
"key_json": json_bytes,
|
||||
}
|
||||
)
|
||||
return self.runInteraction(
|
||||
"add_e2e_one_time_keys", _add_e2e_one_time_keys
|
||||
for algorithm, key_id, json_bytes in new_keys
|
||||
],
|
||||
)
|
||||
yield self.runInteraction(
|
||||
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
|
||||
)
|
||||
|
||||
def count_e2e_one_time_keys(self, user_id, device_id):
|
||||
|
|
|
@ -152,7 +152,7 @@ class EventFederationStore(SQLBaseStore):
|
|||
txn.execute(sql, (room_id, ))
|
||||
|
||||
results = []
|
||||
for event_id, depth in txn:
|
||||
for event_id, depth in txn.fetchall():
|
||||
hashes = self._get_event_reference_hashes_txn(txn, event_id)
|
||||
prev_hashes = {
|
||||
k: encode_base64(v) for k, v in hashes.items()
|
||||
|
|
|
@ -496,7 +496,7 @@ class StateStore(SQLBaseStore):
|
|||
state_map = yield self.get_state_ids_for_events([event_id], types)
|
||||
defer.returnValue(state_map[event_id])
|
||||
|
||||
@cached(num_args=2, max_entries=10000)
|
||||
@cached(num_args=2, max_entries=100000)
|
||||
def _get_state_group_for_event(self, room_id, event_id):
|
||||
return self._simple_select_one_onecol(
|
||||
table="event_to_state_groups",
|
||||
|
|
|
@ -216,9 +216,7 @@ class StreamToken(
|
|||
return self
|
||||
|
||||
def copy_and_replace(self, key, new_value):
|
||||
d = self._asdict()
|
||||
d[key] = new_value
|
||||
return StreamToken(**d)
|
||||
return self._replace(**{key: new_value})
|
||||
|
||||
|
||||
StreamToken.START = StreamToken(
|
||||
|
|
|
@ -89,6 +89,11 @@ class ObservableDeferred(object):
|
|||
deferred.addCallbacks(callback, errback)
|
||||
|
||||
def observe(self):
|
||||
"""Observe the underlying deferred.
|
||||
|
||||
Can return either a deferred if the underlying deferred is still pending
|
||||
(or has failed), or the actual value. Callers may need to use maybeDeferred.
|
||||
"""
|
||||
if not self._result:
|
||||
d = defer.Deferred()
|
||||
|
||||
|
@ -101,7 +106,7 @@ class ObservableDeferred(object):
|
|||
return d
|
||||
else:
|
||||
success, res = self._result
|
||||
return defer.succeed(res) if success else defer.fail(res)
|
||||
return res if success else defer.fail(res)
|
||||
|
||||
def observers(self):
|
||||
return self._observers
|
||||
|
|
|
@ -15,12 +15,9 @@
|
|||
import logging
|
||||
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util import unwrapFirstError, logcontext
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
|
||||
from synapse.util.logcontext import (
|
||||
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
|
||||
)
|
||||
|
||||
from . import DEBUG_CACHES, register_cache
|
||||
|
||||
|
@ -227,8 +224,20 @@ class _CacheDescriptorBase(object):
|
|||
)
|
||||
|
||||
self.num_args = num_args
|
||||
|
||||
# list of the names of the args used as the cache key
|
||||
self.arg_names = all_args[1:num_args + 1]
|
||||
|
||||
# self.arg_defaults is a map of arg name to its default value for each
|
||||
# argument that has a default value
|
||||
if arg_spec.defaults:
|
||||
self.arg_defaults = dict(zip(
|
||||
all_args[-len(arg_spec.defaults):],
|
||||
arg_spec.defaults
|
||||
))
|
||||
else:
|
||||
self.arg_defaults = {}
|
||||
|
||||
if "cache_context" in self.arg_names:
|
||||
raise Exception(
|
||||
"cache_context arg cannot be included among the cache keys"
|
||||
|
@ -292,18 +301,31 @@ class CacheDescriptor(_CacheDescriptorBase):
|
|||
iterable=self.iterable,
|
||||
)
|
||||
|
||||
def get_cache_key(args, kwargs):
|
||||
"""Given some args/kwargs return a generator that resolves into
|
||||
the cache_key.
|
||||
|
||||
We loop through each arg name, looking up if its in the `kwargs`,
|
||||
otherwise using the next argument in `args`. If there are no more
|
||||
args then we try looking the arg name up in the defaults
|
||||
"""
|
||||
pos = 0
|
||||
for nm in self.arg_names:
|
||||
if nm in kwargs:
|
||||
yield kwargs[nm]
|
||||
elif pos < len(args):
|
||||
yield args[pos]
|
||||
pos += 1
|
||||
else:
|
||||
yield self.arg_defaults[nm]
|
||||
|
||||
@functools.wraps(self.orig)
|
||||
def wrapped(*args, **kwargs):
|
||||
# If we're passed a cache_context then we'll want to call its invalidate()
|
||||
# whenever we are invalidated
|
||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||
|
||||
# Add temp cache_context so inspect.getcallargs doesn't explode
|
||||
if self.add_cache_context:
|
||||
kwargs["cache_context"] = None
|
||||
|
||||
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
||||
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
|
||||
cache_key = tuple(get_cache_key(args, kwargs))
|
||||
|
||||
# Add our own `cache_context` to argument list if the wrapped function
|
||||
# has asked for one
|
||||
|
@ -328,11 +350,9 @@ class CacheDescriptor(_CacheDescriptorBase):
|
|||
defer.returnValue(cached_result)
|
||||
observer.addCallback(check_result)
|
||||
|
||||
return preserve_context_over_deferred(observer)
|
||||
except KeyError:
|
||||
ret = defer.maybeDeferred(
|
||||
preserve_context_over_fn,
|
||||
self.function_to_call,
|
||||
logcontext.preserve_fn(self.function_to_call),
|
||||
obj, *args, **kwargs
|
||||
)
|
||||
|
||||
|
@ -342,10 +362,14 @@ class CacheDescriptor(_CacheDescriptorBase):
|
|||
|
||||
ret.addErrback(onErr)
|
||||
|
||||
ret = ObservableDeferred(ret, consumeErrors=True)
|
||||
cache.set(cache_key, ret, callback=invalidate_callback)
|
||||
result_d = ObservableDeferred(ret, consumeErrors=True)
|
||||
cache.set(cache_key, result_d, callback=invalidate_callback)
|
||||
observer = result_d.observe()
|
||||
|
||||
return preserve_context_over_deferred(ret.observe())
|
||||
if isinstance(observer, defer.Deferred):
|
||||
return logcontext.make_deferred_yieldable(observer)
|
||||
else:
|
||||
return observer
|
||||
|
||||
wrapped.invalidate = cache.invalidate
|
||||
wrapped.invalidate_all = cache.invalidate_all
|
||||
|
@ -362,7 +386,11 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
|||
"""Wraps an existing cache to support bulk fetching of keys.
|
||||
|
||||
Given a list of keys it looks in the cache to find any hits, then passes
|
||||
the list of missing keys to the wrapped fucntion.
|
||||
the list of missing keys to the wrapped function.
|
||||
|
||||
Once wrapped, the function returns either a Deferred which resolves to
|
||||
the list of results, or (if all results were cached), just the list of
|
||||
results.
|
||||
"""
|
||||
|
||||
def __init__(self, orig, cached_method_name, list_name, num_args=None,
|
||||
|
@ -433,8 +461,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
|||
args_to_call[self.list_name] = missing
|
||||
|
||||
ret_d = defer.maybeDeferred(
|
||||
preserve_context_over_fn,
|
||||
self.function_to_call,
|
||||
logcontext.preserve_fn(self.function_to_call),
|
||||
**args_to_call
|
||||
)
|
||||
|
||||
|
@ -443,8 +470,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
|||
# We need to create deferreds for each arg in the list so that
|
||||
# we can insert the new deferred into the cache.
|
||||
for arg in missing:
|
||||
with PreserveLoggingContext():
|
||||
observer = ret_d.observe()
|
||||
observer = ret_d.observe()
|
||||
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
|
||||
|
||||
observer = ObservableDeferred(observer)
|
||||
|
@ -471,7 +497,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
|||
results.update(res)
|
||||
return results
|
||||
|
||||
return preserve_context_over_deferred(defer.gatherResults(
|
||||
return logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||
cached_defers.values(),
|
||||
consumeErrors=True,
|
||||
).addCallback(update_results_dict).addErrback(
|
||||
|
|
|
@ -310,6 +310,10 @@ def preserve_context_over_fn(fn, *args, **kwargs):
|
|||
def preserve_context_over_deferred(deferred, context=None):
|
||||
"""Given a deferred wrap it such that any callbacks added later to it will
|
||||
be invoked with the current context.
|
||||
|
||||
Deprecated: this almost certainly doesn't do want you want, ie make
|
||||
the deferred follow the synapse logcontext rules: try
|
||||
``make_deferred_yieldable`` instead.
|
||||
"""
|
||||
if context is None:
|
||||
context = LoggingContext.current_context()
|
||||
|
@ -359,6 +363,25 @@ def preserve_fn(f):
|
|||
return g
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def make_deferred_yieldable(deferred):
|
||||
"""Given a deferred, make it follow the Synapse logcontext rules:
|
||||
|
||||
If the deferred has completed (or is not actually a Deferred), essentially
|
||||
does nothing (just returns another completed deferred with the
|
||||
result/failure).
|
||||
|
||||
If the deferred has not yet completed, resets the logcontext before
|
||||
returning a deferred. Then, when the deferred completes, restores the
|
||||
current logcontext before running callbacks/errbacks.
|
||||
|
||||
(This is more-or-less the opposite operation to preserve_fn.)
|
||||
"""
|
||||
with PreserveLoggingContext():
|
||||
r = yield deferred
|
||||
defer.returnValue(r)
|
||||
|
||||
|
||||
# modules to ignore in `logcontext_tracer`
|
||||
_to_ignore = [
|
||||
"synapse.util.logcontext",
|
||||
|
|
|
@ -56,7 +56,8 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
|
|||
events ([synapse.events.EventBase]): list of events to filter
|
||||
"""
|
||||
forgotten = yield preserve_context_over_deferred(defer.gatherResults([
|
||||
preserve_fn(store.who_forgot_in_room)(
|
||||
defer.maybeDeferred(
|
||||
preserve_fn(store.who_forgot_in_room),
|
||||
room_id,
|
||||
)
|
||||
for room_id in frozenset(e.room_id for e in events)
|
||||
|
|
|
@ -19,10 +19,12 @@ from twisted.internet import defer
|
|||
from mock import Mock
|
||||
from tests import unittest
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def _regex(regex, exclusive=True):
|
||||
return {
|
||||
"regex": regex,
|
||||
"regex": re.compile(regex),
|
||||
"exclusive": exclusive
|
||||
}
|
||||
|
||||
|
|
|
@ -199,7 +199,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
|
|||
|
||||
a.func.prefill(("foo",), ObservableDeferred(d))
|
||||
|
||||
self.assertEquals(a.func("foo").result, d.result)
|
||||
self.assertEquals(a.func("foo"), d.result)
|
||||
self.assertEquals(callcount[0], 0)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -12,11 +12,18 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
import mock
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.util import async
|
||||
from synapse.util import logcontext
|
||||
from twisted.internet import defer
|
||||
from synapse.util.caches import descriptors
|
||||
from tests import unittest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DescriptorTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
|
@ -84,3 +91,125 @@ class DescriptorTestCase(unittest.TestCase):
|
|||
r = yield obj.fn(2, 5)
|
||||
self.assertEqual(r, 'chips')
|
||||
obj.mock.assert_not_called()
|
||||
|
||||
def test_cache_logcontexts(self):
|
||||
"""Check that logcontexts are set and restored correctly when
|
||||
using the cache."""
|
||||
|
||||
complete_lookup = defer.Deferred()
|
||||
|
||||
class Cls(object):
|
||||
@descriptors.cached()
|
||||
def fn(self, arg1):
|
||||
@defer.inlineCallbacks
|
||||
def inner_fn():
|
||||
with logcontext.PreserveLoggingContext():
|
||||
yield complete_lookup
|
||||
defer.returnValue(1)
|
||||
|
||||
return inner_fn()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_lookup():
|
||||
with logcontext.LoggingContext() as c1:
|
||||
c1.name = "c1"
|
||||
r = yield obj.fn(1)
|
||||
self.assertEqual(logcontext.LoggingContext.current_context(),
|
||||
c1)
|
||||
defer.returnValue(r)
|
||||
|
||||
def check_result(r):
|
||||
self.assertEqual(r, 1)
|
||||
|
||||
obj = Cls()
|
||||
|
||||
# set off a deferred which will do a cache lookup
|
||||
d1 = do_lookup()
|
||||
self.assertEqual(logcontext.LoggingContext.current_context(),
|
||||
logcontext.LoggingContext.sentinel)
|
||||
d1.addCallback(check_result)
|
||||
|
||||
# and another
|
||||
d2 = do_lookup()
|
||||
self.assertEqual(logcontext.LoggingContext.current_context(),
|
||||
logcontext.LoggingContext.sentinel)
|
||||
d2.addCallback(check_result)
|
||||
|
||||
# let the lookup complete
|
||||
complete_lookup.callback(None)
|
||||
|
||||
return defer.gatherResults([d1, d2])
|
||||
|
||||
def test_cache_logcontexts_with_exception(self):
|
||||
"""Check that the cache sets and restores logcontexts correctly when
|
||||
the lookup function throws an exception"""
|
||||
|
||||
class Cls(object):
|
||||
@descriptors.cached()
|
||||
def fn(self, arg1):
|
||||
@defer.inlineCallbacks
|
||||
def inner_fn():
|
||||
yield async.run_on_reactor()
|
||||
raise SynapseError(400, "blah")
|
||||
|
||||
return inner_fn()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_lookup():
|
||||
with logcontext.LoggingContext() as c1:
|
||||
c1.name = "c1"
|
||||
try:
|
||||
yield obj.fn(1)
|
||||
self.fail("No exception thrown")
|
||||
except SynapseError:
|
||||
pass
|
||||
|
||||
self.assertEqual(logcontext.LoggingContext.current_context(),
|
||||
c1)
|
||||
|
||||
obj = Cls()
|
||||
|
||||
# set off a deferred which will do a cache lookup
|
||||
d1 = do_lookup()
|
||||
self.assertEqual(logcontext.LoggingContext.current_context(),
|
||||
logcontext.LoggingContext.sentinel)
|
||||
|
||||
return d1
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_cache_default_args(self):
|
||||
class Cls(object):
|
||||
def __init__(self):
|
||||
self.mock = mock.Mock()
|
||||
|
||||
@descriptors.cached()
|
||||
def fn(self, arg1, arg2=2, arg3=3):
|
||||
return self.mock(arg1, arg2, arg3)
|
||||
|
||||
obj = Cls()
|
||||
|
||||
obj.mock.return_value = 'fish'
|
||||
r = yield obj.fn(1, 2, 3)
|
||||
self.assertEqual(r, 'fish')
|
||||
obj.mock.assert_called_once_with(1, 2, 3)
|
||||
obj.mock.reset_mock()
|
||||
|
||||
# a call with same params shouldn't call the mock again
|
||||
r = yield obj.fn(1, 2)
|
||||
self.assertEqual(r, 'fish')
|
||||
obj.mock.assert_not_called()
|
||||
obj.mock.reset_mock()
|
||||
|
||||
# a call with different params should call the mock again
|
||||
obj.mock.return_value = 'chips'
|
||||
r = yield obj.fn(2, 3)
|
||||
self.assertEqual(r, 'chips')
|
||||
obj.mock.assert_called_once_with(2, 3, 3)
|
||||
obj.mock.reset_mock()
|
||||
|
||||
# the two values should now be cached
|
||||
r = yield obj.fn(1, 2)
|
||||
self.assertEqual(r, 'fish')
|
||||
r = yield obj.fn(2, 3)
|
||||
self.assertEqual(r, 'chips')
|
||||
obj.mock.assert_not_called()
|
||||
|
|
|
@ -53,7 +53,9 @@ class SnapshotCacheTestCase(unittest.TestCase):
|
|||
# before the cache expires returns a resolved deferred.
|
||||
get_result_at_11 = self.cache.get(11, "key")
|
||||
self.assertIsNotNone(get_result_at_11)
|
||||
self.assertTrue(get_result_at_11.called)
|
||||
if isinstance(get_result_at_11, Deferred):
|
||||
# The cache may return the actual result rather than a deferred
|
||||
self.assertTrue(get_result_at_11.called)
|
||||
|
||||
# Check that getting the key after the deferred has resolved
|
||||
# after the cache expires returns None
|
||||
|
|
Loading…
Reference in a new issue