-(dnsforward): custom client per-domain upstreams

Closes: https://github.com/AdguardTeam/AdGuardHome/issues/1539
This commit is contained in:
Andrey Meshkov 2020-05-13 20:31:43 +03:00
parent 1f954ab673
commit 67a39045fc
10 changed files with 106 additions and 88 deletions

View file

@ -70,3 +70,7 @@ issues:
- G108 - G108
# gosec: Subprocess launched with function call as argument or cmd arguments # gosec: Subprocess launched with function call as argument or cmd arguments
- G204 - G204
# gosec: Potential DoS vulnerability via decompression bomb
- G110
# gosec: Expect WriteFile permissions to be 0600 or less
- G306

View file

@ -26,8 +26,9 @@ type FilteringConfig struct {
// Filtering callback function // Filtering callback function
FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"` FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
// This callback function returns the list of upstream servers for a client specified by IP address // GetCustomUpstreamByClient - a callback function that returns upstreams configuration
GetUpstreamsByClient func(clientAddr string) []upstream.Upstream `yaml:"-"` // based on the client IP address. Returns nil if there are no custom upstreams for the client
GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"`
// Protection configuration // Protection configuration
// -- // --
@ -102,11 +103,10 @@ type TLSConfig struct {
// ServerConfig represents server configuration. // ServerConfig represents server configuration.
// The zero ServerConfig is empty and ready for use. // The zero ServerConfig is empty and ready for use.
type ServerConfig struct { type ServerConfig struct {
UDPListenAddr *net.UDPAddr // UDP listen address UDPListenAddr *net.UDPAddr // UDP listen address
TCPListenAddr *net.TCPAddr // TCP listen address TCPListenAddr *net.TCPAddr // TCP listen address
Upstreams []upstream.Upstream // Configured upstreams UpstreamConfig *proxy.UpstreamConfig // Upstream DNS servers config
DomainsReservedUpstreams map[string][]upstream.Upstream // Map of domains and lists of configured upstreams OnDNSRequest func(d *proxy.DNSContext)
OnDNSRequest func(d *proxy.DNSContext)
FilteringConfig FilteringConfig
TLSConfig TLSConfig
@ -132,22 +132,21 @@ var defaultValues = ServerConfig{
// createProxyConfig creates and validates configuration for the main proxy // createProxyConfig creates and validates configuration for the main proxy
func (s *Server) createProxyConfig() (proxy.Config, error) { func (s *Server) createProxyConfig() (proxy.Config, error) {
proxyConfig := proxy.Config{ proxyConfig := proxy.Config{
UDPListenAddr: s.conf.UDPListenAddr, UDPListenAddr: s.conf.UDPListenAddr,
TCPListenAddr: s.conf.TCPListenAddr, TCPListenAddr: s.conf.TCPListenAddr,
Ratelimit: int(s.conf.Ratelimit), Ratelimit: int(s.conf.Ratelimit),
RatelimitWhitelist: s.conf.RatelimitWhitelist, RatelimitWhitelist: s.conf.RatelimitWhitelist,
RefuseAny: s.conf.RefuseAny, RefuseAny: s.conf.RefuseAny,
CacheEnabled: true, CacheEnabled: true,
CacheSizeBytes: int(s.conf.CacheSize), CacheSizeBytes: int(s.conf.CacheSize),
CacheMinTTL: s.conf.CacheMinTTL, CacheMinTTL: s.conf.CacheMinTTL,
CacheMaxTTL: s.conf.CacheMaxTTL, CacheMaxTTL: s.conf.CacheMaxTTL,
Upstreams: s.conf.Upstreams, UpstreamConfig: s.conf.UpstreamConfig,
DomainsReservedUpstreams: s.conf.DomainsReservedUpstreams, BeforeRequestHandler: s.beforeRequestHandler,
BeforeRequestHandler: s.beforeRequestHandler, RequestHandler: s.handleDNSRequest,
RequestHandler: s.handleDNSRequest, AllServers: s.conf.AllServers,
AllServers: s.conf.AllServers, EnableEDNSClientSubnet: s.conf.EnableEDNSClientSubnet,
EnableEDNSClientSubnet: s.conf.EnableEDNSClientSubnet, FindFastestAddr: s.conf.FastestAddr,
FindFastestAddr: s.conf.FastestAddr,
} }
if len(s.conf.BogusNXDomain) > 0 { if len(s.conf.BogusNXDomain) > 0 {
@ -168,7 +167,7 @@ func (s *Server) createProxyConfig() (proxy.Config, error) {
} }
// Validate proxy config // Validate proxy config
if len(proxyConfig.Upstreams) == 0 { if proxyConfig.UpstreamConfig == nil || len(proxyConfig.UpstreamConfig.Upstreams) == 0 {
return proxyConfig, errors.New("no upstream servers configured") return proxyConfig, errors.New("no upstream servers configured")
} }
@ -204,18 +203,16 @@ func (s *Server) prepareUpstreamSettings() error {
if err != nil { if err != nil {
return fmt.Errorf("DNS: proxy.ParseUpstreamsConfig: %s", err) return fmt.Errorf("DNS: proxy.ParseUpstreamsConfig: %s", err)
} }
s.conf.Upstreams = upstreamConfig.Upstreams s.conf.UpstreamConfig = &upstreamConfig
s.conf.DomainsReservedUpstreams = upstreamConfig.DomainReservedUpstreams
return nil return nil
} }
// prepareIntlProxy - initializes DNS proxy that we use for internal DNS queries // prepareIntlProxy - initializes DNS proxy that we use for internal DNS queries
func (s *Server) prepareIntlProxy() { func (s *Server) prepareIntlProxy() {
intlProxyConfig := proxy.Config{ intlProxyConfig := proxy.Config{
CacheEnabled: true, CacheEnabled: true,
CacheSizeBytes: 4096, CacheSizeBytes: 4096,
Upstreams: s.conf.Upstreams, UpstreamConfig: s.conf.UpstreamConfig,
DomainsReservedUpstreams: s.conf.DomainsReservedUpstreams,
} }
s.internalProxy = &proxy.Proxy{Config: intlProxyConfig} s.internalProxy = &proxy.Proxy{Config: intlProxyConfig}
} }

View file

@ -325,7 +325,9 @@ func (s *Server) startWithUpstream(u upstream.Upstream) error {
if err != nil { if err != nil {
return err return err
} }
s.dnsProxy.Upstreams = []upstream.Upstream{u} s.dnsProxy.UpstreamConfig = &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{u},
}
return s.dnsProxy.Start() return s.dnsProxy.Start()
} }
@ -353,8 +355,8 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
// but protection is disabled - response is NOT blocked // but protection is disabled - response is NOT blocked
req := createTestMessage("badhost.") req := createTestMessage("badhost.")
reply, err := dns.Exchange(req, addr.String()) reply, err := dns.Exchange(req, addr.String())
assert.True(t, err == nil) assert.Nil(t, err)
assert.True(t, reply.Rcode == dns.RcodeSuccess) assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
} }
func TestBlockCNAME(t *testing.T) { func TestBlockCNAME(t *testing.T) {
@ -368,23 +370,23 @@ func TestBlockCNAME(t *testing.T) {
// response is blocked // response is blocked
req := createTestMessage("badhost.") req := createTestMessage("badhost.")
reply, err := dns.Exchange(req, addr.String()) reply, err := dns.Exchange(req, addr.String())
assert.True(t, err == nil) assert.Nil(t, err, nil)
assert.True(t, reply.Rcode == dns.RcodeNameError) assert.Equal(t, dns.RcodeNameError, reply.Rcode)
// 'whitelist.example.org' has a canonical name 'null.example.org' which is blocked by filters // 'whitelist.example.org' has a canonical name 'null.example.org' which is blocked by filters
// but 'whitelist.example.org' is in a whitelist: // but 'whitelist.example.org' is in a whitelist:
// response isn't blocked // response isn't blocked
req = createTestMessage("whitelist.example.org.") req = createTestMessage("whitelist.example.org.")
reply, err = dns.Exchange(req, addr.String()) reply, err = dns.Exchange(req, addr.String())
assert.True(t, err == nil) assert.Nil(t, err)
assert.True(t, reply.Rcode == dns.RcodeSuccess) assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
// 'example.org' has a canonical name 'cname1' with IP 127.0.0.255 which is blocked by filters: // 'example.org' has a canonical name 'cname1' with IP 127.0.0.255 which is blocked by filters:
// response is blocked // response is blocked
req = createTestMessage("example.org.") req = createTestMessage("example.org.")
reply, err = dns.Exchange(req, addr.String()) reply, err = dns.Exchange(req, addr.String())
assert.True(t, err == nil) assert.Nil(t, err)
assert.True(t, reply.Rcode == dns.RcodeNameError) assert.Equal(t, dns.RcodeNameError, reply.Rcode)
_ = s.Stop() _ = s.Stop()
} }
@ -455,7 +457,7 @@ func TestNullBlockedRequest(t *testing.T) {
func TestBlockedCustomIP(t *testing.T) { func TestBlockedCustomIP(t *testing.T) {
rules := "||nxdomain.example.org^\n||null.example.org^\n127.0.0.1 host.example.org\n@@||whitelist.example.org^\n||127.0.0.255\n" rules := "||nxdomain.example.org^\n||null.example.org^\n127.0.0.1 host.example.org\n@@||whitelist.example.org^\n||127.0.0.255\n"
filters := []dnsfilter.Filter{dnsfilter.Filter{ filters := []dnsfilter.Filter{{
ID: 0, Data: []byte(rules), ID: 0, Data: []byte(rules),
}} }}
c := dnsfilter.Config{} c := dnsfilter.Config{}
@ -475,27 +477,27 @@ func TestBlockedCustomIP(t *testing.T) {
conf.BlockingIPv4 = "0.0.0.1" conf.BlockingIPv4 = "0.0.0.1"
conf.BlockingIPv6 = "::1" conf.BlockingIPv6 = "::1"
err = s.Prepare(&conf) err = s.Prepare(&conf)
assert.True(t, err == nil) assert.Nil(t, err)
err = s.Start() err = s.Start()
assert.True(t, err == nil, "%s", err) assert.Nil(t, err)
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("null.example.org.", dns.TypeA) req := createTestMessageWithType("null.example.org.", dns.TypeA)
reply, err := dns.Exchange(req, addr.String()) reply, err := dns.Exchange(req, addr.String())
assert.True(t, err == nil) assert.Nil(t, err)
assert.True(t, len(reply.Answer) == 1) assert.Equal(t, 1, len(reply.Answer))
a, ok := reply.Answer[0].(*dns.A) a, ok := reply.Answer[0].(*dns.A)
assert.True(t, ok) assert.True(t, ok)
assert.True(t, a.A.String() == "0.0.0.1") assert.Equal(t, "0.0.0.1", a.A.String())
req = createTestMessageWithType("null.example.org.", dns.TypeAAAA) req = createTestMessageWithType("null.example.org.", dns.TypeAAAA)
reply, err = dns.Exchange(req, addr.String()) reply, err = dns.Exchange(req, addr.String())
assert.True(t, err == nil) assert.Nil(t, err)
assert.True(t, len(reply.Answer) == 1) assert.Equal(t, 1, len(reply.Answer))
a6, ok := reply.Answer[0].(*dns.AAAA) a6, ok := reply.Answer[0].(*dns.AAAA)
assert.True(t, ok) assert.True(t, ok)
assert.True(t, a6.AAAA.String() == "::1") assert.Equal(t, "::1", a6.AAAA.String())
err = s.Stop() err = s.Stop()
if err != nil { if err != nil {
@ -598,7 +600,7 @@ func createTestServer(t *testing.T) *Server {
127.0.0.1 host.example.org 127.0.0.1 host.example.org
@@||whitelist.example.org^ @@||whitelist.example.org^
||127.0.0.255` ||127.0.0.255`
filters := []dnsfilter.Filter{dnsfilter.Filter{ filters := []dnsfilter.Filter{{
ID: 0, Data: []byte(rules), ID: 0, Data: []byte(rules),
}} }}
c := dnsfilter.Config{} c := dnsfilter.Config{}

View file

@ -10,7 +10,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
func (s *Server) beforeRequestHandler(p *proxy.Proxy, d *proxy.DNSContext) (bool, error) { func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
ip := ipFromAddr(d.Addr) ip := ipFromAddr(d.Addr)
if s.access.IsBlockedIP(ip) { if s.access.IsBlockedIP(ip) {
log.Tracef("Client IP %s is blocked by settings", ip) log.Tracef("Client IP %s is blocked by settings", ip)

View file

@ -31,7 +31,7 @@ const (
) )
// handleDNSRequest filters the incoming DNS requests and writes them to the query log // handleDNSRequest filters the incoming DNS requests and writes them to the query log
func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
ctx := &dnsContext{srv: s, proxyCtx: d} ctx := &dnsContext{srv: s, proxyCtx: d}
ctx.result = &dnsfilter.Result{} ctx.result = &dnsfilter.Result{}
ctx.startTime = time.Now() ctx.startTime = time.Now()
@ -124,12 +124,12 @@ func processUpstream(ctx *dnsContext) int {
return resultDone // response is already set - nothing to do return resultDone // response is already set - nothing to do
} }
if d.Addr != nil && s.conf.GetUpstreamsByClient != nil { if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
clientIP := ipFromAddr(d.Addr) clientIP := ipFromAddr(d.Addr)
upstreams := s.conf.GetUpstreamsByClient(clientIP) upstreamsConf := s.conf.GetCustomUpstreamByClient(clientIP)
if len(upstreams) > 0 { if upstreamsConf != nil {
log.Debug("Using custom upstreams for %s", clientIP) log.Debug("Using custom upstreams for %s", clientIP)
d.Upstreams = upstreams d.CustomUpstreamConfig = upstreamsConf
} }
} }

2
go.mod
View file

@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome
go 1.14 go 1.14
require ( require (
github.com/AdguardTeam/dnsproxy v0.28.0 github.com/AdguardTeam/dnsproxy v0.28.1
github.com/AdguardTeam/golibs v0.4.2 github.com/AdguardTeam/golibs v0.4.2
github.com/AdguardTeam/urlfilter v0.10.0 github.com/AdguardTeam/urlfilter v0.10.0
github.com/NYTimes/gziphandler v1.1.1 github.com/NYTimes/gziphandler v1.1.1

4
go.sum
View file

@ -1,5 +1,5 @@
github.com/AdguardTeam/dnsproxy v0.28.0 h1:w6ITGjSMLztUOTVNVVcE0JU1bV2U0bOPyDHGwyZgTc4= github.com/AdguardTeam/dnsproxy v0.28.1 h1:WkLjrUcVf/njbTLyL7bNt6e18zQjF2ZYv/HWwL9cMmU=
github.com/AdguardTeam/dnsproxy v0.28.0/go.mod h1:hOYFV9TW+pd5XKYz7KZf2FFD8SvSPqjyGTxUae86s58= github.com/AdguardTeam/dnsproxy v0.28.1/go.mod h1:hOYFV9TW+pd5XKYz7KZf2FFD8SvSPqjyGTxUae86s58=
github.com/AdguardTeam/golibs v0.4.0 h1:4VX6LoOqFe9p9Gf55BeD8BvJD6M6RDYmgEiHrENE9KU= github.com/AdguardTeam/golibs v0.4.0 h1:4VX6LoOqFe9p9Gf55BeD8BvJD6M6RDYmgEiHrENE9KU=
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.4.2 h1:7M28oTZFoFwNmp8eGPb3ImmYbxGaJLyQXeIFVHjME0o= github.com/AdguardTeam/golibs v0.4.2 h1:7M28oTZFoFwNmp8eGPb3ImmYbxGaJLyQXeIFVHjME0o=

View file

@ -11,11 +11,12 @@ import (
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/AdGuardHome/dhcpd" "github.com/AdguardTeam/AdGuardHome/dhcpd"
"github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/utils" "github.com/AdguardTeam/golibs/utils"
) )
@ -41,11 +42,12 @@ type Client struct {
BlockedServices []string BlockedServices []string
Upstreams []string // list of upstream servers to be used for the client's requests Upstreams []string // list of upstream servers to be used for the client's requests
// Upstream objects:
// Custom upstream config for this client
// nil: not yet initialized // nil: not yet initialized
// not nil, but empty: initialized, no good upstreams // not nil, but empty: initialized, no good upstreams
// not nil, not empty: Upstreams ready to be used // not nil, not empty: Upstreams ready to be used
upstreamObjects []upstream.Upstream upstreamConfig *proxy.UpstreamConfig
} }
type clientSource uint type clientSource uint
@ -273,16 +275,10 @@ func (clients *clientsContainer) Find(ip string) (Client, bool) {
return c, true return c, true
} }
func upstreamArrayCopy(a []upstream.Upstream) []upstream.Upstream {
a2 := make([]upstream.Upstream, len(a))
copy(a2, a)
return a2
}
// FindUpstreams looks for upstreams configured for the client // FindUpstreams looks for upstreams configured for the client
// If no client found for this IP, or if no custom upstreams are configured, // If no client found for this IP, or if no custom upstreams are configured,
// this method returns nil // this method returns nil
func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream { func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
@ -291,22 +287,18 @@ func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream {
return nil return nil
} }
if c.upstreamObjects == nil { if len(c.Upstreams) == 0 {
c.upstreamObjects = make([]upstream.Upstream, 0) return nil
for _, us := range c.Upstreams { }
u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout})
if err != nil { if c.upstreamConfig == nil {
log.Error("upstream.AddressToUpstream: %s: %s", us, err) config, err := proxy.ParseUpstreamsConfig(c.Upstreams, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout)
continue if err == nil {
} c.upstreamConfig = &config
c.upstreamObjects = append(c.upstreamObjects, u)
} }
} }
if len(c.upstreamObjects) == 0 { return c.upstreamConfig
return nil
}
return upstreamArrayCopy(c.upstreamObjects)
} }
// Find searches for a client by IP (and does not lock anything) // Find searches for a client by IP (and does not lock anything)
@ -537,7 +529,7 @@ func (clients *clientsContainer) Update(name string, c Client) error {
} }
// update upstreams cache // update upstreams cache
c.upstreamObjects = nil c.upstreamConfig = nil
*old = c *old = c
return nil return nil

View file

@ -236,3 +236,31 @@ func TestClientsAddExisting(t *testing.T) {
assert.True(t, ok) assert.True(t, ok)
assert.Nil(t, err) assert.Nil(t, err)
} }
func TestClientsCustomUpstream(t *testing.T) {
clients := clientsContainer{}
clients.testing = true
clients.Init(nil, nil, nil)
// add client with upstreams
client := Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
Name: "client1",
Upstreams: []string{
"1.1.1.1",
"[/example.org/]8.8.8.8",
},
}
ok, err := clients.Add(client)
assert.Nil(t, err)
assert.True(t, ok)
config := clients.FindUpstreams("1.2.3.4")
assert.Nil(t, config)
config = clients.FindUpstreams("1.1.1.1")
assert.NotNil(t, config)
assert.Equal(t, 1, len(config.Upstreams))
assert.Equal(t, 1, len(config.DomainReservedUpstreams))
}

View file

@ -11,7 +11,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/stats" "github.com/AdguardTeam/AdGuardHome/stats"
"github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/joomcode/errorx" "github.com/joomcode/errorx"
) )
@ -176,7 +175,7 @@ func generateServerConfig() dnsforward.ServerConfig {
newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
newconfig.FilterHandler = applyAdditionalFiltering newconfig.FilterHandler = applyAdditionalFiltering
newconfig.GetUpstreamsByClient = getUpstreamsByClient newconfig.GetCustomUpstreamByClient = Context.clients.FindUpstreams
return newconfig return newconfig
} }
@ -222,10 +221,6 @@ func getDNSAddresses() []string {
return dnsAddresses return dnsAddresses
} }
func getUpstreamsByClient(clientAddr string) []upstream.Upstream {
return Context.clients.FindUpstreams(clientAddr)
}
// If a client has his own settings, apply them // If a client has his own settings, apply them
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) { func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
Context.dnsFilter.ApplyBlockedServices(setts, nil, true) Context.dnsFilter.ApplyBlockedServices(setts, nil, true)