diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 8709394b97..a859872ce2 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -26,7 +26,7 @@ from synapse.api.errors import SynapseError, Codes from synapse.util.retryutils import get_retry_limiter -from synapse.util.async import create_observer +from synapse.util.async import ObservableDeferred from OpenSSL import crypto @@ -111,6 +111,10 @@ class Keyring(object): if download is None: download = self._get_server_verify_key_impl(server_name, key_ids) + download = ObservableDeferred( + download, + consumeErrors=True + ) self.key_downloads[server_name] = download @download.addBoth @@ -118,7 +122,7 @@ class Keyring(object): del self.key_downloads[server_name] return ret - r = yield create_observer(download) + r = yield download.observe() defer.returnValue(r) @defer.inlineCallbacks diff --git a/synapse/rest/media/v1/base_resource.py b/synapse/rest/media/v1/base_resource.py index 08c8d75af4..4af5f73878 100644 --- a/synapse/rest/media/v1/base_resource.py +++ b/synapse/rest/media/v1/base_resource.py @@ -25,7 +25,7 @@ from twisted.internet import defer from twisted.web.resource import Resource from twisted.protocols.basic import FileSender -from synapse.util.async import create_observer +from synapse.util.async import ObservableDeferred import os @@ -83,13 +83,17 @@ class BaseMediaResource(Resource): download = self.downloads.get(key) if download is None: download = self._get_remote_media_impl(server_name, media_id) + download = ObservableDeferred( + download, + consumeErrors=True + ) self.downloads[key] = download @download.addBoth def callback(media_info): del self.downloads[key] return media_info - return create_observer(download) + return download.observe() @defer.inlineCallbacks def _get_remote_media_impl(self, server_name, media_id): diff --git a/synapse/util/async.py b/synapse/util/async.py index d8febdb90c..34acb14a6f 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -34,20 +34,56 @@ def run_on_reactor(): return sleep(0) -def create_observer(deferred): - """Creates a deferred that observes the result or failure of the given - deferred *without* affecting the given deferred. +class ObservableDeferred(object): + """Wraps a deferred object so that we can add observer deferreds. These + observer deferreds do not affect the callback chain of the original + deferred. + + If consumeErrors is true errors will be captured from the origin deferred. """ - d = defer.Deferred() - def callback(r): - d.callback(r) - return r + __slots__ = ["_deferred", "_observers", "_result"] - def errback(f): - d.errback(f) - return f + def __init__(self, deferred, consumeErrors=False): + object.__setattr__(self, "_deferred", deferred) + object.__setattr__(self, "_result", None) + object.__setattr__(self, "_observers", []) - deferred.addCallbacks(callback, errback) + def callback(r): + self._result = (True, r) + while self._observers: + try: + self._observers.pop().callback(r) + except: + pass + return r - return d + def errback(f): + self._result = (False, f) + while self._observers: + try: + self._observers.pop().errback(f) + except: + pass + + if consumeErrors: + return None + else: + return f + + deferred.addCallbacks(callback, errback) + + def observe(self): + if not self._result: + d = defer.Deferred() + self._observers.append(d) + return d + else: + success, res = self._result + return defer.succeed(res) if success else defer.fail(res) + + def __getattr__(self, name): + return getattr(self._deferred, name) + + def __setattr__(self, name, value): + setattr(self._deferred, name, value)