mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-12-26 00:28:16 +03:00
300 lines
7.8 KiB
Go
300 lines
7.8 KiB
Go
|
package dnsforward
|
||
|
|
||
|
import (
|
||
|
"crypto/tls"
|
||
|
"net"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||
|
"github.com/miekg/dns"
|
||
|
"github.com/stretchr/testify/assert"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
blockedHost = "blockedhost.org"
|
||
|
testFQDN = "example.org."
|
||
|
dnsClientTimeout = 200 * time.Millisecond
|
||
|
)
|
||
|
|
||
|
func TestServer_HandleBefore_tls(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
const clientID = "client-1"
|
||
|
|
||
|
testCases := []struct {
|
||
|
clientSrvName string
|
||
|
name string
|
||
|
host string
|
||
|
allowedClients []string
|
||
|
disallowedClients []string
|
||
|
blockedHosts []string
|
||
|
wantRCode int
|
||
|
}{{
|
||
|
clientSrvName: tlsServerName,
|
||
|
name: "allow_all",
|
||
|
host: testFQDN,
|
||
|
allowedClients: []string{},
|
||
|
disallowedClients: []string{},
|
||
|
blockedHosts: []string{},
|
||
|
wantRCode: dns.RcodeSuccess,
|
||
|
}, {
|
||
|
clientSrvName: "%" + "." + tlsServerName,
|
||
|
name: "invalid_client_id",
|
||
|
host: testFQDN,
|
||
|
allowedClients: []string{},
|
||
|
disallowedClients: []string{},
|
||
|
blockedHosts: []string{},
|
||
|
wantRCode: dns.RcodeServerFailure,
|
||
|
}, {
|
||
|
clientSrvName: clientID + "." + tlsServerName,
|
||
|
name: "allowed_client_allowed",
|
||
|
host: testFQDN,
|
||
|
allowedClients: []string{clientID},
|
||
|
disallowedClients: []string{},
|
||
|
blockedHosts: []string{},
|
||
|
wantRCode: dns.RcodeSuccess,
|
||
|
}, {
|
||
|
clientSrvName: "client-2." + tlsServerName,
|
||
|
name: "allowed_client_rejected",
|
||
|
host: testFQDN,
|
||
|
allowedClients: []string{clientID},
|
||
|
disallowedClients: []string{},
|
||
|
blockedHosts: []string{},
|
||
|
wantRCode: dns.RcodeRefused,
|
||
|
}, {
|
||
|
clientSrvName: tlsServerName,
|
||
|
name: "disallowed_client_allowed",
|
||
|
host: testFQDN,
|
||
|
allowedClients: []string{},
|
||
|
disallowedClients: []string{clientID},
|
||
|
blockedHosts: []string{},
|
||
|
wantRCode: dns.RcodeSuccess,
|
||
|
}, {
|
||
|
clientSrvName: clientID + "." + tlsServerName,
|
||
|
name: "disallowed_client_rejected",
|
||
|
host: testFQDN,
|
||
|
allowedClients: []string{},
|
||
|
disallowedClients: []string{clientID},
|
||
|
blockedHosts: []string{},
|
||
|
wantRCode: dns.RcodeRefused,
|
||
|
}, {
|
||
|
clientSrvName: tlsServerName,
|
||
|
name: "blocked_hosts_allowed",
|
||
|
host: testFQDN,
|
||
|
allowedClients: []string{},
|
||
|
disallowedClients: []string{},
|
||
|
blockedHosts: []string{blockedHost},
|
||
|
wantRCode: dns.RcodeSuccess,
|
||
|
}, {
|
||
|
clientSrvName: tlsServerName,
|
||
|
name: "blocked_hosts_rejected",
|
||
|
host: dns.Fqdn(blockedHost),
|
||
|
allowedClients: []string{},
|
||
|
disallowedClients: []string{},
|
||
|
blockedHosts: []string{blockedHost},
|
||
|
wantRCode: dns.RcodeRefused,
|
||
|
}}
|
||
|
|
||
|
localAns := []dns.RR{&dns.A{
|
||
|
Hdr: dns.RR_Header{
|
||
|
Name: testFQDN,
|
||
|
Rrtype: dns.TypeA,
|
||
|
Class: dns.ClassINET,
|
||
|
Ttl: 3600,
|
||
|
Rdlength: 4,
|
||
|
},
|
||
|
A: net.IP{1, 2, 3, 4},
|
||
|
}}
|
||
|
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||
|
resp := (&dns.Msg{}).SetReply(req)
|
||
|
resp.Answer = localAns
|
||
|
|
||
|
require.NoError(t, w.WriteMsg(resp))
|
||
|
})
|
||
|
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
|
||
|
|
||
|
for _, tc := range testCases {
|
||
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
s, _ := createTestTLS(t, TLSConfig{
|
||
|
TLSListenAddrs: []*net.TCPAddr{{}},
|
||
|
ServerName: tlsServerName,
|
||
|
})
|
||
|
|
||
|
s.conf.UpstreamDNS = []string{localUpsAddr}
|
||
|
|
||
|
s.conf.AllowedClients = tc.allowedClients
|
||
|
s.conf.DisallowedClients = tc.disallowedClients
|
||
|
s.conf.BlockedHosts = tc.blockedHosts
|
||
|
|
||
|
err := s.Prepare(&s.conf)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
startDeferStop(t, s)
|
||
|
|
||
|
tlsConfig := &tls.Config{
|
||
|
InsecureSkipVerify: true,
|
||
|
ServerName: tc.clientSrvName,
|
||
|
}
|
||
|
|
||
|
client := &dns.Client{
|
||
|
Net: "tcp-tls",
|
||
|
TLSConfig: tlsConfig,
|
||
|
Timeout: dnsClientTimeout,
|
||
|
}
|
||
|
|
||
|
req := createTestMessage(tc.host)
|
||
|
addr := s.dnsProxy.Addr(proxy.ProtoTLS).String()
|
||
|
|
||
|
reply, _, err := client.Exchange(req, addr)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
assert.Equal(t, tc.wantRCode, reply.Rcode)
|
||
|
if tc.wantRCode == dns.RcodeSuccess {
|
||
|
assert.Equal(t, localAns, reply.Answer)
|
||
|
} else {
|
||
|
assert.Empty(t, reply.Answer)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestServer_HandleBefore_udp(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
const (
|
||
|
clientIPv4 = "127.0.0.1"
|
||
|
clientIPv6 = "::1"
|
||
|
)
|
||
|
|
||
|
clientIPs := []string{clientIPv4, clientIPv6}
|
||
|
|
||
|
testCases := []struct {
|
||
|
name string
|
||
|
host string
|
||
|
allowedClients []string
|
||
|
disallowedClients []string
|
||
|
blockedHosts []string
|
||
|
wantTimeout bool
|
||
|
}{{
|
||
|
name: "allow_all",
|
||
|
host: testFQDN,
|
||
|
allowedClients: []string{},
|
||
|
disallowedClients: []string{},
|
||
|
blockedHosts: []string{},
|
||
|
wantTimeout: false,
|
||
|
}, {
|
||
|
name: "allowed_client_allowed",
|
||
|
host: testFQDN,
|
||
|
allowedClients: clientIPs,
|
||
|
disallowedClients: []string{},
|
||
|
blockedHosts: []string{},
|
||
|
wantTimeout: false,
|
||
|
}, {
|
||
|
name: "allowed_client_rejected",
|
||
|
host: testFQDN,
|
||
|
allowedClients: []string{"1:2:3::4"},
|
||
|
disallowedClients: []string{},
|
||
|
blockedHosts: []string{},
|
||
|
wantTimeout: true,
|
||
|
}, {
|
||
|
name: "disallowed_client_allowed",
|
||
|
host: testFQDN,
|
||
|
allowedClients: []string{},
|
||
|
disallowedClients: []string{"1:2:3::4"},
|
||
|
blockedHosts: []string{},
|
||
|
wantTimeout: false,
|
||
|
}, {
|
||
|
name: "disallowed_client_rejected",
|
||
|
host: testFQDN,
|
||
|
allowedClients: []string{},
|
||
|
disallowedClients: clientIPs,
|
||
|
blockedHosts: []string{},
|
||
|
wantTimeout: true,
|
||
|
}, {
|
||
|
name: "blocked_hosts_allowed",
|
||
|
host: testFQDN,
|
||
|
allowedClients: []string{},
|
||
|
disallowedClients: []string{},
|
||
|
blockedHosts: []string{blockedHost},
|
||
|
wantTimeout: false,
|
||
|
}, {
|
||
|
name: "blocked_hosts_rejected",
|
||
|
host: dns.Fqdn(blockedHost),
|
||
|
allowedClients: []string{},
|
||
|
disallowedClients: []string{},
|
||
|
blockedHosts: []string{blockedHost},
|
||
|
wantTimeout: true,
|
||
|
}}
|
||
|
|
||
|
localAns := []dns.RR{&dns.A{
|
||
|
Hdr: dns.RR_Header{
|
||
|
Name: testFQDN,
|
||
|
Rrtype: dns.TypeA,
|
||
|
Class: dns.ClassINET,
|
||
|
Ttl: 3600,
|
||
|
Rdlength: 4,
|
||
|
},
|
||
|
A: net.IP{1, 2, 3, 4},
|
||
|
}}
|
||
|
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||
|
resp := (&dns.Msg{}).SetReply(req)
|
||
|
resp.Answer = localAns
|
||
|
|
||
|
require.NoError(t, w.WriteMsg(resp))
|
||
|
})
|
||
|
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
|
||
|
|
||
|
for _, tc := range testCases {
|
||
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
s := createTestServer(t, &filtering.Config{
|
||
|
BlockingMode: filtering.BlockingModeDefault,
|
||
|
}, ServerConfig{
|
||
|
UDPListenAddrs: []*net.UDPAddr{{}},
|
||
|
TCPListenAddrs: []*net.TCPAddr{{}},
|
||
|
Config: Config{
|
||
|
AllowedClients: tc.allowedClients,
|
||
|
DisallowedClients: tc.disallowedClients,
|
||
|
BlockedHosts: tc.blockedHosts,
|
||
|
UpstreamDNS: []string{localUpsAddr},
|
||
|
UpstreamMode: UpstreamModeLoadBalance,
|
||
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||
|
},
|
||
|
ServePlainDNS: true,
|
||
|
})
|
||
|
|
||
|
startDeferStop(t, s)
|
||
|
|
||
|
client := &dns.Client{
|
||
|
Net: "udp",
|
||
|
Timeout: dnsClientTimeout,
|
||
|
}
|
||
|
|
||
|
req := createTestMessage(tc.host)
|
||
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
|
||
|
|
||
|
reply, _, err := client.Exchange(req, addr)
|
||
|
if tc.wantTimeout {
|
||
|
wantErr := &net.OpError{}
|
||
|
require.ErrorAs(t, err, &wantErr)
|
||
|
assert.True(t, wantErr.Timeout())
|
||
|
|
||
|
assert.Nil(t, reply)
|
||
|
} else {
|
||
|
require.NoError(t, err)
|
||
|
require.NotNil(t, reply)
|
||
|
|
||
|
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||
|
assert.Equal(t, localAns, reply.Answer)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|