put resolve_service in an object

this makes it easier to stub things out for tests.
This commit is contained in:
Richard van der Hoff 2019-01-22 17:42:26 +00:00
parent 53a327b4d5
commit 7021784d46
3 changed files with 96 additions and 75 deletions

View file

@ -22,7 +22,7 @@ from twisted.web.client import URI, Agent, HTTPConnectionPool
from twisted.web.iweb import IAgent
from synapse.http.endpoint import parse_server_name
from synapse.http.federation.srv_resolver import pick_server_from_list, resolve_service
from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
from synapse.util.logcontext import make_deferred_yieldable
logger = logging.getLogger(__name__)
@ -37,13 +37,23 @@ class MatrixFederationAgent(object):
Args:
reactor (IReactor): twisted reactor to use for underlying requests
tls_client_options_factory (ClientTLSOptionsFactory|None):
factory to use for fetching client tls options, or none to disable TLS.
srv_resolver (SrvResolver|None):
SRVResolver impl to use for looking up SRV records. None to use a default
implementation.
"""
def __init__(self, reactor, tls_client_options_factory):
def __init__(
self, reactor, tls_client_options_factory, _srv_resolver=None,
):
self._reactor = reactor
self._tls_client_options_factory = tls_client_options_factory
if _srv_resolver is None:
_srv_resolver = SrvResolver()
self._srv_resolver = _srv_resolver
self._pool = HTTPConnectionPool(reactor)
self._pool.retryAutomatically = False
@ -91,7 +101,7 @@ class MatrixFederationAgent(object):
if port is not None:
target = (host, port)
else:
server_list = yield resolve_service(server_name_bytes)
server_list = yield self._srv_resolver.resolve_service(server_name_bytes)
if not server_list:
target = (host, 8448)
logger.debug("No SRV record for %s, using %s", host, target)

View file

@ -84,73 +84,86 @@ def pick_server_from_list(server_list):
)
@defer.inlineCallbacks
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
"""Look up a SRV record, with caching
class SrvResolver(object):
"""Interface to the dns client to do SRV lookups, with result caching.
The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
but the cache never gets populated), so we add our own caching layer here.
Args:
service_name (bytes): record to look up
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
cache (dict): cache object
clock (object): clock implementation. must provide a time() method.
Returns:
Deferred[list[Server]]: a list of the SRV records, or an empty list if none found
get_time (callable): clock implementation. Should return seconds since the epoch
"""
if not isinstance(service_name, bytes):
raise TypeError("%r is not a byte string" % (service_name,))
def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
self._dns_client = dns_client
self._cache = cache
self._get_time = get_time
cache_entry = cache.get(service_name, None)
if cache_entry:
if all(s.expires > int(clock.time()) for s in cache_entry):
servers = list(cache_entry)
defer.returnValue(servers)
@defer.inlineCallbacks
def resolve_service(self, service_name):
"""Look up a SRV record
try:
answers, _, _ = yield make_deferred_yieldable(
dns_client.lookupService(service_name),
)
except DNSNameError:
# TODO: cache this. We can get the SOA out of the exception, and use
# the negative-TTL value.
defer.returnValue([])
except DomainError as e:
# We failed to resolve the name (other than a NameError)
# Try something in the cache, else rereaise
cache_entry = cache.get(service_name, None)
Args:
service_name (bytes): record to look up
Returns:
Deferred[list[Server]]:
a list of the SRV records, or an empty list if none found
"""
now = int(self._get_time())
if not isinstance(service_name, bytes):
raise TypeError("%r is not a byte string" % (service_name,))
cache_entry = self._cache.get(service_name, None)
if cache_entry:
logger.warn(
"Failed to resolve %r, falling back to cache. %r",
service_name, e
if all(s.expires > now for s in cache_entry):
servers = list(cache_entry)
defer.returnValue(servers)
try:
answers, _, _ = yield make_deferred_yieldable(
self._dns_client.lookupService(service_name),
)
defer.returnValue(list(cache_entry))
else:
raise e
except DNSNameError:
# TODO: cache this. We can get the SOA out of the exception, and use
# the negative-TTL value.
defer.returnValue([])
except DomainError as e:
# We failed to resolve the name (other than a NameError)
# Try something in the cache, else rereaise
cache_entry = self._cache.get(service_name, None)
if cache_entry:
logger.warn(
"Failed to resolve %r, falling back to cache. %r",
service_name, e
)
defer.returnValue(list(cache_entry))
else:
raise e
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
and answers[0].payload.target == dns.Name(b'.')):
raise ConnectError("Service %s unavailable" % service_name)
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
and answers[0].payload.target == dns.Name(b'.')):
raise ConnectError("Service %s unavailable" % service_name)
servers = []
servers = []
for answer in answers:
if answer.type != dns.SRV or not answer.payload:
continue
for answer in answers:
if answer.type != dns.SRV or not answer.payload:
continue
payload = answer.payload
payload = answer.payload
servers.append(Server(
host=payload.target.name,
port=payload.port,
priority=payload.priority,
weight=payload.weight,
expires=int(clock.time()) + answer.ttl,
))
servers.append(Server(
host=payload.target.name,
port=payload.port,
priority=payload.priority,
weight=payload.weight,
expires=now + answer.ttl,
))
cache[service_name] = list(servers)
defer.returnValue(servers)
self._cache[service_name] = list(servers)
defer.returnValue(servers)

View file

@ -21,7 +21,7 @@ from twisted.internet.defer import Deferred
from twisted.internet.error import ConnectError
from twisted.names import dns, error
from synapse.http.federation.srv_resolver import resolve_service
from synapse.http.federation.srv_resolver import SrvResolver
from synapse.util.logcontext import LoggingContext
from tests import unittest
@ -43,13 +43,13 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock.lookupService.return_value = result_deferred
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
@defer.inlineCallbacks
def do_lookup():
with LoggingContext("one") as ctx:
resolve_d = resolve_service(
service_name, dns_client=dns_client_mock, cache=cache
)
resolve_d = resolver.resolve_service(service_name)
self.assertNoResult(resolve_d)
@ -89,10 +89,9 @@ class SrvResolverTestCase(unittest.TestCase):
entry.expires = 0
cache = {service_name: [entry]}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
servers = yield resolve_service(
service_name, dns_client=dns_client_mock, cache=cache
)
servers = yield resolver.resolve_service(service_name)
dns_client_mock.lookupService.assert_called_once_with(service_name)
@ -112,11 +111,12 @@ class SrvResolverTestCase(unittest.TestCase):
entry.expires = 999999999
cache = {service_name: [entry]}
servers = yield resolve_service(
service_name, dns_client=dns_client_mock, cache=cache, clock=clock
resolver = SrvResolver(
dns_client=dns_client_mock, cache=cache, get_time=clock.time,
)
servers = yield resolver.resolve_service(service_name)
self.assertFalse(dns_client_mock.lookupService.called)
self.assertEquals(len(servers), 1)
@ -131,9 +131,10 @@ class SrvResolverTestCase(unittest.TestCase):
service_name = b"test_service.example.com"
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
with self.assertRaises(error.DNSServerError):
yield resolve_service(service_name, dns_client=dns_client_mock, cache=cache)
yield resolver.resolve_service(service_name)
@defer.inlineCallbacks
def test_name_error(self):
@ -144,10 +145,9 @@ class SrvResolverTestCase(unittest.TestCase):
service_name = b"test_service.example.com"
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
servers = yield resolve_service(
service_name, dns_client=dns_client_mock, cache=cache
)
servers = yield resolver.resolve_service(service_name)
self.assertEquals(len(servers), 0)
self.assertEquals(len(cache), 0)
@ -162,10 +162,9 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = lookup_deferred
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
resolve_d = resolve_service(
service_name, dns_client=dns_client_mock, cache=cache
)
resolve_d = resolver.resolve_service(service_name)
self.assertNoResult(resolve_d)
# returning a single "." should make the lookup fail with a ConenctError
@ -187,10 +186,9 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = lookup_deferred
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
resolve_d = resolve_service(
service_name, dns_client=dns_client_mock, cache=cache
)
resolve_d = resolver.resolve_service(service_name)
self.assertNoResult(resolve_d)
lookup_deferred.callback((