mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-18 17:10:43 +03:00
Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes
This commit is contained in:
commit
d7dbc56c71
22 changed files with 423 additions and 123 deletions
|
@ -146,6 +146,7 @@ To install the synapse homeserver run::
|
||||||
|
|
||||||
virtualenv -p python2.7 ~/.synapse
|
virtualenv -p python2.7 ~/.synapse
|
||||||
source ~/.synapse/bin/activate
|
source ~/.synapse/bin/activate
|
||||||
|
pip install --upgrade pip
|
||||||
pip install --upgrade setuptools
|
pip install --upgrade setuptools
|
||||||
pip install https://github.com/matrix-org/synapse/tarball/master
|
pip install https://github.com/matrix-org/synapse/tarball/master
|
||||||
|
|
||||||
|
@ -228,6 +229,7 @@ To get started, it is easiest to use the command line to register new users::
|
||||||
New user localpart: erikj
|
New user localpart: erikj
|
||||||
Password:
|
Password:
|
||||||
Confirm password:
|
Confirm password:
|
||||||
|
Make admin [no]:
|
||||||
Success!
|
Success!
|
||||||
|
|
||||||
This process uses a setting ``registration_shared_secret`` in
|
This process uses a setting ``registration_shared_secret`` in
|
||||||
|
|
|
@ -204,9 +204,14 @@ That doesn't follow the rules, but we can fix it by wrapping it with
|
||||||
This technique works equally for external functions which return deferreds,
|
This technique works equally for external functions which return deferreds,
|
||||||
or deferreds we have made ourselves.
|
or deferreds we have made ourselves.
|
||||||
|
|
||||||
XXX: think this is what ``preserve_context_over_deferred`` is supposed to do,
|
You can also use ``logcontext.make_deferred_yieldable``, which just does the
|
||||||
though it is broken, in that it only restores the logcontext for the duration
|
boilerplate for you, so the above could be written:
|
||||||
of the callbacks, which doesn't comply with the logcontext rules.
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
def sleep(seconds):
|
||||||
|
return logcontext.make_deferred_yieldable(get_sleep_deferred(seconds))
|
||||||
|
|
||||||
|
|
||||||
Fire-and-forget
|
Fire-and-forget
|
||||||
---------------
|
---------------
|
||||||
|
|
|
@ -23,14 +23,27 @@ import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import yaml
|
import yaml
|
||||||
|
import errno
|
||||||
|
import time
|
||||||
|
|
||||||
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
|
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
|
||||||
|
|
||||||
GREEN = "\x1b[1;32m"
|
GREEN = "\x1b[1;32m"
|
||||||
|
YELLOW = "\x1b[1;33m"
|
||||||
RED = "\x1b[1;31m"
|
RED = "\x1b[1;31m"
|
||||||
NORMAL = "\x1b[m"
|
NORMAL = "\x1b[m"
|
||||||
|
|
||||||
|
|
||||||
|
def pid_running(pid):
|
||||||
|
try:
|
||||||
|
os.kill(pid, 0)
|
||||||
|
return True
|
||||||
|
except OSError, err:
|
||||||
|
if err.errno == errno.EPERM:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def write(message, colour=NORMAL, stream=sys.stdout):
|
def write(message, colour=NORMAL, stream=sys.stdout):
|
||||||
if colour == NORMAL:
|
if colour == NORMAL:
|
||||||
stream.write(message + "\n")
|
stream.write(message + "\n")
|
||||||
|
@ -38,6 +51,11 @@ def write(message, colour=NORMAL, stream=sys.stdout):
|
||||||
stream.write(colour + message + NORMAL + "\n")
|
stream.write(colour + message + NORMAL + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def abort(message, colour=RED, stream=sys.stderr):
|
||||||
|
write(message, colour, stream)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def start(configfile):
|
def start(configfile):
|
||||||
write("Starting ...")
|
write("Starting ...")
|
||||||
args = SYNAPSE
|
args = SYNAPSE
|
||||||
|
@ -45,7 +63,8 @@ def start(configfile):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
subprocess.check_call(args)
|
subprocess.check_call(args)
|
||||||
write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN)
|
write("started synapse.app.homeserver(%r)" %
|
||||||
|
(configfile,), colour=GREEN)
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
write(
|
write(
|
||||||
"error starting (exit code: %d); see above for logs" % e.returncode,
|
"error starting (exit code: %d); see above for logs" % e.returncode,
|
||||||
|
@ -76,8 +95,16 @@ def start_worker(app, configfile, worker_configfile):
|
||||||
def stop(pidfile, app):
|
def stop(pidfile, app):
|
||||||
if os.path.exists(pidfile):
|
if os.path.exists(pidfile):
|
||||||
pid = int(open(pidfile).read())
|
pid = int(open(pidfile).read())
|
||||||
os.kill(pid, signal.SIGTERM)
|
try:
|
||||||
write("stopped %s" % (app,), colour=GREEN)
|
os.kill(pid, signal.SIGTERM)
|
||||||
|
write("stopped %s" % (app,), colour=GREEN)
|
||||||
|
except OSError, err:
|
||||||
|
if err.errno == errno.ESRCH:
|
||||||
|
write("%s not running" % (app,), colour=YELLOW)
|
||||||
|
elif err.errno == errno.EPERM:
|
||||||
|
abort("Cannot stop %s: Operation not permitted" % (app,))
|
||||||
|
else:
|
||||||
|
abort("Cannot stop %s: Unknown error" % (app,))
|
||||||
|
|
||||||
|
|
||||||
Worker = collections.namedtuple("Worker", [
|
Worker = collections.namedtuple("Worker", [
|
||||||
|
@ -190,7 +217,19 @@ def main():
|
||||||
if start_stop_synapse:
|
if start_stop_synapse:
|
||||||
stop(pidfile, "synapse.app.homeserver")
|
stop(pidfile, "synapse.app.homeserver")
|
||||||
|
|
||||||
# TODO: Wait for synapse to actually shutdown before starting it again
|
# Wait for synapse to actually shutdown before starting it again
|
||||||
|
if action == "restart":
|
||||||
|
running_pids = []
|
||||||
|
if start_stop_synapse and os.path.exists(pidfile):
|
||||||
|
running_pids.append(int(open(pidfile).read()))
|
||||||
|
for worker in workers:
|
||||||
|
if os.path.exists(worker.pidfile):
|
||||||
|
running_pids.append(int(open(worker.pidfile).read()))
|
||||||
|
if len(running_pids) > 0:
|
||||||
|
write("Waiting for process to exit before restarting...")
|
||||||
|
for running_pid in running_pids:
|
||||||
|
while pid_running(running_pid):
|
||||||
|
time.sleep(0.2)
|
||||||
|
|
||||||
if action == "start" or action == "restart":
|
if action == "start" or action == "restart":
|
||||||
if start_stop_synapse:
|
if start_stop_synapse:
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -124,29 +125,23 @@ class ApplicationService(object):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Expected bool for 'exclusive' in ns '%s'" % ns
|
"Expected bool for 'exclusive' in ns '%s'" % ns
|
||||||
)
|
)
|
||||||
if not isinstance(regex_obj.get("regex"), basestring):
|
regex = regex_obj.get("regex")
|
||||||
|
if isinstance(regex, basestring):
|
||||||
|
regex_obj["regex"] = re.compile(regex) # Pre-compile regex
|
||||||
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Expected string for 'regex' in ns '%s'" % ns
|
"Expected string for 'regex' in ns '%s'" % ns
|
||||||
)
|
)
|
||||||
return namespaces
|
return namespaces
|
||||||
|
|
||||||
def _matches_regex(self, test_string, namespace_key, return_obj=False):
|
def _matches_regex(self, test_string, namespace_key):
|
||||||
if not isinstance(test_string, basestring):
|
|
||||||
logger.error(
|
|
||||||
"Expected a string to test regex against, but got %s",
|
|
||||||
test_string
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
for regex_obj in self.namespaces[namespace_key]:
|
for regex_obj in self.namespaces[namespace_key]:
|
||||||
if re.match(regex_obj["regex"], test_string):
|
if regex_obj["regex"].match(test_string):
|
||||||
if return_obj:
|
return regex_obj
|
||||||
return regex_obj
|
return None
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _is_exclusive(self, ns_key, test_string):
|
def _is_exclusive(self, ns_key, test_string):
|
||||||
regex_obj = self._matches_regex(test_string, ns_key, return_obj=True)
|
regex_obj = self._matches_regex(test_string, ns_key)
|
||||||
if regex_obj:
|
if regex_obj:
|
||||||
return regex_obj["exclusive"]
|
return regex_obj["exclusive"]
|
||||||
return False
|
return False
|
||||||
|
@ -166,7 +161,14 @@ class ApplicationService(object):
|
||||||
if not store:
|
if not store:
|
||||||
defer.returnValue(False)
|
defer.returnValue(False)
|
||||||
|
|
||||||
member_list = yield store.get_users_in_room(event.room_id)
|
does_match = yield self._matches_user_in_member_list(event.room_id, store)
|
||||||
|
defer.returnValue(does_match)
|
||||||
|
|
||||||
|
@cachedInlineCallbacks(num_args=1, cache_context=True)
|
||||||
|
def _matches_user_in_member_list(self, room_id, store, cache_context):
|
||||||
|
member_list = yield store.get_users_in_room(
|
||||||
|
room_id, on_invalidate=cache_context.invalidate
|
||||||
|
)
|
||||||
|
|
||||||
# check joined member events
|
# check joined member events
|
||||||
for user_id in member_list:
|
for user_id in member_list:
|
||||||
|
@ -219,10 +221,10 @@ class ApplicationService(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_interested_in_alias(self, alias):
|
def is_interested_in_alias(self, alias):
|
||||||
return self._matches_regex(alias, ApplicationService.NS_ALIASES)
|
return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
|
||||||
|
|
||||||
def is_interested_in_room(self, room_id):
|
def is_interested_in_room(self, room_id):
|
||||||
return self._matches_regex(room_id, ApplicationService.NS_ROOMS)
|
return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
|
||||||
|
|
||||||
def is_exclusive_user(self, user_id):
|
def is_exclusive_user(self, user_id):
|
||||||
return (
|
return (
|
||||||
|
|
|
@ -54,6 +54,7 @@ class FederationRemoteSendQueue(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
self.notifier = hs.get_notifier()
|
||||||
|
|
||||||
self.presence_map = {}
|
self.presence_map = {}
|
||||||
self.presence_changed = sorteddict()
|
self.presence_changed = sorteddict()
|
||||||
|
@ -186,6 +187,8 @@ class FederationRemoteSendQueue(object):
|
||||||
else:
|
else:
|
||||||
self.edus[pos] = edu
|
self.edus[pos] = edu
|
||||||
|
|
||||||
|
self.notifier.on_new_replication_data()
|
||||||
|
|
||||||
def send_presence(self, destination, states):
|
def send_presence(self, destination, states):
|
||||||
"""As per TransactionQueue"""
|
"""As per TransactionQueue"""
|
||||||
pos = self._next_pos()
|
pos = self._next_pos()
|
||||||
|
@ -199,16 +202,20 @@ class FederationRemoteSendQueue(object):
|
||||||
(destination, state.user_id) for state in states
|
(destination, state.user_id) for state in states
|
||||||
]
|
]
|
||||||
|
|
||||||
|
self.notifier.on_new_replication_data()
|
||||||
|
|
||||||
def send_failure(self, failure, destination):
|
def send_failure(self, failure, destination):
|
||||||
"""As per TransactionQueue"""
|
"""As per TransactionQueue"""
|
||||||
pos = self._next_pos()
|
pos = self._next_pos()
|
||||||
|
|
||||||
self.failures[pos] = (destination, str(failure))
|
self.failures[pos] = (destination, str(failure))
|
||||||
|
self.notifier.on_new_replication_data()
|
||||||
|
|
||||||
def send_device_messages(self, destination):
|
def send_device_messages(self, destination):
|
||||||
"""As per TransactionQueue"""
|
"""As per TransactionQueue"""
|
||||||
pos = self._next_pos()
|
pos = self._next_pos()
|
||||||
self.device_messages[pos] = destination
|
self.device_messages[pos] = destination
|
||||||
|
self.notifier.on_new_replication_data()
|
||||||
|
|
||||||
def get_current_token(self):
|
def get_current_token(self):
|
||||||
return self.pos - 1
|
return self.pos - 1
|
||||||
|
|
|
@ -22,7 +22,7 @@ from .units import Transaction, Edu
|
||||||
from synapse.api.errors import HttpResponseException
|
from synapse.api.errors import HttpResponseException
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.logcontext import preserve_context_over_fn
|
from synapse.util.logcontext import preserve_context_over_fn
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import get_domain_from_id
|
||||||
from synapse.handlers.presence import format_user_presence_state
|
from synapse.handlers.presence import format_user_presence_state
|
||||||
|
@ -303,12 +303,19 @@ class TransactionQueue(object):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
pending_pdus = []
|
||||||
try:
|
try:
|
||||||
self.pending_transactions[destination] = 1
|
self.pending_transactions[destination] = 1
|
||||||
|
|
||||||
|
# This will throw if we wouldn't retry. We do this here so we fail
|
||||||
|
# quickly, but we will later check this again in the http client,
|
||||||
|
# hence why we throw the result away.
|
||||||
|
yield get_retry_limiter(destination, self.clock, self.store)
|
||||||
|
|
||||||
# XXX: what's this for?
|
# XXX: what's this for?
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
|
pending_pdus = []
|
||||||
while True:
|
while True:
|
||||||
device_message_edus, device_stream_id, dev_list_id = (
|
device_message_edus, device_stream_id, dev_list_id = (
|
||||||
yield self._get_new_device_messages(destination)
|
yield self._get_new_device_messages(destination)
|
||||||
|
@ -397,7 +404,7 @@ class TransactionQueue(object):
|
||||||
destination,
|
destination,
|
||||||
e,
|
e,
|
||||||
)
|
)
|
||||||
for p in pending_pdus:
|
for p, _ in pending_pdus:
|
||||||
logger.info("Failed to send event %s to %s", p.event_id,
|
logger.info("Failed to send event %s to %s", p.event_id,
|
||||||
destination)
|
destination)
|
||||||
finally:
|
finally:
|
||||||
|
|
|
@ -17,6 +17,7 @@ import logging
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -125,6 +126,11 @@ class PushRuleEvaluatorForEvent(object):
|
||||||
return self._value_cache.get(dotted_key, None)
|
return self._value_cache.get(dotted_key, None)
|
||||||
|
|
||||||
|
|
||||||
|
# Caches (glob, word_boundary) -> regex for push. See _glob_matches
|
||||||
|
regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
|
||||||
|
register_cache("regex_push_cache", regex_cache)
|
||||||
|
|
||||||
|
|
||||||
def _glob_matches(glob, value, word_boundary=False):
|
def _glob_matches(glob, value, word_boundary=False):
|
||||||
"""Tests if value matches glob.
|
"""Tests if value matches glob.
|
||||||
|
|
||||||
|
@ -137,46 +143,63 @@ def _glob_matches(glob, value, word_boundary=False):
|
||||||
Returns:
|
Returns:
|
||||||
bool
|
bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if IS_GLOB.search(glob):
|
r = regex_cache.get((glob, word_boundary), None)
|
||||||
r = re.escape(glob)
|
if not r:
|
||||||
|
r = _glob_to_re(glob, word_boundary)
|
||||||
r = r.replace(r'\*', '.*?')
|
regex_cache[(glob, word_boundary)] = r
|
||||||
r = r.replace(r'\?', '.')
|
return r.search(value)
|
||||||
|
|
||||||
# handle [abc], [a-z] and [!a-z] style ranges.
|
|
||||||
r = GLOB_REGEX.sub(
|
|
||||||
lambda x: (
|
|
||||||
'[%s%s]' % (
|
|
||||||
x.group(1) and '^' or '',
|
|
||||||
x.group(2).replace(r'\\\-', '-')
|
|
||||||
)
|
|
||||||
),
|
|
||||||
r,
|
|
||||||
)
|
|
||||||
if word_boundary:
|
|
||||||
r = r"\b%s\b" % (r,)
|
|
||||||
r = _compile_regex(r)
|
|
||||||
|
|
||||||
return r.search(value)
|
|
||||||
else:
|
|
||||||
r = r + "$"
|
|
||||||
r = _compile_regex(r)
|
|
||||||
|
|
||||||
return r.match(value)
|
|
||||||
elif word_boundary:
|
|
||||||
r = re.escape(glob)
|
|
||||||
r = r"\b%s\b" % (r,)
|
|
||||||
r = _compile_regex(r)
|
|
||||||
|
|
||||||
return r.search(value)
|
|
||||||
else:
|
|
||||||
return value.lower() == glob.lower()
|
|
||||||
except re.error:
|
except re.error:
|
||||||
logger.warn("Failed to parse glob to regex: %r", glob)
|
logger.warn("Failed to parse glob to regex: %r", glob)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _glob_to_re(glob, word_boundary):
|
||||||
|
"""Generates regex for a given glob.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
glob (string)
|
||||||
|
word_boundary (bool): Whether to match against word boundaries or entire
|
||||||
|
string. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
regex object
|
||||||
|
"""
|
||||||
|
if IS_GLOB.search(glob):
|
||||||
|
r = re.escape(glob)
|
||||||
|
|
||||||
|
r = r.replace(r'\*', '.*?')
|
||||||
|
r = r.replace(r'\?', '.')
|
||||||
|
|
||||||
|
# handle [abc], [a-z] and [!a-z] style ranges.
|
||||||
|
r = GLOB_REGEX.sub(
|
||||||
|
lambda x: (
|
||||||
|
'[%s%s]' % (
|
||||||
|
x.group(1) and '^' or '',
|
||||||
|
x.group(2).replace(r'\\\-', '-')
|
||||||
|
)
|
||||||
|
),
|
||||||
|
r,
|
||||||
|
)
|
||||||
|
if word_boundary:
|
||||||
|
r = r"\b%s\b" % (r,)
|
||||||
|
|
||||||
|
return re.compile(r, flags=re.IGNORECASE)
|
||||||
|
else:
|
||||||
|
r = "^" + r + "$"
|
||||||
|
|
||||||
|
return re.compile(r, flags=re.IGNORECASE)
|
||||||
|
elif word_boundary:
|
||||||
|
r = re.escape(glob)
|
||||||
|
r = r"\b%s\b" % (r,)
|
||||||
|
|
||||||
|
return re.compile(r, flags=re.IGNORECASE)
|
||||||
|
else:
|
||||||
|
r = "^" + re.escape(glob) + "$"
|
||||||
|
return re.compile(r, flags=re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
def _flatten_dict(d, prefix=[], result={}):
|
def _flatten_dict(d, prefix=[], result={}):
|
||||||
for key, value in d.items():
|
for key, value in d.items():
|
||||||
if isinstance(value, basestring):
|
if isinstance(value, basestring):
|
||||||
|
@ -185,16 +208,3 @@ def _flatten_dict(d, prefix=[], result={}):
|
||||||
_flatten_dict(value, prefix=(prefix + [key]), result=result)
|
_flatten_dict(value, prefix=(prefix + [key]), result=result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
regex_cache = LruCache(5000)
|
|
||||||
|
|
||||||
|
|
||||||
def _compile_regex(regex_str):
|
|
||||||
r = regex_cache.get(regex_str, None)
|
|
||||||
if r:
|
|
||||||
return r
|
|
||||||
|
|
||||||
r = re.compile(regex_str, flags=re.IGNORECASE)
|
|
||||||
regex_cache[regex_str] = r
|
|
||||||
return r
|
|
||||||
|
|
|
@ -17,15 +17,12 @@ from twisted.internet import defer
|
||||||
from synapse.push.presentable_names import (
|
from synapse.push.presentable_names import (
|
||||||
calculate_room_name, name_from_member_event
|
calculate_room_name, name_from_member_event
|
||||||
)
|
)
|
||||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_badge_count(store, user_id):
|
def get_badge_count(store, user_id):
|
||||||
invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
|
invites = yield store.get_invited_rooms_for_user(user_id)
|
||||||
preserve_fn(store.get_invited_rooms_for_user)(user_id),
|
joins = yield store.get_rooms_for_user(user_id)
|
||||||
preserve_fn(store.get_rooms_for_user)(user_id),
|
|
||||||
], consumeErrors=True))
|
|
||||||
|
|
||||||
my_receipts_by_room = yield store.get_receipts_for_user(
|
my_receipts_by_room = yield store.get_receipts_for_user(
|
||||||
user_id, "m.read",
|
user_id, "m.read",
|
||||||
|
|
|
@ -195,11 +195,11 @@ class StateHandler(object):
|
||||||
(s.type, s.state_key): s.event_id for s in old_state
|
(s.type, s.state_key): s.event_id for s in old_state
|
||||||
}
|
}
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
context.current_state_events = dict(context.prev_state_ids)
|
context.current_state_ids = dict(context.prev_state_ids)
|
||||||
key = (event.type, event.state_key)
|
key = (event.type, event.state_key)
|
||||||
context.current_state_events[key] = event.event_id
|
context.current_state_ids[key] = event.event_id
|
||||||
else:
|
else:
|
||||||
context.current_state_events = context.prev_state_ids
|
context.current_state_ids = context.prev_state_ids
|
||||||
else:
|
else:
|
||||||
context.current_state_ids = {}
|
context.current_state_ids = {}
|
||||||
context.prev_state_ids = {}
|
context.prev_state_ids = {}
|
||||||
|
|
|
@ -329,6 +329,7 @@ class DeviceStore(SQLBaseStore):
|
||||||
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
|
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
|
||||||
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
|
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
|
||||||
GROUP BY user_id, device_id
|
GROUP BY user_id, device_id
|
||||||
|
LIMIT 20
|
||||||
"""
|
"""
|
||||||
txn.execute(
|
txn.execute(
|
||||||
sql, (destination, from_stream_id, now_stream_id, False)
|
sql, (destination, from_stream_id, now_stream_id, False)
|
||||||
|
@ -339,6 +340,9 @@ class DeviceStore(SQLBaseStore):
|
||||||
if not query_map:
|
if not query_map:
|
||||||
return (now_stream_id, [])
|
return (now_stream_id, [])
|
||||||
|
|
||||||
|
if len(query_map) >= 20:
|
||||||
|
now_stream_id = max(stream_id for stream_id in query_map.itervalues())
|
||||||
|
|
||||||
devices = self._get_e2e_device_keys_txn(
|
devices = self._get_e2e_device_keys_txn(
|
||||||
txn, query_map.keys(), include_all_devices=True
|
txn, query_map.keys(), include_all_devices=True
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
import ujson as json
|
import ujson as json
|
||||||
|
|
||||||
|
@ -120,24 +122,63 @@ class EndToEndKeyStore(SQLBaseStore):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
|
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
|
||||||
|
"""Insert some new one time keys for a device.
|
||||||
|
|
||||||
|
Checks if any of the keys are already inserted, if they are then check
|
||||||
|
if they match. If they don't then we raise an error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# First we check if we have already persisted any of the keys.
|
||||||
|
rows = yield self._simple_select_many_batch(
|
||||||
|
table="e2e_one_time_keys_json",
|
||||||
|
column="key_id",
|
||||||
|
iterable=[key_id for _, key_id, _ in key_list],
|
||||||
|
retcols=("algorithm", "key_id", "key_json",),
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
"device_id": device_id,
|
||||||
|
},
|
||||||
|
desc="add_e2e_one_time_keys_check",
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_key_map = {
|
||||||
|
(row["algorithm"], row["key_id"]): row["key_json"] for row in rows
|
||||||
|
}
|
||||||
|
|
||||||
|
new_keys = [] # Keys that we need to insert
|
||||||
|
for algorithm, key_id, json_bytes in key_list:
|
||||||
|
ex_bytes = existing_key_map.get((algorithm, key_id), None)
|
||||||
|
if ex_bytes:
|
||||||
|
if json_bytes != ex_bytes:
|
||||||
|
raise SynapseError(
|
||||||
|
400, "One time key with key_id %r already exists" % (key_id,)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_keys.append((algorithm, key_id, json_bytes))
|
||||||
|
|
||||||
def _add_e2e_one_time_keys(txn):
|
def _add_e2e_one_time_keys(txn):
|
||||||
for (algorithm, key_id, json_bytes) in key_list:
|
# We are protected from race between lookup and insertion due to
|
||||||
self._simple_upsert_txn(
|
# a unique constraint. If there is a race of two calls to
|
||||||
txn, table="e2e_one_time_keys_json",
|
# `add_e2e_one_time_keys` then they'll conflict and we will only
|
||||||
keyvalues={
|
# insert one set.
|
||||||
|
self._simple_insert_many_txn(
|
||||||
|
txn, table="e2e_one_time_keys_json",
|
||||||
|
values=[
|
||||||
|
{
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
"algorithm": algorithm,
|
"algorithm": algorithm,
|
||||||
"key_id": key_id,
|
"key_id": key_id,
|
||||||
},
|
|
||||||
values={
|
|
||||||
"ts_added_ms": time_now,
|
"ts_added_ms": time_now,
|
||||||
"key_json": json_bytes,
|
"key_json": json_bytes,
|
||||||
}
|
}
|
||||||
)
|
for algorithm, key_id, json_bytes in new_keys
|
||||||
return self.runInteraction(
|
],
|
||||||
"add_e2e_one_time_keys", _add_e2e_one_time_keys
|
)
|
||||||
|
yield self.runInteraction(
|
||||||
|
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
|
||||||
)
|
)
|
||||||
|
|
||||||
def count_e2e_one_time_keys(self, user_id, device_id):
|
def count_e2e_one_time_keys(self, user_id, device_id):
|
||||||
|
|
|
@ -152,7 +152,7 @@ class EventFederationStore(SQLBaseStore):
|
||||||
txn.execute(sql, (room_id, ))
|
txn.execute(sql, (room_id, ))
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for event_id, depth in txn:
|
for event_id, depth in txn.fetchall():
|
||||||
hashes = self._get_event_reference_hashes_txn(txn, event_id)
|
hashes = self._get_event_reference_hashes_txn(txn, event_id)
|
||||||
prev_hashes = {
|
prev_hashes = {
|
||||||
k: encode_base64(v) for k, v in hashes.items()
|
k: encode_base64(v) for k, v in hashes.items()
|
||||||
|
|
|
@ -496,7 +496,7 @@ class StateStore(SQLBaseStore):
|
||||||
state_map = yield self.get_state_ids_for_events([event_id], types)
|
state_map = yield self.get_state_ids_for_events([event_id], types)
|
||||||
defer.returnValue(state_map[event_id])
|
defer.returnValue(state_map[event_id])
|
||||||
|
|
||||||
@cached(num_args=2, max_entries=10000)
|
@cached(num_args=2, max_entries=100000)
|
||||||
def _get_state_group_for_event(self, room_id, event_id):
|
def _get_state_group_for_event(self, room_id, event_id):
|
||||||
return self._simple_select_one_onecol(
|
return self._simple_select_one_onecol(
|
||||||
table="event_to_state_groups",
|
table="event_to_state_groups",
|
||||||
|
|
|
@ -216,9 +216,7 @@ class StreamToken(
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def copy_and_replace(self, key, new_value):
|
def copy_and_replace(self, key, new_value):
|
||||||
d = self._asdict()
|
return self._replace(**{key: new_value})
|
||||||
d[key] = new_value
|
|
||||||
return StreamToken(**d)
|
|
||||||
|
|
||||||
|
|
||||||
StreamToken.START = StreamToken(
|
StreamToken.START = StreamToken(
|
||||||
|
|
|
@ -89,6 +89,11 @@ class ObservableDeferred(object):
|
||||||
deferred.addCallbacks(callback, errback)
|
deferred.addCallbacks(callback, errback)
|
||||||
|
|
||||||
def observe(self):
|
def observe(self):
|
||||||
|
"""Observe the underlying deferred.
|
||||||
|
|
||||||
|
Can return either a deferred if the underlying deferred is still pending
|
||||||
|
(or has failed), or the actual value. Callers may need to use maybeDeferred.
|
||||||
|
"""
|
||||||
if not self._result:
|
if not self._result:
|
||||||
d = defer.Deferred()
|
d = defer.Deferred()
|
||||||
|
|
||||||
|
@ -101,7 +106,7 @@ class ObservableDeferred(object):
|
||||||
return d
|
return d
|
||||||
else:
|
else:
|
||||||
success, res = self._result
|
success, res = self._result
|
||||||
return defer.succeed(res) if success else defer.fail(res)
|
return res if success else defer.fail(res)
|
||||||
|
|
||||||
def observers(self):
|
def observers(self):
|
||||||
return self._observers
|
return self._observers
|
||||||
|
|
|
@ -15,12 +15,9 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError, logcontext
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
|
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
|
||||||
from synapse.util.logcontext import (
|
|
||||||
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
|
|
||||||
)
|
|
||||||
|
|
||||||
from . import DEBUG_CACHES, register_cache
|
from . import DEBUG_CACHES, register_cache
|
||||||
|
|
||||||
|
@ -227,8 +224,20 @@ class _CacheDescriptorBase(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.num_args = num_args
|
self.num_args = num_args
|
||||||
|
|
||||||
|
# list of the names of the args used as the cache key
|
||||||
self.arg_names = all_args[1:num_args + 1]
|
self.arg_names = all_args[1:num_args + 1]
|
||||||
|
|
||||||
|
# self.arg_defaults is a map of arg name to its default value for each
|
||||||
|
# argument that has a default value
|
||||||
|
if arg_spec.defaults:
|
||||||
|
self.arg_defaults = dict(zip(
|
||||||
|
all_args[-len(arg_spec.defaults):],
|
||||||
|
arg_spec.defaults
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
self.arg_defaults = {}
|
||||||
|
|
||||||
if "cache_context" in self.arg_names:
|
if "cache_context" in self.arg_names:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"cache_context arg cannot be included among the cache keys"
|
"cache_context arg cannot be included among the cache keys"
|
||||||
|
@ -292,18 +301,31 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||||
iterable=self.iterable,
|
iterable=self.iterable,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_cache_key(args, kwargs):
|
||||||
|
"""Given some args/kwargs return a generator that resolves into
|
||||||
|
the cache_key.
|
||||||
|
|
||||||
|
We loop through each arg name, looking up if its in the `kwargs`,
|
||||||
|
otherwise using the next argument in `args`. If there are no more
|
||||||
|
args then we try looking the arg name up in the defaults
|
||||||
|
"""
|
||||||
|
pos = 0
|
||||||
|
for nm in self.arg_names:
|
||||||
|
if nm in kwargs:
|
||||||
|
yield kwargs[nm]
|
||||||
|
elif pos < len(args):
|
||||||
|
yield args[pos]
|
||||||
|
pos += 1
|
||||||
|
else:
|
||||||
|
yield self.arg_defaults[nm]
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
# If we're passed a cache_context then we'll want to call its invalidate()
|
# If we're passed a cache_context then we'll want to call its invalidate()
|
||||||
# whenever we are invalidated
|
# whenever we are invalidated
|
||||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||||
|
|
||||||
# Add temp cache_context so inspect.getcallargs doesn't explode
|
cache_key = tuple(get_cache_key(args, kwargs))
|
||||||
if self.add_cache_context:
|
|
||||||
kwargs["cache_context"] = None
|
|
||||||
|
|
||||||
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
|
||||||
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
|
|
||||||
|
|
||||||
# Add our own `cache_context` to argument list if the wrapped function
|
# Add our own `cache_context` to argument list if the wrapped function
|
||||||
# has asked for one
|
# has asked for one
|
||||||
|
@ -328,11 +350,9 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||||
defer.returnValue(cached_result)
|
defer.returnValue(cached_result)
|
||||||
observer.addCallback(check_result)
|
observer.addCallback(check_result)
|
||||||
|
|
||||||
return preserve_context_over_deferred(observer)
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
ret = defer.maybeDeferred(
|
ret = defer.maybeDeferred(
|
||||||
preserve_context_over_fn,
|
logcontext.preserve_fn(self.function_to_call),
|
||||||
self.function_to_call,
|
|
||||||
obj, *args, **kwargs
|
obj, *args, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -342,10 +362,14 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||||
|
|
||||||
ret.addErrback(onErr)
|
ret.addErrback(onErr)
|
||||||
|
|
||||||
ret = ObservableDeferred(ret, consumeErrors=True)
|
result_d = ObservableDeferred(ret, consumeErrors=True)
|
||||||
cache.set(cache_key, ret, callback=invalidate_callback)
|
cache.set(cache_key, result_d, callback=invalidate_callback)
|
||||||
|
observer = result_d.observe()
|
||||||
|
|
||||||
return preserve_context_over_deferred(ret.observe())
|
if isinstance(observer, defer.Deferred):
|
||||||
|
return logcontext.make_deferred_yieldable(observer)
|
||||||
|
else:
|
||||||
|
return observer
|
||||||
|
|
||||||
wrapped.invalidate = cache.invalidate
|
wrapped.invalidate = cache.invalidate
|
||||||
wrapped.invalidate_all = cache.invalidate_all
|
wrapped.invalidate_all = cache.invalidate_all
|
||||||
|
@ -362,7 +386,11 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
||||||
"""Wraps an existing cache to support bulk fetching of keys.
|
"""Wraps an existing cache to support bulk fetching of keys.
|
||||||
|
|
||||||
Given a list of keys it looks in the cache to find any hits, then passes
|
Given a list of keys it looks in the cache to find any hits, then passes
|
||||||
the list of missing keys to the wrapped fucntion.
|
the list of missing keys to the wrapped function.
|
||||||
|
|
||||||
|
Once wrapped, the function returns either a Deferred which resolves to
|
||||||
|
the list of results, or (if all results were cached), just the list of
|
||||||
|
results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, orig, cached_method_name, list_name, num_args=None,
|
def __init__(self, orig, cached_method_name, list_name, num_args=None,
|
||||||
|
@ -433,8 +461,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
||||||
args_to_call[self.list_name] = missing
|
args_to_call[self.list_name] = missing
|
||||||
|
|
||||||
ret_d = defer.maybeDeferred(
|
ret_d = defer.maybeDeferred(
|
||||||
preserve_context_over_fn,
|
logcontext.preserve_fn(self.function_to_call),
|
||||||
self.function_to_call,
|
|
||||||
**args_to_call
|
**args_to_call
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -443,8 +470,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
||||||
# We need to create deferreds for each arg in the list so that
|
# We need to create deferreds for each arg in the list so that
|
||||||
# we can insert the new deferred into the cache.
|
# we can insert the new deferred into the cache.
|
||||||
for arg in missing:
|
for arg in missing:
|
||||||
with PreserveLoggingContext():
|
observer = ret_d.observe()
|
||||||
observer = ret_d.observe()
|
|
||||||
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
|
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
|
||||||
|
|
||||||
observer = ObservableDeferred(observer)
|
observer = ObservableDeferred(observer)
|
||||||
|
@ -471,7 +497,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
||||||
results.update(res)
|
results.update(res)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
return preserve_context_over_deferred(defer.gatherResults(
|
return logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||||
cached_defers.values(),
|
cached_defers.values(),
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
).addCallback(update_results_dict).addErrback(
|
).addCallback(update_results_dict).addErrback(
|
||||||
|
|
|
@ -310,6 +310,10 @@ def preserve_context_over_fn(fn, *args, **kwargs):
|
||||||
def preserve_context_over_deferred(deferred, context=None):
|
def preserve_context_over_deferred(deferred, context=None):
|
||||||
"""Given a deferred wrap it such that any callbacks added later to it will
|
"""Given a deferred wrap it such that any callbacks added later to it will
|
||||||
be invoked with the current context.
|
be invoked with the current context.
|
||||||
|
|
||||||
|
Deprecated: this almost certainly doesn't do want you want, ie make
|
||||||
|
the deferred follow the synapse logcontext rules: try
|
||||||
|
``make_deferred_yieldable`` instead.
|
||||||
"""
|
"""
|
||||||
if context is None:
|
if context is None:
|
||||||
context = LoggingContext.current_context()
|
context = LoggingContext.current_context()
|
||||||
|
@ -359,6 +363,25 @@ def preserve_fn(f):
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def make_deferred_yieldable(deferred):
|
||||||
|
"""Given a deferred, make it follow the Synapse logcontext rules:
|
||||||
|
|
||||||
|
If the deferred has completed (or is not actually a Deferred), essentially
|
||||||
|
does nothing (just returns another completed deferred with the
|
||||||
|
result/failure).
|
||||||
|
|
||||||
|
If the deferred has not yet completed, resets the logcontext before
|
||||||
|
returning a deferred. Then, when the deferred completes, restores the
|
||||||
|
current logcontext before running callbacks/errbacks.
|
||||||
|
|
||||||
|
(This is more-or-less the opposite operation to preserve_fn.)
|
||||||
|
"""
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
r = yield deferred
|
||||||
|
defer.returnValue(r)
|
||||||
|
|
||||||
|
|
||||||
# modules to ignore in `logcontext_tracer`
|
# modules to ignore in `logcontext_tracer`
|
||||||
_to_ignore = [
|
_to_ignore = [
|
||||||
"synapse.util.logcontext",
|
"synapse.util.logcontext",
|
||||||
|
|
|
@ -56,7 +56,8 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
|
||||||
events ([synapse.events.EventBase]): list of events to filter
|
events ([synapse.events.EventBase]): list of events to filter
|
||||||
"""
|
"""
|
||||||
forgotten = yield preserve_context_over_deferred(defer.gatherResults([
|
forgotten = yield preserve_context_over_deferred(defer.gatherResults([
|
||||||
preserve_fn(store.who_forgot_in_room)(
|
defer.maybeDeferred(
|
||||||
|
preserve_fn(store.who_forgot_in_room),
|
||||||
room_id,
|
room_id,
|
||||||
)
|
)
|
||||||
for room_id in frozenset(e.room_id for e in events)
|
for room_id in frozenset(e.room_id for e in events)
|
||||||
|
|
|
@ -19,10 +19,12 @@ from twisted.internet import defer
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
def _regex(regex, exclusive=True):
|
def _regex(regex, exclusive=True):
|
||||||
return {
|
return {
|
||||||
"regex": regex,
|
"regex": re.compile(regex),
|
||||||
"exclusive": exclusive
|
"exclusive": exclusive
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -199,7 +199,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
|
||||||
|
|
||||||
a.func.prefill(("foo",), ObservableDeferred(d))
|
a.func.prefill(("foo",), ObservableDeferred(d))
|
||||||
|
|
||||||
self.assertEquals(a.func("foo").result, d.result)
|
self.assertEquals(a.func("foo"), d.result)
|
||||||
self.assertEquals(callcount[0], 0)
|
self.assertEquals(callcount[0], 0)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -12,11 +12,18 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
|
||||||
import mock
|
import mock
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.util import async
|
||||||
|
from synapse.util import logcontext
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from synapse.util.caches import descriptors
|
from synapse.util.caches import descriptors
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DescriptorTestCase(unittest.TestCase):
|
class DescriptorTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -84,3 +91,125 @@ class DescriptorTestCase(unittest.TestCase):
|
||||||
r = yield obj.fn(2, 5)
|
r = yield obj.fn(2, 5)
|
||||||
self.assertEqual(r, 'chips')
|
self.assertEqual(r, 'chips')
|
||||||
obj.mock.assert_not_called()
|
obj.mock.assert_not_called()
|
||||||
|
|
||||||
|
def test_cache_logcontexts(self):
|
||||||
|
"""Check that logcontexts are set and restored correctly when
|
||||||
|
using the cache."""
|
||||||
|
|
||||||
|
complete_lookup = defer.Deferred()
|
||||||
|
|
||||||
|
class Cls(object):
|
||||||
|
@descriptors.cached()
|
||||||
|
def fn(self, arg1):
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def inner_fn():
|
||||||
|
with logcontext.PreserveLoggingContext():
|
||||||
|
yield complete_lookup
|
||||||
|
defer.returnValue(1)
|
||||||
|
|
||||||
|
return inner_fn()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def do_lookup():
|
||||||
|
with logcontext.LoggingContext() as c1:
|
||||||
|
c1.name = "c1"
|
||||||
|
r = yield obj.fn(1)
|
||||||
|
self.assertEqual(logcontext.LoggingContext.current_context(),
|
||||||
|
c1)
|
||||||
|
defer.returnValue(r)
|
||||||
|
|
||||||
|
def check_result(r):
|
||||||
|
self.assertEqual(r, 1)
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
|
||||||
|
# set off a deferred which will do a cache lookup
|
||||||
|
d1 = do_lookup()
|
||||||
|
self.assertEqual(logcontext.LoggingContext.current_context(),
|
||||||
|
logcontext.LoggingContext.sentinel)
|
||||||
|
d1.addCallback(check_result)
|
||||||
|
|
||||||
|
# and another
|
||||||
|
d2 = do_lookup()
|
||||||
|
self.assertEqual(logcontext.LoggingContext.current_context(),
|
||||||
|
logcontext.LoggingContext.sentinel)
|
||||||
|
d2.addCallback(check_result)
|
||||||
|
|
||||||
|
# let the lookup complete
|
||||||
|
complete_lookup.callback(None)
|
||||||
|
|
||||||
|
return defer.gatherResults([d1, d2])
|
||||||
|
|
||||||
|
def test_cache_logcontexts_with_exception(self):
|
||||||
|
"""Check that the cache sets and restores logcontexts correctly when
|
||||||
|
the lookup function throws an exception"""
|
||||||
|
|
||||||
|
class Cls(object):
|
||||||
|
@descriptors.cached()
|
||||||
|
def fn(self, arg1):
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def inner_fn():
|
||||||
|
yield async.run_on_reactor()
|
||||||
|
raise SynapseError(400, "blah")
|
||||||
|
|
||||||
|
return inner_fn()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def do_lookup():
|
||||||
|
with logcontext.LoggingContext() as c1:
|
||||||
|
c1.name = "c1"
|
||||||
|
try:
|
||||||
|
yield obj.fn(1)
|
||||||
|
self.fail("No exception thrown")
|
||||||
|
except SynapseError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.assertEqual(logcontext.LoggingContext.current_context(),
|
||||||
|
c1)
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
|
||||||
|
# set off a deferred which will do a cache lookup
|
||||||
|
d1 = do_lookup()
|
||||||
|
self.assertEqual(logcontext.LoggingContext.current_context(),
|
||||||
|
logcontext.LoggingContext.sentinel)
|
||||||
|
|
||||||
|
return d1
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_cache_default_args(self):
|
||||||
|
class Cls(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.mock = mock.Mock()
|
||||||
|
|
||||||
|
@descriptors.cached()
|
||||||
|
def fn(self, arg1, arg2=2, arg3=3):
|
||||||
|
return self.mock(arg1, arg2, arg3)
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
|
||||||
|
obj.mock.return_value = 'fish'
|
||||||
|
r = yield obj.fn(1, 2, 3)
|
||||||
|
self.assertEqual(r, 'fish')
|
||||||
|
obj.mock.assert_called_once_with(1, 2, 3)
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# a call with same params shouldn't call the mock again
|
||||||
|
r = yield obj.fn(1, 2)
|
||||||
|
self.assertEqual(r, 'fish')
|
||||||
|
obj.mock.assert_not_called()
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# a call with different params should call the mock again
|
||||||
|
obj.mock.return_value = 'chips'
|
||||||
|
r = yield obj.fn(2, 3)
|
||||||
|
self.assertEqual(r, 'chips')
|
||||||
|
obj.mock.assert_called_once_with(2, 3, 3)
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# the two values should now be cached
|
||||||
|
r = yield obj.fn(1, 2)
|
||||||
|
self.assertEqual(r, 'fish')
|
||||||
|
r = yield obj.fn(2, 3)
|
||||||
|
self.assertEqual(r, 'chips')
|
||||||
|
obj.mock.assert_not_called()
|
||||||
|
|
|
@ -53,7 +53,9 @@ class SnapshotCacheTestCase(unittest.TestCase):
|
||||||
# before the cache expires returns a resolved deferred.
|
# before the cache expires returns a resolved deferred.
|
||||||
get_result_at_11 = self.cache.get(11, "key")
|
get_result_at_11 = self.cache.get(11, "key")
|
||||||
self.assertIsNotNone(get_result_at_11)
|
self.assertIsNotNone(get_result_at_11)
|
||||||
self.assertTrue(get_result_at_11.called)
|
if isinstance(get_result_at_11, Deferred):
|
||||||
|
# The cache may return the actual result rather than a deferred
|
||||||
|
self.assertTrue(get_result_at_11.called)
|
||||||
|
|
||||||
# Check that getting the key after the deferred has resolved
|
# Check that getting the key after the deferred has resolved
|
||||||
# after the cache expires returns None
|
# after the cache expires returns None
|
||||||
|
|
Loading…
Reference in a new issue