AdGuardHome/internal/dnsforward/dnsforward_test.go
Stanislav Chzhen 6363f8a2e7 Pull request 2286: AGDNS-2374-slog-safesearch
Squashed commit of the following:

commit 1909dfed99b8815c1215c709efcae77a70b52ea3
Merge: 3856fda5f 2c64ab5a5
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Oct 9 16:21:38 2024 +0300

    Merge branch 'master' into AGDNS-2374-slog-safesearch

commit 3856fda5f3
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Oct 8 20:04:34 2024 +0300

    home: imp code

commit de774009aa
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Oct 7 16:41:58 2024 +0300

    all: imp code

commit 038bae59d5
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Oct 3 20:24:48 2024 +0300

    all: imp code

commit 792975e248
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Oct 3 15:46:40 2024 +0300

    all: slog safesearch
2024-10-09 16:31:03 +03:00

1673 lines
44 KiB
Go

package dnsforward
import (
"cmp"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/hex"
"encoding/pem"
"fmt"
"math/big"
"net"
"net/netip"
"sync"
"sync/atomic"
"testing"
"testing/fstest"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/hashprefix"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// testTimeout is the common timeout for tests.
//
// TODO(a.garipov): Use more.
const testTimeout = 1 * time.Second
// testQuestionTarget is the common question target for tests.
//
// TODO(a.garipov): Use more.
const testQuestionTarget = "target.example"
const (
tlsServerName = "testdns.adguard.com"
testMessagesCount = 10
)
// testClientAddrPort is the common net.Addr for tests.
//
// TODO(a.garipov): Use more.
var testClientAddrPort = netip.MustParseAddrPort("1.2.3.4:12345")
func startDeferStop(t *testing.T, s *Server) {
t.Helper()
err := s.Start()
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, s.Stop)
}
func createTestServer(
t *testing.T,
filterConf *filtering.Config,
forwardConf ServerConfig,
) (s *Server) {
t.Helper()
rules := `||nxdomain.example.org
||NULL.example.org^
127.0.0.1 host.example.org
@@||whitelist.example.org^
||127.0.0.255`
filters := []filtering.Filter{{
ID: 0,
Data: []byte(rules),
}}
f, err := filtering.New(filterConf, filters)
require.NoError(t, err)
f.SetEnabled(true)
dhcp := &testDHCP{
OnEnabled: func() (ok bool) { return false },
OnHostByIP: func(ip netip.Addr) (host string) { return "" },
OnIPByHost: func(host string) (ip netip.Addr) { panic("not implemented") },
}
s, err = NewServer(DNSCreateParams{
DHCPServer: dhcp,
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
Logger: slogutil.NewDiscardLogger(),
})
require.NoError(t, err)
err = s.Prepare(&forwardConf)
require.NoError(t, err)
return s
}
func createServerTLSConfig(t *testing.T) (*tls.Config, []byte, []byte) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoErrorf(t, err, "cannot generate RSA key: %s", err)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
require.NoErrorf(t, err, "failed to generate serial number: %s", err)
notBefore := time.Now()
notAfter := notBefore.Add(5 * 365 * timeutil.Day)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"AdGuard Tests"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
template.DNSNames = append(template.DNSNames, tlsServerName)
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey)
require.NoErrorf(t, err, "failed to create certificate: %s", err)
certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
cert, err := tls.X509KeyPair(certPem, keyPem)
require.NoErrorf(t, err, "failed to create certificate: %s", err)
return &tls.Config{
Certificates: []tls.Certificate{cert},
ServerName: tlsServerName,
MinVersion: tls.VersionTLS12,
}, certPem, keyPem
}
func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte) {
t.Helper()
var keyPem []byte
_, certPem, keyPem = createServerTLSConfig(t)
s = createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
})
tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem
s.conf.TLSConfig = tlsConf
err := s.Prepare(&s.conf)
require.NoErrorf(t, err, "failed to prepare server: %s", err)
return s, certPem
}
const googleDomainName = "google-public-dns-a.google.com."
func createGoogleATestMessage() *dns.Msg {
return createTestMessage(googleDomainName)
}
func newGoogleUpstream() (u upstream.Upstream) {
return &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "google.upstream.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
return cmp.Or(
aghtest.MatchedResponse(req, dns.TypeA, googleDomainName, "8.8.8.8"),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil
},
OnClose: func() (err error) { return nil },
}
}
func createTestMessage(host string) *dns.Msg {
return &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: host,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
}
func createTestMessageWithType(host string, qtype uint16) *dns.Msg {
req := createTestMessage(host)
req.Question[0].Qtype = qtype
return req
}
// newResp returns the new DNS response with response code set to rcode, req
// used as request, and rrs added.
func newResp(rcode int, req *dns.Msg, ans []dns.RR) (resp *dns.Msg) {
resp = (&dns.Msg{}).SetRcode(req, rcode)
resp.RecursionAvailable = true
resp.Compress = true
resp.Answer = ans
return resp
}
func assertGoogleAResponse(t *testing.T, reply *dns.Msg) {
assertResponse(t, reply, netip.AddrFrom4([4]byte{8, 8, 8, 8}))
}
func assertResponse(t *testing.T, reply *dns.Msg, ip netip.Addr) {
t.Helper()
require.Lenf(t, reply.Answer, 1, "dns server returned reply with wrong number of answers - %d", len(reply.Answer))
a, ok := reply.Answer[0].(*dns.A)
require.Truef(t, ok, "dns server returned wrong answer type instead of A: %v", reply.Answer[0])
assert.Equal(t, net.IP(ip.AsSlice()), a.A)
}
// sendTestMessagesAsync sends messages in parallel to check for race issues.
//
//lint:ignore U1000 it's called from the function which is skipped for now.
func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) {
t.Helper()
wg := &sync.WaitGroup{}
for range testMessagesCount {
msg := createGoogleATestMessage()
wg.Add(1)
go func() {
defer wg.Done()
err := conn.WriteMsg(msg)
require.NoErrorf(t, err, "cannot write message: %s", err)
res, err := conn.ReadMsg()
require.NoErrorf(t, err, "cannot read response to message: %s", err)
assertGoogleAResponse(t, res)
}()
}
wg.Wait()
}
func sendTestMessages(t *testing.T, conn *dns.Conn) {
t.Helper()
for i := range testMessagesCount {
req := createGoogleATestMessage()
err := conn.WriteMsg(req)
assert.NoErrorf(t, err, "cannot write message #%d: %s", i, err)
res, err := conn.ReadMsg()
assert.NoErrorf(t, err, "cannot read response to message #%d: %s", i, err)
assertGoogleAResponse(t, res)
}
}
func TestServer(t *testing.T) {
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
})
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
startDeferStop(t, s)
testCases := []struct {
name string
net string
proto proxy.Proto
}{{
name: "message_over_udp",
net: "",
proto: proxy.ProtoUDP,
}, {
name: "message_over_tcp",
net: "tcp",
proto: proxy.ProtoTCP,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
addr := s.dnsProxy.Addr(tc.proto)
client := dns.Client{Net: tc.net}
reply, _, err := client.Exchange(createGoogleATestMessage(), addr.String())
require.NoErrorf(t, err, "couldn't talk to server %s: %s", addr, err)
assertGoogleAResponse(t, reply)
})
}
}
func TestServer_timeout(t *testing.T) {
t.Run("custom", func(t *testing.T) {
srvConf := &ServerConfig{
UpstreamTimeout: testTimeout,
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}
s, err := NewServer(DNSCreateParams{
DNSFilter: createTestDNSFilter(t),
Logger: slogutil.NewDiscardLogger(),
})
require.NoError(t, err)
err = s.Prepare(srvConf)
require.NoError(t, err)
assert.Equal(t, testTimeout, s.conf.UpstreamTimeout)
})
t.Run("default", func(t *testing.T) {
s, err := NewServer(DNSCreateParams{
DNSFilter: createTestDNSFilter(t),
Logger: slogutil.NewDiscardLogger(),
})
require.NoError(t, err)
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{
Enabled: false,
}
err = s.Prepare(&s.conf)
require.NoError(t, err)
assert.Equal(t, DefaultTimeout, s.conf.UpstreamTimeout)
})
}
func TestServer_Prepare_fallbacks(t *testing.T) {
srvConf := &ServerConfig{
Config: Config{
FallbackDNS: []string{
"#tls://1.1.1.1",
"8.8.8.8",
},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}
s, err := NewServer(DNSCreateParams{
Logger: slogutil.NewDiscardLogger(),
})
require.NoError(t, err)
err = s.Prepare(srvConf)
require.NoError(t, err)
require.NotNil(t, s.dnsProxy.Fallbacks)
assert.Len(t, s.dnsProxy.Fallbacks.Upstreams, 1)
}
func TestServerWithProtectionDisabled(t *testing.T) {
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
})
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
startDeferStop(t, s)
// Message over UDP.
req := createGoogleATestMessage()
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
client := &dns.Client{}
reply, _, err := client.Exchange(req, addr.String())
require.NoErrorf(t, err, "couldn't talk to server %s: %s", addr, err)
assertGoogleAResponse(t, reply)
}
func TestDoTServer(t *testing.T) {
s, certPem := createTestTLS(t, TLSConfig{
TLSListenAddrs: []*net.TCPAddr{{}},
})
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
startDeferStop(t, s)
// Add our self-signed generated config to roots.
roots := x509.NewCertPool()
roots.AppendCertsFromPEM(certPem)
tlsConfig := &tls.Config{
ServerName: tlsServerName,
RootCAs: roots,
MinVersion: tls.VersionTLS12,
}
// Create a DNS-over-TLS client connection.
addr := s.dnsProxy.Addr(proxy.ProtoTLS)
conn, err := dns.DialWithTLS("tcp-tls", addr.String(), tlsConfig)
require.NoErrorf(t, err, "cannot connect to the proxy: %s", err)
sendTestMessages(t, conn)
}
func TestDoQServer(t *testing.T) {
s, _ := createTestTLS(t, TLSConfig{
QUICListenAddrs: []*net.UDPAddr{{IP: net.IP{127, 0, 0, 1}}},
})
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
startDeferStop(t, s)
// Create a DNS-over-QUIC upstream.
addr := s.dnsProxy.Addr(proxy.ProtoQUIC)
opts := &upstream.Options{InsecureSkipVerify: true}
u, err := upstream.AddressToUpstream(fmt.Sprintf("%s://%s", proxy.ProtoQUIC, addr), opts)
require.NoError(t, err)
// Send the test message.
req := createGoogleATestMessage()
res, err := u.Exchange(req)
require.NoError(t, err)
assertGoogleAResponse(t, res)
}
func TestServerRace(t *testing.T) {
t.Skip("TODO(e.burkov): inspect the golibs/cache package for locks")
filterConf := &filtering.Config{
SafeBrowsingEnabled: true,
SafeBrowsingCacheSize: 1000,
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
SafeSearchCacheSize: 1000,
ParentalCacheSize: 1000,
CacheTime: 30,
}
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
},
ConfigModified: func() {},
ServePlainDNS: true,
}
s := createTestServer(t, filterConf, forwardConf)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
startDeferStop(t, s)
// Message over UDP.
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
conn, err := dns.Dial("udp", addr.String())
require.NoErrorf(t, err, "cannot connect to the proxy: %s", err)
sendTestMessagesAsync(t, conn)
}
func TestSafeSearch(t *testing.T) {
safeSearchConf := filtering.SafeSearchConfig{
Enabled: true,
Google: true,
Yandex: true,
}
filterConf := &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
ProtectionEnabled: true,
SafeSearchConf: safeSearchConf,
SafeSearchCacheSize: 1000,
CacheTime: 30,
}
ctx := testutil.ContextWithTimeout(t, testTimeout)
safeSearch, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
ServicesConfig: safeSearchConf,
CacheSize: filterConf.SafeSearchCacheSize,
CacheTTL: time.Minute * time.Duration(filterConf.CacheTime),
})
require.NoError(t, err)
filterConf.SafeSearch = safeSearch
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
},
ServePlainDNS: true,
}
s := createTestServer(t, filterConf, forwardConf)
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
client := &dns.Client{}
yandexIP := netip.AddrFrom4([4]byte{213, 180, 193, 56})
testCases := []struct {
host string
want netip.Addr
wantCNAME string
}{{
host: "yandex.com.",
want: yandexIP,
wantCNAME: "",
}, {
host: "yandex.by.",
want: yandexIP,
wantCNAME: "",
}, {
host: "yandex.kz.",
want: yandexIP,
wantCNAME: "",
}, {
host: "yandex.ru.",
want: yandexIP,
wantCNAME: "",
}, {
host: "www.google.com.",
want: netip.Addr{},
wantCNAME: "forcesafesearch.google.com.",
}, {
host: "www.google.com.af.",
want: netip.Addr{},
wantCNAME: "forcesafesearch.google.com.",
}, {
host: "www.google.be.",
want: netip.Addr{},
wantCNAME: "forcesafesearch.google.com.",
}, {
host: "www.google.by.",
want: netip.Addr{},
wantCNAME: "forcesafesearch.google.com.",
}}
for _, tc := range testCases {
t.Run(tc.host, func(t *testing.T) {
req := createTestMessage(tc.host)
// TODO(a.garipov): Create our own helper for this.
var reply *dns.Msg
once := &sync.Once{}
require.EventuallyWithT(t, func(c *assert.CollectT) {
r, _, errExch := client.Exchange(req, addr)
if assert.NoError(c, errExch) {
once.Do(func() { reply = r })
}
}, testTimeout*10, testTimeout)
if tc.wantCNAME != "" {
require.Len(t, reply.Answer, 2)
cname := testutil.RequireTypeAssert[*dns.CNAME](t, reply.Answer[0])
assert.Equal(t, tc.wantCNAME, cname.Target)
a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[1])
assert.NotEmpty(t, a.A)
} else {
require.Len(t, reply.Answer, 1)
a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[0])
assert.Equal(t, net.IP(tc.want.AsSlice()), a.A)
}
})
}
}
func TestInvalidRequest(t *testing.T) {
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
},
ServePlainDNS: true,
})
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
req := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
}
// Send a DNS request without question.
_, _, err := (&dns.Client{
Timeout: testTimeout,
}).Exchange(&req, addr)
assert.NoErrorf(t, err, "got a response to an invalid query")
}
func TestBlockedRequest(t *testing.T) {
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
},
ServePlainDNS: true,
}
s := createTestServer(t, &filtering.Config{
ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeDefault,
}, forwardConf)
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// Default blocking.
req := createTestMessage("nxdomain.example.org.")
reply, err := dns.Exchange(req, addr.String())
require.NoErrorf(t, err, "couldn't talk to server %s: %s", addr, err)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
require.Len(t, reply.Answer, 1)
assert.True(t, reply.Answer[0].(*dns.A).A.IsUnspecified())
}
func TestServerCustomClientUpstream(t *testing.T) {
const defaultCacheSize = 1024 * 1024
var upsCalledCounter uint32
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
CacheSize: defaultCacheSize,
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
},
ServePlainDNS: true,
}
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, forwardConf)
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
atomic.AddUint32(&upsCalledCounter, 1)
return cmp.Or(
aghtest.MatchedResponse(req, dns.TypeA, "host", "192.168.0.1"),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil
})
customUpsConf := proxy.NewCustomUpstreamConfig(
&proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
true,
defaultCacheSize,
forwardConf.EDNSClientSubnet.Enabled,
)
s.conf.ClientsContainer = &aghtest.ClientsContainer{
OnUpstreamConfigByID: func(
_ string,
_ upstream.Resolver,
) (conf *proxy.CustomUpstreamConfig, err error) {
return customUpsConf, nil
},
}
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
// Send test request.
req := createTestMessage("host.")
reply, err := dns.Exchange(req, addr)
require.NoError(t, err)
require.NotEmpty(t, reply.Answer)
require.Len(t, reply.Answer, 1)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
assert.Equal(t, net.IP{192, 168, 0, 1}, reply.Answer[0].(*dns.A).A)
assert.Equal(t, uint32(1), atomic.LoadUint32(&upsCalledCounter))
_, err = dns.Exchange(req, addr)
require.NoError(t, err)
assert.Equal(t, uint32(1), atomic.LoadUint32(&upsCalledCounter))
}
// testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work.
var testCNAMEs = map[string][]string{
"badhost.": {"NULL.example.org."},
"whitelist.example.org.": {"NULL.example.org."},
}
// testIPv4 is a map of names and IPv4s necessary for the TestUpstream work.
var testIPv4 = map[string][]net.IP{
"NULL.example.org.": {{1, 2, 3, 4}},
"example.org.": {{127, 0, 0, 255}},
}
func TestBlockCNAMEProtectionEnabled(t *testing.T) {
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
},
ServePlainDNS: true,
})
testUpstm := &aghtest.Upstream{
CName: testCNAMEs,
IPv4: testIPv4,
}
s.dnsProxy.UpstreamConfig = &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{testUpstm},
}
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// 'badhost' has a canonical name 'NULL.example.org' which should be
// blocked by filters, but protection is disabled so it is not.
req := createTestMessage("badhost.")
reply, err := dns.Exchange(req, addr.String())
require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
}
func TestBlockCNAME(t *testing.T) {
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
},
ServePlainDNS: true,
}
s := createTestServer(t, &filtering.Config{
ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeDefault,
}, forwardConf)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.Upstream{
CName: testCNAMEs,
IPv4: testIPv4,
},
}
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
testCases := []struct {
name string
host string
want bool
}{{
name: "block_request",
host: "badhost.",
// 'badhost' has a canonical name 'NULL.example.org' which is
// blocked by filters: response is blocked.
want: true,
}, {
name: "allowed",
host: "whitelist.example.org.",
// 'whitelist.example.org' has a canonical name
// 'NULL.example.org' which is blocked by filters
// but 'whitelist.example.org' is in a whitelist:
// response isn't blocked.
want: false,
}, {
name: "block_response",
host: "example.org.",
// 'example.org' has a canonical name 'cname1' with IP
// 127.0.0.255 which is blocked by filters: response is blocked.
want: true,
}}
for _, tc := range testCases {
req := createTestMessage(tc.host)
t.Run(tc.name, func(t *testing.T) {
reply, err := dns.Exchange(req, addr)
require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
if tc.want {
require.Len(t, reply.Answer, 1)
ans := reply.Answer[0]
a, ok := ans.(*dns.A)
require.True(t, ok)
assert.True(t, a.A.IsUnspecified())
}
})
}
}
func TestClientRulesForCNAMEMatching(t *testing.T) {
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
FilterHandler: func(_ netip.Addr, _ string, settings *filtering.Settings) {
settings.FilteringEnabled = false
},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
},
ServePlainDNS: true,
}
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, forwardConf)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.Upstream{
CName: testCNAMEs,
IPv4: testIPv4,
},
}
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// 'badhost' has a canonical name 'NULL.example.org' which is blocked by
// filters: response is blocked.
req := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
},
Question: []dns.Question{{
Name: "badhost.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
// However, in our case it should not be blocked as filtering is
// disabled on the client level.
reply, err := dns.Exchange(&req, addr.String())
require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
}
func TestNullBlockedRequest(t *testing.T) {
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
},
ServePlainDNS: true,
}
s := createTestServer(t, &filtering.Config{
ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeNullIP,
}, forwardConf)
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// Nil filter blocking.
req := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: "NULL.example.org.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
reply, err := dns.Exchange(&req, addr.String())
require.NoErrorf(t, err, "couldn't talk to server %s: %s", addr, err)
require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
a, ok := reply.Answer[0].(*dns.A)
require.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
assert.Truef(t, a.A.IsUnspecified(), "dns server %s returned wrong answer instead of 0.0.0.0: %v", addr, a.A)
}
func TestBlockedCustomIP(t *testing.T) {
rules := "||nxdomain.example.org^\n||NULL.example.org^\n127.0.0.1 host.example.org\n@@||whitelist.example.org^\n||127.0.0.255\n"
filters := []filtering.Filter{{
ID: 0,
Data: []byte(rules),
}}
f, err := filtering.New(&filtering.Config{
ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeCustomIP,
BlockingIPv4: netip.Addr{},
BlockingIPv6: netip.Addr{},
}, filters)
require.NoError(t, err)
dhcp := &testDHCP{
OnEnabled: func() (ok bool) { return false },
OnHostByIP: func(_ netip.Addr) (host string) { panic("not implemented") },
OnIPByHost: func(_ string) (ip netip.Addr) { panic("not implemented") },
}
s, err := NewServer(DNSCreateParams{
DHCPServer: dhcp,
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
Logger: slogutil.NewDiscardLogger(),
})
require.NoError(t, err)
conf := &ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
},
ServePlainDNS: true,
}
// Invalid BlockingIPv4.
err = s.Prepare(conf)
assert.Error(t, err)
s.dnsFilter.SetBlockingMode(
filtering.BlockingModeCustomIP,
netip.AddrFrom4([4]byte{0, 0, 0, 1}),
netip.MustParseAddr("::1"))
err = s.Prepare(conf)
require.NoError(t, err)
f.SetEnabled(true)
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("NULL.example.org.", dns.TypeA)
reply, err := dns.Exchange(req, addr.String())
require.NoError(t, err)
require.Len(t, reply.Answer, 1)
a, ok := reply.Answer[0].(*dns.A)
require.True(t, ok)
assert.True(t, net.IP{0, 0, 0, 1}.Equal(a.A))
req = createTestMessageWithType("NULL.example.org.", dns.TypeAAAA)
reply, err = dns.Exchange(req, addr.String())
require.NoError(t, err)
require.Len(t, reply.Answer, 1)
a6, ok := reply.Answer[0].(*dns.AAAA)
require.True(t, ok)
assert.Equal(t, "::1", a6.AAAA.String())
}
func TestBlockedByHosts(t *testing.T) {
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
},
ServePlainDNS: true,
}
s := createTestServer(t, &filtering.Config{
ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeDefault,
}, forwardConf)
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// Hosts blocking.
req := createTestMessage("host.example.org.")
reply, err := dns.Exchange(req, addr.String())
require.NoErrorf(t, err, "couldn't talk to server %s: %s", addr, err)
require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
a, ok := reply.Answer[0].(*dns.A)
require.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
assert.Equalf(t, net.IP{127, 0, 0, 1}, a.A, "dns server %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A)
}
func TestBlockedBySafeBrowsing(t *testing.T) {
const (
hostname = "wmconvirus.narod.ru"
cacheTime = 10 * time.Minute
cacheSize = 10000
)
sbChecker := hashprefix.New(&hashprefix.Config{
CacheTime: cacheTime,
CacheSize: cacheSize,
Upstream: aghtest.NewBlockUpstream(hostname, true),
})
ans4, _ := aghtest.HostToIPs(hostname)
filterConf := &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
ProtectionEnabled: true,
SafeBrowsingEnabled: true,
SafeBrowsingChecker: sbChecker,
SafeBrowsingBlockHost: ans4.String(),
}
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
},
ServePlainDNS: true,
}
s := createTestServer(t, filterConf, forwardConf)
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// SafeBrowsing blocking.
req := createTestMessage(hostname + ".")
reply, err := dns.Exchange(req, addr.String())
require.NoErrorf(t, err, "couldn't talk to server %s: %s", addr, err)
require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
assertResponse(t, reply, ans4)
}
func TestRewrite(t *testing.T) {
c := &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
Rewrites: []*filtering.LegacyRewrite{{
Domain: "test.com",
Answer: "1.2.3.4",
Type: dns.TypeA,
}, {
Domain: "alias.test.com",
Answer: "test.com",
Type: dns.TypeCNAME,
}, {
Domain: "my.alias.example.org",
Answer: "example.org",
Type: dns.TypeCNAME,
}},
}
f, err := filtering.New(c, nil)
require.NoError(t, err)
f.SetEnabled(true)
dhcp := &testDHCP{
OnEnabled: func() (ok bool) { return false },
OnHostByIP: func(ip netip.Addr) (host string) { panic("not implemented") },
OnIPByHost: func(host string) (ip netip.Addr) { panic("not implemented") },
}
s, err := NewServer(DNSCreateParams{
DHCPServer: dhcp,
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
Logger: slogutil.NewDiscardLogger(),
})
require.NoError(t, err)
assert.NoError(t, s.Prepare(&ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamDNS: []string{"8.8.8.8:53"},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
},
ServePlainDNS: true,
}))
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
return cmp.Or(
aghtest.MatchedResponse(req, dns.TypeA, "example.org", "4.3.2.1"),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil
})
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ups}
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
subTestFunc := func(t *testing.T) {
req := createTestMessageWithType("test.com.", dns.TypeA)
reply, eerr := dns.Exchange(req, addr.String())
require.NoError(t, eerr)
require.Len(t, reply.Answer, 1)
a, ok := reply.Answer[0].(*dns.A)
require.True(t, ok)
assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A))
req = createTestMessageWithType("test.com.", dns.TypeAAAA)
reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, eerr)
assert.Empty(t, reply.Answer)
req = createTestMessageWithType("alias.test.com.", dns.TypeA)
reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, eerr)
require.Len(t, reply.Answer, 2)
assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target)
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, eerr)
// The original question is restored.
require.Len(t, reply.Question, 1)
assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name)
require.Len(t, reply.Answer, 2)
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
}
for _, protect := range []bool{true, false} {
val := protect
conf := s.getDNSConfig()
conf.ProtectionEnabled = &val
s.setConfig(conf)
t.Run(fmt.Sprintf("protection_is_%t", val), subTestFunc)
}
}
func publicKey(priv any) any {
switch k := priv.(type) {
case *rsa.PrivateKey:
return &k.PublicKey
case *ecdsa.PrivateKey:
return &k.PublicKey
default:
return nil
}
}
// testDHCP is a mock implementation of the [DHCP] interface.
type testDHCP struct {
OnHostByIP func(ip netip.Addr) (host string)
OnIPByHost func(host string) (ip netip.Addr)
OnEnabled func() (ok bool)
}
// type check
var _ DHCP = (*testDHCP)(nil)
// HostByIP implements the [DHCP] interface for *testDHCP.
func (d *testDHCP) HostByIP(ip netip.Addr) (host string) { return d.OnHostByIP(ip) }
// IPByHost implements the [DHCP] interface for *testDHCP.
func (d *testDHCP) IPByHost(host string) (ip netip.Addr) { return d.OnIPByHost(host) }
// IsClientHost implements the [DHCP] interface for *testDHCP.
func (d *testDHCP) Enabled() (ok bool) { return d.OnEnabled() }
func TestPTRResponseFromDHCPLeases(t *testing.T) {
const localDomain = "lan"
flt, err := filtering.New(&filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, nil)
require.NoError(t, err)
s, err := NewServer(DNSCreateParams{
DNSFilter: flt,
DHCPServer: &testDHCP{
OnEnabled: func() (ok bool) { return true },
OnIPByHost: func(host string) (ip netip.Addr) { panic("not implemented") },
OnHostByIP: func(ip netip.Addr) (host string) {
return "myhost"
},
},
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
Logger: slogutil.NewDiscardLogger(),
LocalDomain: localDomain,
})
require.NoError(t, err)
s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
err = s.Prepare(&s.conf)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
t.Cleanup(s.Close)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("34.12.168.192.in-addr.arpa.", dns.TypePTR)
resp, err := dns.Exchange(req, addr.String())
require.NoErrorf(t, err, "%s", addr)
require.Len(t, resp.Answer, 1)
ans := resp.Answer[0]
assert.Equal(t, dns.TypePTR, ans.Header().Rrtype)
assert.Equal(t, "34.12.168.192.in-addr.arpa.", ans.Header().Name)
ptr := testutil.RequireTypeAssert[*dns.PTR](t, ans)
assert.Equal(t, dns.Fqdn("myhost."+localDomain), ptr.Ptr)
}
func TestPTRResponseFromHosts(t *testing.T) {
// Prepare test hosts file.
const hostsFilename = "hosts"
testFS := fstest.MapFS{
hostsFilename: &fstest.MapFile{Data: []byte(`
127.0.0.1 host # comment
::1 localhost#comment
`)},
}
dhcp := &testDHCP{
OnEnabled: func() (ok bool) { return false },
OnIPByHost: func(host string) (ip netip.Addr) { panic("not implemented") },
OnHostByIP: func(ip netip.Addr) (host string) { return "" },
}
var eventsCalledCounter uint32
hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{
OnStart: func() (_ error) { panic("not implemented") },
OnEvents: func() (e <-chan struct{}) {
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
return nil
},
OnAdd: func(name string) (err error) {
assert.Equal(t, hostsFilename, name)
return nil
},
OnClose: func() (err error) { panic("not implemented") },
}, hostsFilename)
require.NoError(t, err)
t.Cleanup(func() {
assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter))
})
flt, err := filtering.New(&filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
EtcHosts: hc,
}, nil)
require.NoError(t, err)
flt.SetEnabled(true)
var s *Server
s, err = NewServer(DNSCreateParams{
DHCPServer: dhcp,
DNSFilter: flt,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
Logger: slogutil.NewDiscardLogger(),
})
require.NoError(t, err)
s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
err = s.Prepare(&s.conf)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
t.Cleanup(s.Close)
subTestFunc := func(t *testing.T) {
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR)
resp, eerr := dns.Exchange(req, addr.String())
require.NoError(t, eerr)
require.Len(t, resp.Answer, 1)
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
ptr, ok := resp.Answer[0].(*dns.PTR)
require.True(t, ok)
assert.Equal(t, "host.", ptr.Ptr)
}
for _, protect := range []bool{true, false} {
val := protect
conf := s.getDNSConfig()
conf.ProtectionEnabled = &val
s.setConfig(conf)
t.Run(fmt.Sprintf("protection_is_%t", val), subTestFunc)
}
}
func TestNewServer(t *testing.T) {
// TODO(a.garipov): Consider moving away from the text-based error
// checks and onto a more structured approach.
testCases := []struct {
name string
in DNSCreateParams
wantErrMsg string
}{{
name: "success",
in: DNSCreateParams{
Logger: slogutil.NewDiscardLogger(),
},
wantErrMsg: "",
}, {
name: "success_local_tld",
in: DNSCreateParams{
Logger: slogutil.NewDiscardLogger(),
LocalDomain: "mynet",
},
wantErrMsg: "",
}, {
name: "success_local_domain",
in: DNSCreateParams{
Logger: slogutil.NewDiscardLogger(),
LocalDomain: "my.local.net",
},
wantErrMsg: "",
}, {
name: "bad_local_domain",
in: DNSCreateParams{
Logger: slogutil.NewDiscardLogger(),
LocalDomain: "!!!",
},
wantErrMsg: `local domain: bad domain name "!!!": ` +
`bad top-level domain name label "!!!": ` +
`bad top-level domain name label rune '!'`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := NewServer(tc.in)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}
// doubleTTL is a helper function that returns a clone of DNS PTR with appended
// copy of first answer record with doubled TTL.
func doubleTTL(msg *dns.Msg) (resp *dns.Msg) {
if msg == nil {
return nil
}
if len(msg.Answer) == 0 {
return msg
}
rec := msg.Answer[0]
ptr, ok := rec.(*dns.PTR)
if !ok {
return msg
}
clone := *ptr
clone.Hdr.Ttl *= 2
msg.Answer = append(msg.Answer, &clone)
return msg
}
func TestServer_Exchange(t *testing.T) {
const (
onesHost = "one.one.one.one"
twosHost = "two.two.two.two"
localDomainHost = "local.domain"
defaultTTL = time.Second * 60
)
var (
onesIP = netip.MustParseAddr("1.1.1.1")
twosIP = netip.MustParseAddr("2.2.2.2")
localIP = netip.MustParseAddr("192.168.1.1")
pt = testutil.PanicT{}
)
onesRevExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice())
require.NoError(t, err)
twosRevExtIPv4, err := netutil.IPToReversedAddr(twosIP.AsSlice())
require.NoError(t, err)
extUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := cmp.Or(
aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, dns.Fqdn(onesHost)),
doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, dns.Fqdn(twosHost))),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
)
require.NoError(pt, w.WriteMsg(resp))
})
upsAddr := aghtest.StartLocalhostUpstream(t, extUpsHdlr).String()
revLocIPv4, err := netutil.IPToReversedAddr(localIP.AsSlice())
require.NoError(t, err)
locUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := cmp.Or(
aghtest.MatchedResponse(req, dns.TypePTR, revLocIPv4, dns.Fqdn(localDomainHost)),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
)
require.NoError(pt, w.WriteMsg(resp))
})
errUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
require.NoError(pt, w.WriteMsg(new(dns.Msg).SetRcode(req, dns.RcodeServerFailure)))
})
nonPtrHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
hash := sha256.Sum256([]byte("some-host"))
resp := (&dns.Msg{
Answer: []dns.RR{&dns.TXT{
Hdr: dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: 60,
},
Txt: []string{hex.EncodeToString(hash[:])},
}},
}).SetReply(req)
require.NoError(pt, w.WriteMsg(resp))
})
refusingHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
require.NoError(pt, w.WriteMsg(new(dns.Msg).SetRcode(req, dns.RcodeRefused)))
})
zeroTTLHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := (&dns.Msg{
Answer: []dns.RR{&dns.PTR{
Hdr: dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 0,
},
Ptr: dns.Fqdn(localDomainHost),
}},
}).SetReply(req)
require.NoError(pt, w.WriteMsg(resp))
})
testCases := []struct {
req netip.Addr
wantErr error
locUpstream dns.Handler
name string
want string
wantTTL time.Duration
}{{
name: "external_good",
want: onesHost,
wantErr: nil,
locUpstream: nil,
req: onesIP,
wantTTL: defaultTTL,
}, {
name: "local_good",
want: localDomainHost,
wantErr: nil,
locUpstream: locUpsHdlr,
req: localIP,
wantTTL: defaultTTL,
}, {
name: "upstream_error",
want: "",
wantErr: ErrRDNSFailed,
locUpstream: errUpsHdlr,
req: localIP,
wantTTL: 0,
}, {
name: "empty_answer_error",
want: "",
wantErr: ErrRDNSNoData,
locUpstream: locUpsHdlr,
req: netip.MustParseAddr("192.168.1.2"),
wantTTL: 0,
}, {
name: "invalid_answer",
want: "",
wantErr: ErrRDNSNoData,
locUpstream: nonPtrHdlr,
req: localIP,
wantTTL: 0,
}, {
name: "refused",
want: "",
wantErr: ErrRDNSFailed,
locUpstream: refusingHdlr,
req: localIP,
wantTTL: 0,
}, {
name: "longest_ttl",
want: twosHost,
wantErr: nil,
locUpstream: nil,
req: twosIP,
wantTTL: defaultTTL * 2,
}, {
name: "zero_ttl",
want: localDomainHost,
wantErr: nil,
locUpstream: zeroTTLHdlr,
req: localIP,
wantTTL: 0,
}}
for _, tc := range testCases {
localUpsAddr := aghtest.StartLocalhostUpstream(t, tc.locUpstream).String()
t.Run(tc.name, func(t *testing.T) {
srv := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
Config: Config{
UpstreamDNS: []string{upsAddr},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
LocalPTRResolvers: []string{localUpsAddr},
UsePrivateRDNS: true,
ServePlainDNS: true,
})
host, ttl, eerr := srv.Exchange(tc.req)
require.ErrorIs(t, eerr, tc.wantErr)
assert.Equal(t, tc.want, host)
assert.Equal(t, tc.wantTTL, ttl)
})
}
t.Run("resolving_disabled", func(t *testing.T) {
srv := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
Config: Config{
UpstreamDNS: []string{upsAddr},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
LocalPTRResolvers: []string{},
ServePlainDNS: true,
})
host, _, eerr := srv.Exchange(localIP)
require.NoError(t, eerr)
assert.Empty(t, host)
})
}