AdGuardHome/internal/dnsforward/dnsforward_test.go
Ainar Garipov fa33568fab Pull request: all: imp errcheck in tests
Merge in DNS/adguard-home from imp-errcheck to master

Squashed commit of the following:

commit ed046b8ef59a092a27c623cd14b3fc2ef305fc3d
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Feb 4 15:00:04 2021 +0300

    stats: imp cleanup more

commit e53a9240d3e3eec2417c768b98c267a8cd54d992
Merge: da734317 676f8c76
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Feb 4 14:59:40 2021 +0300

    Merge branch 'master' into imp-errcheck

commit da734317035543b52e5a9030813084bdc92ba90a
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Feb 4 14:37:26 2021 +0300

    stats: imp cleanup

commit 8b4ad150129111a09be6fa2944a21bd06ab8e5a1
Merge: 385c8a6c 5081ead0
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Feb 4 14:34:26 2021 +0300

    Merge branch 'master' into imp-errcheck

commit 385c8a6c91e3bf07a457da370c8cc77820b91600
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Jan 29 20:41:57 2021 +0300

    all: imp errcheck in tests
2021-02-04 15:12:34 +03:00

1169 lines
30 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package dnsforward
import (
"context"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"io/ioutil"
"math/big"
"net"
"os"
"sort"
"sync"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/testutil"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
const (
tlsServerName = "testdns.adguard.com"
testMessagesCount = 10
)
func startDeferStop(t *testing.T, s *Server) {
t.Helper()
err := s.Start()
assert.Nilf(t, err, "failed to start server: %s", err)
t.Cleanup(func() {
err := s.Stop()
assert.Nilf(t, err, "dns server failed to stop: %s", err)
})
}
func TestServer(t *testing.T) {
s := createTestServer(t)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&testUpstream{
ipv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
},
},
}
startDeferStop(t, s)
testCases := []struct {
name string
proto string
}{{
name: "message_over_udp",
proto: proxy.ProtoUDP,
}, {
name: "message_over_tcp",
proto: proxy.ProtoTCP,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
addr := s.dnsProxy.Addr(tc.proto)
client := dns.Client{Net: tc.proto}
reply, _, err := client.Exchange(createGoogleATestMessage(), addr.String())
assert.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err)
assertGoogleAResponse(t, reply)
})
}
}
func TestServerWithProtectionDisabled(t *testing.T) {
s := createTestServer(t)
s.conf.ProtectionEnabled = false
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&testUpstream{
ipv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
},
},
}
startDeferStop(t, s)
// Message over UDP.
req := createGoogleATestMessage()
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
client := dns.Client{Net: proxy.ProtoUDP}
reply, _, err := client.Exchange(req, addr.String())
assert.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err)
assertGoogleAResponse(t, reply)
}
func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte) {
t.Helper()
var keyPem []byte
_, certPem, keyPem = createServerTLSConfig(t)
s = createTestServer(t)
tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem
s.conf.TLSConfig = tlsConf
err := s.Prepare(nil)
assert.Nilf(t, err, "failed to prepare server: %s", err)
return s, certPem
}
func TestDoTServer(t *testing.T) {
s, certPem := createTestTLS(t, TLSConfig{
TLSListenAddr: &net.TCPAddr{Port: 0},
})
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&testUpstream{
ipv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
},
},
}
startDeferStop(t, s)
// Add our self-signed generated config to roots.
roots := x509.NewCertPool()
roots.AppendCertsFromPEM(certPem)
tlsConfig := &tls.Config{
ServerName: tlsServerName,
RootCAs: roots,
MinVersion: tls.VersionTLS12,
}
// Create a DNS-over-TLS client connection.
addr := s.dnsProxy.Addr(proxy.ProtoTLS)
conn, err := dns.DialWithTLS("tcp-tls", addr.String(), tlsConfig)
assert.Nilf(t, err, "cannot connect to the proxy: %s", err)
sendTestMessages(t, conn)
}
func TestDoQServer(t *testing.T) {
s, _ := createTestTLS(t, TLSConfig{
QUICListenAddr: &net.UDPAddr{Port: 0},
})
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&testUpstream{
ipv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
},
},
}
startDeferStop(t, s)
// Create a DNS-over-QUIC upstream.
addr := s.dnsProxy.Addr(proxy.ProtoQUIC)
opts := upstream.Options{InsecureSkipVerify: true}
u, err := upstream.AddressToUpstream(fmt.Sprintf("%s://%s", proxy.ProtoQUIC, addr), opts)
assert.Nil(t, err)
// Send the test message.
req := createGoogleATestMessage()
res, err := u.Exchange(req)
assert.Nil(t, err)
assertGoogleAResponse(t, res)
}
func TestServerRace(t *testing.T) {
t.Skip("TODO(e.burkov): inspect the golibs/cache package for locks")
s := createTestServer(t)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&testUpstream{
ipv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
},
},
}
startDeferStop(t, s)
// Message over UDP.
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
conn, err := dns.Dial(proxy.ProtoUDP, addr.String())
assert.Nilf(t, err, "cannot connect to the proxy: %s", err)
sendTestMessagesAsync(t, conn)
}
// testResolver is a Resolver for tests.
//
//lint:ignore U1000 TODO(e.burkov): move into aghtest package.
type testResolver struct{}
// LookupIPAddr implements Resolver interface for *testResolver.
//
//lint:ignore U1000 TODO(e.burkov): move into aghtest package.
func (r *testResolver) LookupIPAddr(_ context.Context, host string) (ips []net.IPAddr, err error) {
hash := sha256.Sum256([]byte(host))
addrs := []net.IPAddr{{
IP: net.IP(hash[:4]),
Zone: "somezone",
}, {
IP: net.IP(hash[4:20]),
Zone: "somezone",
}}
return addrs, nil
}
// LookupHost implements Resolver interface for *testResolver.
//
//lint:ignore U1000 TODO(e.burkov): move into aghtest package.
func (r *testResolver) LookupHost(host string) (addrs []string, err error) {
hash := sha256.Sum256([]byte(host))
addrs = []string{
net.IP(hash[:4]).String(),
net.IP(hash[4:20]).String(),
}
return addrs, nil
}
func TestSafeSearch(t *testing.T) {
t.Skip("TODO(e.burkov): substitute the dnsfilter by one with custom resolver from aghtest")
testUpstreamIP := net.IP{213, 180, 193, 56}
testCases := []string{
"yandex.com.",
"yandex.by.",
"yandex.kz.",
"yandex.ru.",
"www.google.com.",
"www.google.com.af.",
"www.google.be.",
"www.google.by.",
}
for _, tc := range testCases {
t.Run("safe_search_"+tc, func(t *testing.T) {
s := createTestServer(t)
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
client := dns.Client{Net: proxy.ProtoUDP}
exchangeAndAssertResponse(t, &client, addr, tc, testUpstreamIP)
})
}
}
func TestInvalidRequest(t *testing.T) {
s := createTestServer(t)
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
req := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
}
// Send a DNS request without question.
_, _, err := (&dns.Client{
Net: proxy.ProtoUDP,
Timeout: 500 * time.Millisecond,
}).Exchange(&req, addr)
assert.Nil(t, err, "got a response to an invalid query")
}
func TestBlockedRequest(t *testing.T) {
s := createTestServer(t)
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// Default blocking.
req := createTestMessage("nxdomain.example.org.")
reply, err := dns.Exchange(req, addr.String())
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
assert.True(t, reply.Answer[0].(*dns.A).A.IsUnspecified())
}
func TestServerCustomClientUpstream(t *testing.T) {
s := createTestServer(t)
s.conf.GetCustomUpstreamByClient = func(_ string) *proxy.UpstreamConfig {
return &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{
&testUpstream{
ipv4: map[string][]net.IP{
"host.": {{192, 168, 0, 1}},
},
},
},
}
}
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// Send test request.
req := createTestMessage("host.")
reply, err := dns.Exchange(req, addr.String())
assert.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
assert.NotEmpty(t, reply.Answer)
assert.Equal(t, net.IP{192, 168, 0, 1}, reply.Answer[0].(*dns.A).A)
}
// testUpstream is a mock of real upstream. specify fields with necessary values
// to simulate real upstream behaviour.
//
// TODO(e.burkov): move into aghtest package.
type testUpstream struct {
// cn is a map of hostname to canonical name.
cn map[string]string
// ipv4 is a map of hostname to IPv4.
ipv4 map[string][]net.IP
// ipv6 is a map of hostname to IPv6.
ipv6 map[string][]net.IP
}
// Exchange implements upstream.Upstream interface for *testUpstream.
func (u *testUpstream) Exchange(m *dns.Msg) (*dns.Msg, error) {
resp := &dns.Msg{}
resp.SetReply(m)
hasRec := false
name := m.Question[0].Name
if cname, ok := u.cn[name]; ok {
resp.Answer = append(resp.Answer, &dns.CNAME{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeCNAME,
},
Target: cname,
})
}
var rrtype uint16
var a []net.IP
switch m.Question[0].Qtype {
case dns.TypeA:
rrtype = dns.TypeA
if ipv4addr, ok := u.ipv4[name]; ok {
hasRec = true
a = ipv4addr
}
case dns.TypeAAAA:
rrtype = dns.TypeAAAA
if ipv6addr, ok := u.ipv6[name]; ok {
hasRec = true
a = ipv6addr
}
}
for _, ip := range a {
resp.Answer = append(resp.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: name,
Rrtype: rrtype,
},
A: ip,
})
}
if len(resp.Answer) == 0 {
if hasRec {
// Set no error RCode if there are some records for
// given Qname but we didn't apply them.
resp.SetRcode(m, dns.RcodeSuccess)
return resp, nil
}
// Set NXDomain RCode otherwise.
resp.SetRcode(m, dns.RcodeNameError)
}
return resp, nil
}
// Address implements upstream.Upstream interface for *testUpstream.
func (u *testUpstream) Address() string {
return "test"
}
func (s *Server) startWithUpstream(u upstream.Upstream) error {
s.Lock()
defer s.Unlock()
err := s.Prepare(nil)
if err != nil {
return err
}
s.dnsProxy.UpstreamConfig = &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{u},
}
return s.dnsProxy.Start()
}
// testCNAMEs is a simple map of names and CNAMEs necessary for the testUpstream work
var testCNAMEs = map[string]string{
"badhost.": "null.example.org.",
"whitelist.example.org.": "null.example.org.",
}
// testIPv4 is a simple map of names and IPv4s necessary for the testUpstream work
var testIPv4 = map[string][]net.IP{
"null.example.org.": {{1, 2, 3, 4}},
"example.org.": {{127, 0, 0, 255}},
}
func TestBlockCNAMEProtectionEnabled(t *testing.T) {
s := createTestServer(t)
testUpstm := &testUpstream{
cn: testCNAMEs,
ipv4: testIPv4,
ipv6: nil,
}
s.conf.ProtectionEnabled = false
err := s.startWithUpstream(testUpstm)
assert.Nil(t, err)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// 'badhost' has a canonical name 'null.example.org' which is blocked by
// filters: but protection is disabled so response is _not_ blocked.
req := createTestMessage("badhost.")
reply, err := dns.Exchange(req, addr.String())
assert.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
}
func TestBlockCNAME(t *testing.T) {
s := createTestServer(t)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&testUpstream{
cn: testCNAMEs,
ipv4: testIPv4,
},
}
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
testCases := []struct {
host string
want bool
}{{
host: "badhost.",
// 'badhost' has a canonical name 'null.example.org' which is
// blocked by filters: response is blocked.
want: true,
}, {
host: "whitelist.example.org.",
// 'whitelist.example.org' has a canonical name
// 'null.example.org' which is blocked by filters
// but 'whitelist.example.org' is in a whitelist:
// response isn't blocked.
want: false,
}, {
host: "example.org.",
// 'example.org' has a canonical name 'cname1' with IP
// 127.0.0.255 which is blocked by filters: response is blocked.
want: true,
}}
for _, tc := range testCases {
t.Run("block_cname_"+tc.host, func(t *testing.T) {
req := createTestMessage(tc.host)
reply, err := dns.Exchange(req, addr)
assert.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
if tc.want {
assert.True(t, reply.Answer[0].(*dns.A).A.IsUnspecified())
}
})
}
}
func TestClientRulesForCNAMEMatching(t *testing.T) {
s := createTestServer(t)
s.conf.FilterHandler = func(_ net.IP, _ string, settings *dnsfilter.RequestFilteringSettings) {
settings.FilteringEnabled = false
}
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&testUpstream{
cn: testCNAMEs,
ipv4: testIPv4,
},
}
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// 'badhost' has a canonical name 'null.example.org' which is blocked by
// filters: response is blocked.
req := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
},
Question: []dns.Question{{
Name: "badhost.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
// However, in our case it should not be blocked as filtering is
// disabled on the client level.
reply, err := dns.Exchange(&req, addr.String())
assert.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
}
func TestNullBlockedRequest(t *testing.T) {
s := createTestServer(t)
s.conf.FilteringConfig.BlockingMode = "null_ip"
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// Nil filter blocking.
req := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: "null.example.org.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
reply, err := dns.Exchange(&req, addr.String())
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
assert.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
a, ok := reply.Answer[0].(*dns.A)
assert.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
assert.Truef(t, a.A.IsUnspecified(), "dns server %s returned wrong answer instead of 0.0.0.0: %v", addr, a.A)
}
func TestBlockedCustomIP(t *testing.T) {
rules := "||nxdomain.example.org^\n||null.example.org^\n127.0.0.1 host.example.org\n@@||whitelist.example.org^\n||127.0.0.255\n"
filters := []dnsfilter.Filter{{
ID: 0,
Data: []byte(rules),
}}
s := NewServer(DNSCreateParams{
DNSFilter: dnsfilter.New(&dnsfilter.Config{}, filters),
})
conf := ServerConfig{
UDPListenAddr: &net.UDPAddr{Port: 0},
TCPListenAddr: &net.TCPAddr{Port: 0},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
BlockingMode: "custom_ip",
BlockingIPv4: nil,
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
},
}
// Invalid BlockingIPv4.
assert.NotNil(t, s.Prepare(&conf))
conf.BlockingIPv4 = net.IP{0, 0, 0, 1}
conf.BlockingIPv6 = net.ParseIP("::1")
assert.Nil(t, s.Prepare(&conf))
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("null.example.org.", dns.TypeA)
reply, err := dns.Exchange(req, addr.String())
assert.Nil(t, err)
assert.Len(t, reply.Answer, 1)
a, ok := reply.Answer[0].(*dns.A)
assert.True(t, ok)
assert.True(t, net.IP{0, 0, 0, 1}.Equal(a.A))
req = createTestMessageWithType("null.example.org.", dns.TypeAAAA)
reply, err = dns.Exchange(req, addr.String())
assert.Nil(t, err)
assert.Len(t, reply.Answer, 1)
a6, ok := reply.Answer[0].(*dns.AAAA)
assert.True(t, ok)
assert.Equal(t, "::1", a6.AAAA.String())
}
func TestBlockedByHosts(t *testing.T) {
s := createTestServer(t)
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// Hosts blocking.
req := createTestMessage("host.example.org.")
reply, err := dns.Exchange(req, addr.String())
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
assert.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
a, ok := reply.Answer[0].(*dns.A)
assert.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
assert.Equalf(t, net.IP{127, 0, 0, 1}, a.A, "dns server %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A)
}
func TestBlockedBySafeBrowsing(t *testing.T) {
t.Skip("TODO(e.burkov): substitute the dnsfilter by one with custom safeBrowsingUpstream")
resolver := &testResolver{}
ips, _ := resolver.LookupIPAddr(context.Background(), safeBrowsingBlockHost)
addrs, _ := resolver.LookupHost(safeBrowsingBlockHost)
s := createTestServer(t)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&testUpstream{
ipv4: map[string][]net.IP{
"wmconvirus.narod.ru.": {ips[0].IP},
},
},
}
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// SafeBrowsing blocking.
req := createTestMessage("wmconvirus.narod.ru.")
reply, err := dns.Exchange(req, addr.String())
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
assert.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
a, ok := reply.Answer[0].(*dns.A)
if assert.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) {
found := false
for _, blockAddr := range addrs {
if blockAddr == a.A.String() {
found = true
break
}
}
assert.Truef(t, found, "dns server %s returned wrong answer: %v", addr, a.A)
}
}
func TestRewrite(t *testing.T) {
c := &dnsfilter.Config{
Rewrites: []dnsfilter.RewriteEntry{{
Domain: "test.com",
Answer: "1.2.3.4",
Type: dns.TypeA,
}, {
Domain: "alias.test.com",
Answer: "test.com",
Type: dns.TypeCNAME,
}, {
Domain: "my.alias.example.org",
Answer: "example.org",
Type: dns.TypeCNAME,
}},
}
f := dnsfilter.New(c, nil)
s := NewServer(DNSCreateParams{DNSFilter: f})
err := s.Prepare(&ServerConfig{
UDPListenAddr: &net.UDPAddr{Port: 0},
TCPListenAddr: &net.TCPAddr{Port: 0},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
UpstreamDNS: []string{"8.8.8.8:53"},
},
})
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&testUpstream{
cn: map[string]string{
"example.org": "somename",
},
ipv4: map[string][]net.IP{
"example.org.": {{4, 3, 2, 1}},
},
},
}
assert.Nil(t, err)
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("test.com.", dns.TypeA)
reply, err := dns.Exchange(req, addr.String())
assert.Nil(t, err)
assert.Len(t, reply.Answer, 1)
a, ok := reply.Answer[0].(*dns.A)
assert.True(t, ok)
assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A))
req = createTestMessageWithType("test.com.", dns.TypeAAAA)
reply, err = dns.Exchange(req, addr.String())
assert.Nil(t, err)
assert.Empty(t, reply.Answer)
req = createTestMessageWithType("alias.test.com.", dns.TypeA)
reply, err = dns.Exchange(req, addr.String())
assert.Nil(t, err)
assert.Len(t, reply.Answer, 2)
assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target)
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
reply, err = dns.Exchange(req, addr.String())
assert.Nil(t, err)
assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name) // the original question is restored
assert.Len(t, reply.Answer, 2)
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
}
func createTestServer(t *testing.T) *Server {
rules := `||nxdomain.example.org
||null.example.org^
127.0.0.1 host.example.org
@@||whitelist.example.org^
||127.0.0.255`
filters := []dnsfilter.Filter{{
ID: 0, Data: []byte(rules),
}}
c := dnsfilter.Config{
SafeBrowsingEnabled: true,
SafeBrowsingCacheSize: 1000,
SafeSearchEnabled: true,
SafeSearchCacheSize: 1000,
ParentalCacheSize: 1000,
CacheTime: 30,
}
f := dnsfilter.New(&c, filters)
s := NewServer(DNSCreateParams{DNSFilter: f})
s.conf = ServerConfig{
UDPListenAddr: &net.UDPAddr{Port: 0},
TCPListenAddr: &net.TCPAddr{Port: 0},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
},
ConfigModified: func() {},
}
err := s.Prepare(nil)
assert.Nil(t, err)
return s
}
func createServerTLSConfig(t *testing.T) (*tls.Config, []byte, []byte) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
assert.Nilf(t, err, "cannot generate RSA key: %s", err)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
assert.Nilf(t, err, "failed to generate serial number: %s", err)
notBefore := time.Now()
notAfter := notBefore.Add(5 * 365 * time.Hour * 24)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"AdGuard Tests"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
template.DNSNames = append(template.DNSNames, tlsServerName)
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey)
assert.Nilf(t, err, "failed to create certificate: %s", err)
certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
cert, err := tls.X509KeyPair(certPem, keyPem)
assert.Nilf(t, err, "failed to create certificate: %s", err)
return &tls.Config{
Certificates: []tls.Certificate{cert},
ServerName: tlsServerName,
MinVersion: tls.VersionTLS12,
}, certPem, keyPem
}
// sendTestMessagesAsync sends messages in parallel to check for race issues.
//lint:ignore U1000 it's called from the function which is skipped for now.
func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) {
wg := &sync.WaitGroup{}
for i := 0; i < testMessagesCount; i++ {
msg := createGoogleATestMessage()
wg.Add(1)
go func() {
defer wg.Done()
err := conn.WriteMsg(msg)
assert.Nilf(t, err, "cannot write message: %s", err)
res, err := conn.ReadMsg()
assert.Nilf(t, err, "cannot read response to message: %s", err)
assertGoogleAResponse(t, res)
}()
}
wg.Wait()
}
func sendTestMessages(t *testing.T, conn *dns.Conn) {
t.Helper()
for i := 0; i < testMessagesCount; i++ {
req := createGoogleATestMessage()
err := conn.WriteMsg(req)
assert.Nilf(t, err, "cannot write message #%d: %s", i, err)
res, err := conn.ReadMsg()
assert.Nilf(t, err, "cannot read response to message #%d: %s", i, err)
assertGoogleAResponse(t, res)
}
}
func exchangeAndAssertResponse(t *testing.T, client *dns.Client, addr net.Addr, host string, ip net.IP) {
t.Helper()
req := createTestMessage(host)
reply, _, err := client.Exchange(req, addr.String())
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
assertResponse(t, reply, ip)
}
func createGoogleATestMessage() *dns.Msg {
return createTestMessage("google-public-dns-a.google.com.")
}
func createTestMessage(host string) *dns.Msg {
return &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: host,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
}
func createTestMessageWithType(host string, qtype uint16) *dns.Msg {
req := createTestMessage(host)
req.Question[0].Qtype = qtype
return req
}
func assertGoogleAResponse(t *testing.T, reply *dns.Msg) {
assertResponse(t, reply, net.IP{8, 8, 8, 8})
}
func assertResponse(t *testing.T, reply *dns.Msg, ip net.IP) {
t.Helper()
assert.Lenf(t, reply.Answer, 1, "dns server returned reply with wrong number of answers - %d", len(reply.Answer))
a, ok := reply.Answer[0].(*dns.A)
if assert.Truef(t, ok, "dns server returned wrong answer type instead of A: %v", reply.Answer[0]) {
assert.Truef(t, a.A.Equal(ip), "dns server returned wrong answer instead of %s: %s", ip, a.A)
}
}
func publicKey(priv interface{}) interface{} {
switch k := priv.(type) {
case *rsa.PrivateKey:
return &k.PublicKey
case *ecdsa.PrivateKey:
return &k.PublicKey
default:
return nil
}
}
func TestValidateUpstream(t *testing.T) {
testCases := []struct {
name string
upstream string
valid bool
wantDef bool
}{{
name: "invalid",
upstream: "1.2.3.4.5",
valid: false,
}, {
name: "invalid",
upstream: "123.3.7m",
valid: false,
}, {
name: "invalid",
upstream: "htttps://google.com/dns-query",
valid: false,
}, {
name: "invalid",
upstream: "[/host.com]tls://dns.adguard.com",
valid: false,
}, {
name: "invalid",
upstream: "[host.ru]#",
valid: false,
}, {
name: "valid_default",
upstream: "1.1.1.1",
valid: true,
wantDef: true,
}, {
name: "valid_default",
upstream: "tls://1.1.1.1",
valid: true,
wantDef: true,
}, {
name: "valid_default",
upstream: "https://dns.adguard.com/dns-query",
valid: true,
wantDef: true,
}, {
name: "valid_default",
upstream: "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
valid: true,
wantDef: true,
}, {
name: "valid",
upstream: "[/host.com/]1.1.1.1",
valid: true,
wantDef: false,
}, {
name: "valid",
upstream: "[//]tls://1.1.1.1",
valid: true,
wantDef: false,
}, {
name: "valid",
upstream: "[/www.host.com/]#",
valid: true,
wantDef: false,
}, {
name: "valid",
upstream: "[/host.com/google.com/]8.8.8.8",
valid: true,
wantDef: false,
}, {
name: "valid",
upstream: "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
valid: true,
wantDef: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
defaultUpstream, err := validateUpstream(tc.upstream)
assert.Equal(t, tc.valid, err == nil)
if err == nil {
assert.Equal(t, tc.wantDef, defaultUpstream)
}
})
}
}
func TestValidateUpstreamsSet(t *testing.T) {
// Empty upstreams array.
var upstreamsSet []string
assert.Nil(t, ValidateUpstreams(upstreamsSet), "empty upstreams array should be valid")
// Comment in upstreams array.
upstreamsSet = []string{"# comment"}
assert.Nil(t, ValidateUpstreams(upstreamsSet), "comments should not be validated")
// Set of valid upstreams. There is no default upstream specified.
upstreamsSet = []string{
"[/host.com/]1.1.1.1",
"[//]tls://1.1.1.1",
"[/www.host.com/]#",
"[/host.com/google.com/]8.8.8.8",
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
}
assert.NotNil(t, ValidateUpstreams(upstreamsSet), "there is no default upstream")
// Let's add default upstream.
upstreamsSet = append(upstreamsSet, "8.8.8.8")
err := ValidateUpstreams(upstreamsSet)
assert.Nilf(t, err, "upstreams set is valid, but doesn't pass through validation cause: %s", err)
// Let's add invalid upstream.
upstreamsSet = append(upstreamsSet, "dhcp://fake.dns")
assert.NotNil(t, ValidateUpstreams(upstreamsSet), "there is an invalid upstream in set, but it pass through validation")
}
func TestIPStringFromAddr(t *testing.T) {
addr := net.UDPAddr{
IP: net.ParseIP("1:2:3::4"),
Port: 12345,
Zone: "eth0",
}
assert.Equal(t, IPStringFromAddr(&addr), addr.IP.String())
assert.Empty(t, IPStringFromAddr(nil))
}
func TestMatchDNSName(t *testing.T) {
dnsNames := []string{"host1", "*.host2", "1.2.3.4"}
sort.Strings(dnsNames)
testCases := []struct {
name string
dnsName string
want bool
}{{
name: "match",
dnsName: "host1",
want: true,
}, {
name: "match",
dnsName: "a.host2",
want: true,
}, {
name: "match",
dnsName: "b.a.host2",
want: true,
}, {
name: "match",
dnsName: "1.2.3.4",
want: true,
}, {
name: "mismatch",
dnsName: "host2",
want: false,
}, {
name: "mismatch",
dnsName: "",
want: false,
}, {
name: "mismatch",
dnsName: "*.host2",
want: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, matchDNSName(dnsNames, tc.dnsName))
})
}
}
type testDHCP struct{}
func (d *testDHCP) Leases(flags int) []dhcpd.Lease {
l := dhcpd.Lease{
IP: net.IP{127, 0, 0, 1},
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
Hostname: "localhost",
}
return []dhcpd.Lease{l}
}
func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {}
func TestPTRResponseFromDHCPLeases(t *testing.T) {
dhcp := &testDHCP{}
s := NewServer(DNSCreateParams{
DNSFilter: dnsfilter.New(&dnsfilter.Config{}, nil),
DHCPServer: dhcp,
})
s.conf.UDPListenAddr = &net.UDPAddr{Port: 0}
s.conf.TCPListenAddr = &net.TCPAddr{Port: 0}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true
err := s.Prepare(nil)
assert.Nil(t, err)
assert.Nil(t, s.Start())
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR)
resp, err := dns.Exchange(req, addr.String())
assert.Nil(t, err)
assert.Len(t, resp.Answer, 1)
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
ptr, ok := resp.Answer[0].(*dns.PTR)
if assert.True(t, ok) {
assert.Equal(t, "localhost.", ptr.Ptr)
}
s.Close()
}
func TestPTRResponseFromHosts(t *testing.T) {
c := dnsfilter.Config{
AutoHosts: &util.AutoHosts{},
}
// Prepare test hosts file.
hf, err := ioutil.TempFile("", "")
if assert.Nil(t, err) {
t.Cleanup(func() {
assert.Nil(t, hf.Close())
assert.Nil(t, os.Remove(hf.Name()))
})
}
_, _ = hf.WriteString(" 127.0.0.1 host # comment \n")
_, _ = hf.WriteString(" ::1 localhost#comment \n")
// Init auto hosts.
c.AutoHosts.Init(hf.Name())
t.Cleanup(c.AutoHosts.Close)
s := NewServer(DNSCreateParams{DNSFilter: dnsfilter.New(&c, nil)})
s.conf.UDPListenAddr = &net.UDPAddr{Port: 0}
s.conf.TCPListenAddr = &net.TCPAddr{Port: 0}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true
assert.Nil(t, s.Prepare(nil))
assert.Nil(t, s.Start())
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR)
resp, err := dns.Exchange(req, addr.String())
assert.Nil(t, err)
assert.Len(t, resp.Answer, 1)
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
ptr, ok := resp.Answer[0].(*dns.PTR)
if assert.True(t, ok) {
assert.Equal(t, "host.", ptr.Ptr)
}
s.Close()
}