mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-22 09:35:45 +03:00
Do an AAAA lookup on SRV record targets (#2462)
Support SRV records which point at AAAA records, as well as A records. Fixes https://github.com/matrix-org/synapse/issues/2405
This commit is contained in:
parent
f496399ac4
commit
f65e31d22f
2 changed files with 118 additions and 24 deletions
|
@ -12,6 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import socket
|
||||||
|
|
||||||
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
|
@ -30,7 +31,10 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SERVER_CACHE = {}
|
SERVER_CACHE = {}
|
||||||
|
|
||||||
|
# our record of an individual server which can be tried to reach a destination.
|
||||||
|
#
|
||||||
|
# "host" is actually a dotted-quad or ipv6 address string. Except when there's
|
||||||
|
# no SRV record, in which case it is the original hostname.
|
||||||
_Server = collections.namedtuple(
|
_Server = collections.namedtuple(
|
||||||
"_Server", "priority weight host port expires"
|
"_Server", "priority weight host port expires"
|
||||||
)
|
)
|
||||||
|
@ -219,9 +223,10 @@ class SRVClientEndpoint(object):
|
||||||
return self.default_server
|
return self.default_server
|
||||||
else:
|
else:
|
||||||
raise ConnectError(
|
raise ConnectError(
|
||||||
"Not server available for %s" % self.service_name
|
"No server available for %s" % self.service_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# look for all servers with the same priority
|
||||||
min_priority = self.servers[0].priority
|
min_priority = self.servers[0].priority
|
||||||
weight_indexes = list(
|
weight_indexes = list(
|
||||||
(index, server.weight + 1)
|
(index, server.weight + 1)
|
||||||
|
@ -231,11 +236,22 @@ class SRVClientEndpoint(object):
|
||||||
|
|
||||||
total_weight = sum(weight for index, weight in weight_indexes)
|
total_weight = sum(weight for index, weight in weight_indexes)
|
||||||
target_weight = random.randint(0, total_weight)
|
target_weight = random.randint(0, total_weight)
|
||||||
|
|
||||||
for index, weight in weight_indexes:
|
for index, weight in weight_indexes:
|
||||||
target_weight -= weight
|
target_weight -= weight
|
||||||
if target_weight <= 0:
|
if target_weight <= 0:
|
||||||
server = self.servers[index]
|
server = self.servers[index]
|
||||||
|
# XXX: this looks totally dubious:
|
||||||
|
#
|
||||||
|
# (a) we never reuse a server until we have been through
|
||||||
|
# all of the servers at the same priority, so if the
|
||||||
|
# weights are A: 100, B:1, we always do ABABAB instead of
|
||||||
|
# AAAA...AAAB (approximately).
|
||||||
|
#
|
||||||
|
# (b) After using all the servers at the lowest priority,
|
||||||
|
# we move onto the next priority. We should only use the
|
||||||
|
# second priority if servers at the top priority are
|
||||||
|
# unreachable.
|
||||||
|
#
|
||||||
del self.servers[index]
|
del self.servers[index]
|
||||||
self.used_servers.append(server)
|
self.used_servers.append(server)
|
||||||
return server
|
return server
|
||||||
|
@ -280,26 +296,21 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
|
||||||
continue
|
continue
|
||||||
|
|
||||||
payload = answer.payload
|
payload = answer.payload
|
||||||
host = str(payload.target)
|
|
||||||
srv_ttl = answer.ttl
|
|
||||||
|
|
||||||
try:
|
hosts = yield _get_hosts_for_srv_record(
|
||||||
answers, _, _ = yield dns_client.lookupAddress(host)
|
dns_client, str(payload.target)
|
||||||
except DNSNameError:
|
)
|
||||||
continue
|
|
||||||
|
|
||||||
for answer in answers:
|
for (ip, ttl) in hosts:
|
||||||
if answer.type == dns.A and answer.payload:
|
host_ttl = min(answer.ttl, ttl)
|
||||||
ip = answer.payload.dottedQuad()
|
|
||||||
host_ttl = min(srv_ttl, answer.ttl)
|
|
||||||
|
|
||||||
servers.append(_Server(
|
servers.append(_Server(
|
||||||
host=ip,
|
host=ip,
|
||||||
port=int(payload.port),
|
port=int(payload.port),
|
||||||
priority=int(payload.priority),
|
priority=int(payload.priority),
|
||||||
weight=int(payload.weight),
|
weight=int(payload.weight),
|
||||||
expires=int(clock.time()) + host_ttl,
|
expires=int(clock.time()) + host_ttl,
|
||||||
))
|
))
|
||||||
|
|
||||||
servers.sort()
|
servers.sort()
|
||||||
cache[service_name] = list(servers)
|
cache[service_name] = list(servers)
|
||||||
|
@ -317,3 +328,68 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
defer.returnValue(servers)
|
defer.returnValue(servers)
|
||||||
|
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _get_hosts_for_srv_record(dns_client, host):
|
||||||
|
"""Look up each of the hosts in a SRV record
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dns_client (twisted.names.dns.IResolver):
|
||||||
|
host (basestring): host to look up
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[list[(str, int)]]: a list of (host, ttl) pairs
|
||||||
|
|
||||||
|
"""
|
||||||
|
ip4_servers = []
|
||||||
|
ip6_servers = []
|
||||||
|
|
||||||
|
def cb(res):
|
||||||
|
# lookupAddress and lookupIP6Address return a three-tuple
|
||||||
|
# giving the answer, authority, and additional sections of the
|
||||||
|
# response.
|
||||||
|
#
|
||||||
|
# we only care about the answers.
|
||||||
|
|
||||||
|
return res[0]
|
||||||
|
|
||||||
|
def eb(res):
|
||||||
|
res.trap(DNSNameError)
|
||||||
|
return []
|
||||||
|
|
||||||
|
# no logcontexts here, so we can safely fire these off and gatherResults
|
||||||
|
d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb)
|
||||||
|
d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb)
|
||||||
|
results = yield defer.gatherResults([d1, d2], consumeErrors=True)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
for answer in result:
|
||||||
|
if not answer.payload:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
if answer.type == dns.A:
|
||||||
|
ip = answer.payload.dottedQuad()
|
||||||
|
ip4_servers.append((ip, answer.ttl))
|
||||||
|
elif answer.type == dns.AAAA:
|
||||||
|
ip = socket.inet_ntop(
|
||||||
|
socket.AF_INET6, answer.payload.address,
|
||||||
|
)
|
||||||
|
ip6_servers.append((ip, answer.ttl))
|
||||||
|
else:
|
||||||
|
# the most likely candidate here is a CNAME record.
|
||||||
|
# rfc2782 says srvs may not point to aliases.
|
||||||
|
logger.warn(
|
||||||
|
"Ignoring unexpected DNS record type %s for %s",
|
||||||
|
answer.type, host,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.warn("Ignoring invalid DNS response for %s: %s",
|
||||||
|
host, e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# keep the ipv4 results before the ipv6 results, mostly to match historical
|
||||||
|
# behaviour.
|
||||||
|
defer.returnValue(ip4_servers + ip6_servers)
|
||||||
|
|
|
@ -24,15 +24,17 @@ from synapse.http.endpoint import resolve_service
|
||||||
from tests.utils import MockClock
|
from tests.utils import MockClock
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.DEBUG
|
||||||
class DnsTestCase(unittest.TestCase):
|
class DnsTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_resolve(self):
|
def test_resolve(self):
|
||||||
dns_client_mock = Mock()
|
dns_client_mock = Mock()
|
||||||
|
|
||||||
service_name = "test_service.examle.com"
|
service_name = "test_service.example.com"
|
||||||
host_name = "example.com"
|
host_name = "example.com"
|
||||||
ip_address = "127.0.0.1"
|
ip_address = "127.0.0.1"
|
||||||
|
ip6_address = "::1"
|
||||||
|
|
||||||
answer_srv = dns.RRHeader(
|
answer_srv = dns.RRHeader(
|
||||||
type=dns.SRV,
|
type=dns.SRV,
|
||||||
|
@ -48,8 +50,22 @@ class DnsTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
dns_client_mock.lookupService.return_value = ([answer_srv], None, None)
|
answer_aaaa = dns.RRHeader(
|
||||||
dns_client_mock.lookupAddress.return_value = ([answer_a], None, None)
|
type=dns.AAAA,
|
||||||
|
payload=dns.Record_AAAA(
|
||||||
|
address=ip6_address,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dns_client_mock.lookupService.return_value = defer.succeed(
|
||||||
|
([answer_srv], None, None),
|
||||||
|
)
|
||||||
|
dns_client_mock.lookupAddress.return_value = defer.succeed(
|
||||||
|
([answer_a], None, None),
|
||||||
|
)
|
||||||
|
dns_client_mock.lookupIPV6Address.return_value = defer.succeed(
|
||||||
|
([answer_aaaa], None, None),
|
||||||
|
)
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
|
|
||||||
|
@ -59,10 +75,12 @@ class DnsTestCase(unittest.TestCase):
|
||||||
|
|
||||||
dns_client_mock.lookupService.assert_called_once_with(service_name)
|
dns_client_mock.lookupService.assert_called_once_with(service_name)
|
||||||
dns_client_mock.lookupAddress.assert_called_once_with(host_name)
|
dns_client_mock.lookupAddress.assert_called_once_with(host_name)
|
||||||
|
dns_client_mock.lookupIPV6Address.assert_called_once_with(host_name)
|
||||||
|
|
||||||
self.assertEquals(len(servers), 1)
|
self.assertEquals(len(servers), 2)
|
||||||
self.assertEquals(servers, cache[service_name])
|
self.assertEquals(servers, cache[service_name])
|
||||||
self.assertEquals(servers[0].host, ip_address)
|
self.assertEquals(servers[0].host, ip_address)
|
||||||
|
self.assertEquals(servers[1].host, ip6_address)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_from_cache_expired_and_dns_fail(self):
|
def test_from_cache_expired_and_dns_fail(self):
|
||||||
|
|
Loading…
Reference in a new issue