mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-23 05:25:35 +03:00
Merge: -(dnsforward): custom client per-domain upstreams
* commit '5d7b3fb7d5aa14c434dc532aef2fd68e54e7e182': -(dnsforward): fix handling RRSIG records Added a unit-test for custom upstreams -(dnsforward): custom client per-domain upstreams
This commit is contained in:
commit
39420c8a00
12 changed files with 179 additions and 129 deletions
|
@ -70,3 +70,7 @@ issues:
|
||||||
- G108
|
- G108
|
||||||
# gosec: Subprocess launched with function call as argument or cmd arguments
|
# gosec: Subprocess launched with function call as argument or cmd arguments
|
||||||
- G204
|
- G204
|
||||||
|
# gosec: Potential DoS vulnerability via decompression bomb
|
||||||
|
- G110
|
||||||
|
# gosec: Expect WriteFile permissions to be 0600 or less
|
||||||
|
- G306
|
||||||
|
|
|
@ -26,8 +26,9 @@ type FilteringConfig struct {
|
||||||
// Filtering callback function
|
// Filtering callback function
|
||||||
FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
|
FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
|
||||||
|
|
||||||
// This callback function returns the list of upstream servers for a client specified by IP address
|
// GetCustomUpstreamByClient - a callback function that returns upstreams configuration
|
||||||
GetUpstreamsByClient func(clientAddr string) []upstream.Upstream `yaml:"-"`
|
// based on the client IP address. Returns nil if there are no custom upstreams for the client
|
||||||
|
GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"`
|
||||||
|
|
||||||
// Protection configuration
|
// Protection configuration
|
||||||
// --
|
// --
|
||||||
|
@ -104,8 +105,7 @@ type TLSConfig struct {
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
UDPListenAddr *net.UDPAddr // UDP listen address
|
UDPListenAddr *net.UDPAddr // UDP listen address
|
||||||
TCPListenAddr *net.TCPAddr // TCP listen address
|
TCPListenAddr *net.TCPAddr // TCP listen address
|
||||||
Upstreams []upstream.Upstream // Configured upstreams
|
UpstreamConfig *proxy.UpstreamConfig // Upstream DNS servers config
|
||||||
DomainsReservedUpstreams map[string][]upstream.Upstream // Map of domains and lists of configured upstreams
|
|
||||||
OnDNSRequest func(d *proxy.DNSContext)
|
OnDNSRequest func(d *proxy.DNSContext)
|
||||||
|
|
||||||
FilteringConfig
|
FilteringConfig
|
||||||
|
@ -141,8 +141,7 @@ func (s *Server) createProxyConfig() (proxy.Config, error) {
|
||||||
CacheSizeBytes: int(s.conf.CacheSize),
|
CacheSizeBytes: int(s.conf.CacheSize),
|
||||||
CacheMinTTL: s.conf.CacheMinTTL,
|
CacheMinTTL: s.conf.CacheMinTTL,
|
||||||
CacheMaxTTL: s.conf.CacheMaxTTL,
|
CacheMaxTTL: s.conf.CacheMaxTTL,
|
||||||
Upstreams: s.conf.Upstreams,
|
UpstreamConfig: s.conf.UpstreamConfig,
|
||||||
DomainsReservedUpstreams: s.conf.DomainsReservedUpstreams,
|
|
||||||
BeforeRequestHandler: s.beforeRequestHandler,
|
BeforeRequestHandler: s.beforeRequestHandler,
|
||||||
RequestHandler: s.handleDNSRequest,
|
RequestHandler: s.handleDNSRequest,
|
||||||
AllServers: s.conf.AllServers,
|
AllServers: s.conf.AllServers,
|
||||||
|
@ -168,7 +167,7 @@ func (s *Server) createProxyConfig() (proxy.Config, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate proxy config
|
// Validate proxy config
|
||||||
if len(proxyConfig.Upstreams) == 0 {
|
if proxyConfig.UpstreamConfig == nil || len(proxyConfig.UpstreamConfig.Upstreams) == 0 {
|
||||||
return proxyConfig, errors.New("no upstream servers configured")
|
return proxyConfig, errors.New("no upstream servers configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -204,8 +203,7 @@ func (s *Server) prepareUpstreamSettings() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("DNS: proxy.ParseUpstreamsConfig: %s", err)
|
return fmt.Errorf("DNS: proxy.ParseUpstreamsConfig: %s", err)
|
||||||
}
|
}
|
||||||
s.conf.Upstreams = upstreamConfig.Upstreams
|
s.conf.UpstreamConfig = &upstreamConfig
|
||||||
s.conf.DomainsReservedUpstreams = upstreamConfig.DomainReservedUpstreams
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -214,8 +212,7 @@ func (s *Server) prepareIntlProxy() {
|
||||||
intlProxyConfig := proxy.Config{
|
intlProxyConfig := proxy.Config{
|
||||||
CacheEnabled: true,
|
CacheEnabled: true,
|
||||||
CacheSizeBytes: 4096,
|
CacheSizeBytes: 4096,
|
||||||
Upstreams: s.conf.Upstreams,
|
UpstreamConfig: s.conf.UpstreamConfig,
|
||||||
DomainsReservedUpstreams: s.conf.DomainsReservedUpstreams,
|
|
||||||
}
|
}
|
||||||
s.internalProxy = &proxy.Proxy{Config: intlProxyConfig}
|
s.internalProxy = &proxy.Proxy{Config: intlProxyConfig}
|
||||||
}
|
}
|
||||||
|
|
|
@ -249,6 +249,39 @@ func TestBlockedRequest(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServerCustomClientUpstream(t *testing.T) {
|
||||||
|
s := createTestServer(t)
|
||||||
|
err := s.Start()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to start server: %s", err)
|
||||||
|
}
|
||||||
|
s.conf.GetCustomUpstreamByClient = func(clientAddr string) *proxy.UpstreamConfig {
|
||||||
|
uc := &proxy.UpstreamConfig{}
|
||||||
|
u := &testUpstream{}
|
||||||
|
u.ipv4 = map[string][]net.IP{}
|
||||||
|
u.ipv4["host."] = []net.IP{net.ParseIP("192.168.0.1")}
|
||||||
|
uc.Upstreams = append(uc.Upstreams, u)
|
||||||
|
return uc
|
||||||
|
}
|
||||||
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
|
|
||||||
|
// Send test request
|
||||||
|
req := dns.Msg{}
|
||||||
|
req.Id = dns.Id()
|
||||||
|
req.RecursionDesired = true
|
||||||
|
req.Question = []dns.Question{
|
||||||
|
{Name: "host.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||||
|
}
|
||||||
|
|
||||||
|
reply, err := dns.Exchange(&req, addr.String())
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||||||
|
assert.NotNil(t, reply.Answer)
|
||||||
|
assert.Equal(t, "192.168.0.1", reply.Answer[0].(*dns.A).A.String())
|
||||||
|
assert.Nil(t, s.Stop())
|
||||||
|
}
|
||||||
|
|
||||||
// testUpstream is a mock of real upstream.
|
// testUpstream is a mock of real upstream.
|
||||||
// specify fields with necessary values to simulate real upstream behaviour
|
// specify fields with necessary values to simulate real upstream behaviour
|
||||||
type testUpstream struct {
|
type testUpstream struct {
|
||||||
|
@ -325,7 +358,9 @@ func (s *Server) startWithUpstream(u upstream.Upstream) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.dnsProxy.Upstreams = []upstream.Upstream{u}
|
s.dnsProxy.UpstreamConfig = &proxy.UpstreamConfig{
|
||||||
|
Upstreams: []upstream.Upstream{u},
|
||||||
|
}
|
||||||
return s.dnsProxy.Start()
|
return s.dnsProxy.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -353,8 +388,8 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
|
||||||
// but protection is disabled - response is NOT blocked
|
// but protection is disabled - response is NOT blocked
|
||||||
req := createTestMessage("badhost.")
|
req := createTestMessage("badhost.")
|
||||||
reply, err := dns.Exchange(req, addr.String())
|
reply, err := dns.Exchange(req, addr.String())
|
||||||
assert.True(t, err == nil)
|
assert.Nil(t, err)
|
||||||
assert.True(t, reply.Rcode == dns.RcodeSuccess)
|
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBlockCNAME(t *testing.T) {
|
func TestBlockCNAME(t *testing.T) {
|
||||||
|
@ -368,23 +403,23 @@ func TestBlockCNAME(t *testing.T) {
|
||||||
// response is blocked
|
// response is blocked
|
||||||
req := createTestMessage("badhost.")
|
req := createTestMessage("badhost.")
|
||||||
reply, err := dns.Exchange(req, addr.String())
|
reply, err := dns.Exchange(req, addr.String())
|
||||||
assert.True(t, err == nil)
|
assert.Nil(t, err, nil)
|
||||||
assert.True(t, reply.Rcode == dns.RcodeNameError)
|
assert.Equal(t, dns.RcodeNameError, reply.Rcode)
|
||||||
|
|
||||||
// 'whitelist.example.org' has a canonical name 'null.example.org' which is blocked by filters
|
// 'whitelist.example.org' has a canonical name 'null.example.org' which is blocked by filters
|
||||||
// but 'whitelist.example.org' is in a whitelist:
|
// but 'whitelist.example.org' is in a whitelist:
|
||||||
// response isn't blocked
|
// response isn't blocked
|
||||||
req = createTestMessage("whitelist.example.org.")
|
req = createTestMessage("whitelist.example.org.")
|
||||||
reply, err = dns.Exchange(req, addr.String())
|
reply, err = dns.Exchange(req, addr.String())
|
||||||
assert.True(t, err == nil)
|
assert.Nil(t, err)
|
||||||
assert.True(t, reply.Rcode == dns.RcodeSuccess)
|
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||||||
|
|
||||||
// 'example.org' has a canonical name 'cname1' with IP 127.0.0.255 which is blocked by filters:
|
// 'example.org' has a canonical name 'cname1' with IP 127.0.0.255 which is blocked by filters:
|
||||||
// response is blocked
|
// response is blocked
|
||||||
req = createTestMessage("example.org.")
|
req = createTestMessage("example.org.")
|
||||||
reply, err = dns.Exchange(req, addr.String())
|
reply, err = dns.Exchange(req, addr.String())
|
||||||
assert.True(t, err == nil)
|
assert.Nil(t, err)
|
||||||
assert.True(t, reply.Rcode == dns.RcodeNameError)
|
assert.Equal(t, dns.RcodeNameError, reply.Rcode)
|
||||||
|
|
||||||
_ = s.Stop()
|
_ = s.Stop()
|
||||||
}
|
}
|
||||||
|
@ -455,7 +490,7 @@ func TestNullBlockedRequest(t *testing.T) {
|
||||||
|
|
||||||
func TestBlockedCustomIP(t *testing.T) {
|
func TestBlockedCustomIP(t *testing.T) {
|
||||||
rules := "||nxdomain.example.org^\n||null.example.org^\n127.0.0.1 host.example.org\n@@||whitelist.example.org^\n||127.0.0.255\n"
|
rules := "||nxdomain.example.org^\n||null.example.org^\n127.0.0.1 host.example.org\n@@||whitelist.example.org^\n||127.0.0.255\n"
|
||||||
filters := []dnsfilter.Filter{dnsfilter.Filter{
|
filters := []dnsfilter.Filter{{
|
||||||
ID: 0, Data: []byte(rules),
|
ID: 0, Data: []byte(rules),
|
||||||
}}
|
}}
|
||||||
c := dnsfilter.Config{}
|
c := dnsfilter.Config{}
|
||||||
|
@ -475,27 +510,27 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||||
conf.BlockingIPv4 = "0.0.0.1"
|
conf.BlockingIPv4 = "0.0.0.1"
|
||||||
conf.BlockingIPv6 = "::1"
|
conf.BlockingIPv6 = "::1"
|
||||||
err = s.Prepare(&conf)
|
err = s.Prepare(&conf)
|
||||||
assert.True(t, err == nil)
|
assert.Nil(t, err)
|
||||||
err = s.Start()
|
err = s.Start()
|
||||||
assert.True(t, err == nil, "%s", err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
|
|
||||||
req := createTestMessageWithType("null.example.org.", dns.TypeA)
|
req := createTestMessageWithType("null.example.org.", dns.TypeA)
|
||||||
reply, err := dns.Exchange(req, addr.String())
|
reply, err := dns.Exchange(req, addr.String())
|
||||||
assert.True(t, err == nil)
|
assert.Nil(t, err)
|
||||||
assert.True(t, len(reply.Answer) == 1)
|
assert.Equal(t, 1, len(reply.Answer))
|
||||||
a, ok := reply.Answer[0].(*dns.A)
|
a, ok := reply.Answer[0].(*dns.A)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.True(t, a.A.String() == "0.0.0.1")
|
assert.Equal(t, "0.0.0.1", a.A.String())
|
||||||
|
|
||||||
req = createTestMessageWithType("null.example.org.", dns.TypeAAAA)
|
req = createTestMessageWithType("null.example.org.", dns.TypeAAAA)
|
||||||
reply, err = dns.Exchange(req, addr.String())
|
reply, err = dns.Exchange(req, addr.String())
|
||||||
assert.True(t, err == nil)
|
assert.Nil(t, err)
|
||||||
assert.True(t, len(reply.Answer) == 1)
|
assert.Equal(t, 1, len(reply.Answer))
|
||||||
a6, ok := reply.Answer[0].(*dns.AAAA)
|
a6, ok := reply.Answer[0].(*dns.AAAA)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.True(t, a6.AAAA.String() == "::1")
|
assert.Equal(t, "::1", a6.AAAA.String())
|
||||||
|
|
||||||
err = s.Stop()
|
err = s.Stop()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -598,7 +633,7 @@ func createTestServer(t *testing.T) *Server {
|
||||||
127.0.0.1 host.example.org
|
127.0.0.1 host.example.org
|
||||||
@@||whitelist.example.org^
|
@@||whitelist.example.org^
|
||||||
||127.0.0.255`
|
||127.0.0.255`
|
||||||
filters := []dnsfilter.Filter{dnsfilter.Filter{
|
filters := []dnsfilter.Filter{{
|
||||||
ID: 0, Data: []byte(rules),
|
ID: 0, Data: []byte(rules),
|
||||||
}}
|
}}
|
||||||
c := dnsfilter.Config{}
|
c := dnsfilter.Config{}
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) beforeRequestHandler(p *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
|
func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
|
||||||
ip := ipFromAddr(d.Addr)
|
ip := ipFromAddr(d.Addr)
|
||||||
if s.access.IsBlockedIP(ip) {
|
if s.access.IsBlockedIP(ip) {
|
||||||
log.Tracef("Client IP %s is blocked by settings", ip)
|
log.Tracef("Client IP %s is blocked by settings", ip)
|
||||||
|
|
|
@ -31,7 +31,7 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
|
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
|
||||||
func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||||
ctx := &dnsContext{srv: s, proxyCtx: d}
|
ctx := &dnsContext{srv: s, proxyCtx: d}
|
||||||
ctx.result = &dnsfilter.Result{}
|
ctx.result = &dnsfilter.Result{}
|
||||||
ctx.startTime = time.Now()
|
ctx.startTime = time.Now()
|
||||||
|
@ -124,12 +124,12 @@ func processUpstream(ctx *dnsContext) int {
|
||||||
return resultDone // response is already set - nothing to do
|
return resultDone // response is already set - nothing to do
|
||||||
}
|
}
|
||||||
|
|
||||||
if d.Addr != nil && s.conf.GetUpstreamsByClient != nil {
|
if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
|
||||||
clientIP := ipFromAddr(d.Addr)
|
clientIP := ipFromAddr(d.Addr)
|
||||||
upstreams := s.conf.GetUpstreamsByClient(clientIP)
|
upstreamsConf := s.conf.GetCustomUpstreamByClient(clientIP)
|
||||||
if len(upstreams) > 0 {
|
if upstreamsConf != nil {
|
||||||
log.Debug("Using custom upstreams for %s", clientIP)
|
log.Debug("Using custom upstreams for %s", clientIP)
|
||||||
d.Upstreams = upstreams
|
d.CustomUpstreamConfig = upstreamsConf
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -165,8 +165,9 @@ func processDNSSECAfterResponse(ctx *dnsContext) int {
|
||||||
return resultDone
|
return resultDone
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !ctx.origReqDNSSEC {
|
||||||
optResp := d.Res.IsEdns0()
|
optResp := d.Res.IsEdns0()
|
||||||
if !ctx.origReqDNSSEC && optResp != nil && optResp.Do() {
|
if optResp != nil && !optResp.Do() {
|
||||||
return resultDone
|
return resultDone
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -196,6 +197,7 @@ func processDNSSECAfterResponse(ctx *dnsContext) int {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
d.Res.Ns = answers
|
d.Res.Ns = answers
|
||||||
|
}
|
||||||
|
|
||||||
return resultDone
|
return resultDone
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,14 +0,0 @@
|
||||||
package dnsforward
|
|
||||||
|
|
||||||
import "net"
|
|
||||||
|
|
||||||
// GetIPString is a helper function that extracts IP address from net.Addr
|
|
||||||
func GetIPString(addr net.Addr) string {
|
|
||||||
switch addr := addr.(type) {
|
|
||||||
case *net.UDPAddr:
|
|
||||||
return addr.IP.String()
|
|
||||||
case *net.TCPAddr:
|
|
||||||
return addr.IP.String()
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
|
@ -8,6 +8,17 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/utils"
|
"github.com/AdguardTeam/golibs/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// GetIPString is a helper function that extracts IP address from net.Addr
|
||||||
|
func GetIPString(addr net.Addr) string {
|
||||||
|
switch addr := addr.(type) {
|
||||||
|
case *net.UDPAddr:
|
||||||
|
return addr.IP.String()
|
||||||
|
case *net.TCPAddr:
|
||||||
|
return addr.IP.String()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func stringArrayDup(a []string) []string {
|
func stringArrayDup(a []string) []string {
|
||||||
a2 := make([]string, len(a))
|
a2 := make([]string, len(a))
|
||||||
copy(a2, a)
|
copy(a2, a)
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome
|
||||||
go 1.14
|
go 1.14
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/AdguardTeam/dnsproxy v0.28.0
|
github.com/AdguardTeam/dnsproxy v0.28.1
|
||||||
github.com/AdguardTeam/golibs v0.4.2
|
github.com/AdguardTeam/golibs v0.4.2
|
||||||
github.com/AdguardTeam/urlfilter v0.10.0
|
github.com/AdguardTeam/urlfilter v0.10.0
|
||||||
github.com/NYTimes/gziphandler v1.1.1
|
github.com/NYTimes/gziphandler v1.1.1
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -1,5 +1,5 @@
|
||||||
github.com/AdguardTeam/dnsproxy v0.28.0 h1:w6ITGjSMLztUOTVNVVcE0JU1bV2U0bOPyDHGwyZgTc4=
|
github.com/AdguardTeam/dnsproxy v0.28.1 h1:WkLjrUcVf/njbTLyL7bNt6e18zQjF2ZYv/HWwL9cMmU=
|
||||||
github.com/AdguardTeam/dnsproxy v0.28.0/go.mod h1:hOYFV9TW+pd5XKYz7KZf2FFD8SvSPqjyGTxUae86s58=
|
github.com/AdguardTeam/dnsproxy v0.28.1/go.mod h1:hOYFV9TW+pd5XKYz7KZf2FFD8SvSPqjyGTxUae86s58=
|
||||||
github.com/AdguardTeam/golibs v0.4.0 h1:4VX6LoOqFe9p9Gf55BeD8BvJD6M6RDYmgEiHrENE9KU=
|
github.com/AdguardTeam/golibs v0.4.0 h1:4VX6LoOqFe9p9Gf55BeD8BvJD6M6RDYmgEiHrENE9KU=
|
||||||
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
||||||
github.com/AdguardTeam/golibs v0.4.2 h1:7M28oTZFoFwNmp8eGPb3ImmYbxGaJLyQXeIFVHjME0o=
|
github.com/AdguardTeam/golibs v0.4.2 h1:7M28oTZFoFwNmp8eGPb3ImmYbxGaJLyQXeIFVHjME0o=
|
||||||
|
|
|
@ -11,11 +11,12 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
||||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||||
"github.com/AdguardTeam/AdGuardHome/util"
|
"github.com/AdguardTeam/AdGuardHome/util"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/utils"
|
"github.com/AdguardTeam/golibs/utils"
|
||||||
)
|
)
|
||||||
|
@ -41,11 +42,12 @@ type Client struct {
|
||||||
BlockedServices []string
|
BlockedServices []string
|
||||||
|
|
||||||
Upstreams []string // list of upstream servers to be used for the client's requests
|
Upstreams []string // list of upstream servers to be used for the client's requests
|
||||||
// Upstream objects:
|
|
||||||
|
// Custom upstream config for this client
|
||||||
// nil: not yet initialized
|
// nil: not yet initialized
|
||||||
// not nil, but empty: initialized, no good upstreams
|
// not nil, but empty: initialized, no good upstreams
|
||||||
// not nil, not empty: Upstreams ready to be used
|
// not nil, not empty: Upstreams ready to be used
|
||||||
upstreamObjects []upstream.Upstream
|
upstreamConfig *proxy.UpstreamConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type clientSource uint
|
type clientSource uint
|
||||||
|
@ -273,16 +275,10 @@ func (clients *clientsContainer) Find(ip string) (Client, bool) {
|
||||||
return c, true
|
return c, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func upstreamArrayCopy(a []upstream.Upstream) []upstream.Upstream {
|
|
||||||
a2 := make([]upstream.Upstream, len(a))
|
|
||||||
copy(a2, a)
|
|
||||||
return a2
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindUpstreams looks for upstreams configured for the client
|
// FindUpstreams looks for upstreams configured for the client
|
||||||
// If no client found for this IP, or if no custom upstreams are configured,
|
// If no client found for this IP, or if no custom upstreams are configured,
|
||||||
// this method returns nil
|
// this method returns nil
|
||||||
func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream {
|
func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig {
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
|
@ -291,22 +287,18 @@ func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.upstreamObjects == nil {
|
if len(c.Upstreams) == 0 {
|
||||||
c.upstreamObjects = make([]upstream.Upstream, 0)
|
return nil
|
||||||
for _, us := range c.Upstreams {
|
|
||||||
u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout})
|
|
||||||
if err != nil {
|
|
||||||
log.Error("upstream.AddressToUpstream: %s: %s", us, err)
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
c.upstreamObjects = append(c.upstreamObjects, u)
|
|
||||||
|
if c.upstreamConfig == nil {
|
||||||
|
config, err := proxy.ParseUpstreamsConfig(c.Upstreams, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout)
|
||||||
|
if err == nil {
|
||||||
|
c.upstreamConfig = &config
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(c.upstreamObjects) == 0 {
|
return c.upstreamConfig
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return upstreamArrayCopy(c.upstreamObjects)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find searches for a client by IP (and does not lock anything)
|
// Find searches for a client by IP (and does not lock anything)
|
||||||
|
@ -537,7 +529,7 @@ func (clients *clientsContainer) Update(name string, c Client) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// update upstreams cache
|
// update upstreams cache
|
||||||
c.upstreamObjects = nil
|
c.upstreamConfig = nil
|
||||||
|
|
||||||
*old = c
|
*old = c
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -236,3 +236,31 @@ func TestClientsAddExisting(t *testing.T) {
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClientsCustomUpstream(t *testing.T) {
|
||||||
|
clients := clientsContainer{}
|
||||||
|
clients.testing = true
|
||||||
|
|
||||||
|
clients.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
// add client with upstreams
|
||||||
|
client := Client{
|
||||||
|
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
|
||||||
|
Name: "client1",
|
||||||
|
Upstreams: []string{
|
||||||
|
"1.1.1.1",
|
||||||
|
"[/example.org/]8.8.8.8",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ok, err := clients.Add(client)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
config := clients.FindUpstreams("1.2.3.4")
|
||||||
|
assert.Nil(t, config)
|
||||||
|
|
||||||
|
config = clients.FindUpstreams("1.1.1.1")
|
||||||
|
assert.NotNil(t, config)
|
||||||
|
assert.Equal(t, 1, len(config.Upstreams))
|
||||||
|
assert.Equal(t, 1, len(config.DomainReservedUpstreams))
|
||||||
|
}
|
||||||
|
|
|
@ -11,7 +11,6 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/stats"
|
"github.com/AdguardTeam/AdGuardHome/stats"
|
||||||
"github.com/AdguardTeam/AdGuardHome/util"
|
"github.com/AdguardTeam/AdGuardHome/util"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/joomcode/errorx"
|
"github.com/joomcode/errorx"
|
||||||
)
|
)
|
||||||
|
@ -176,7 +175,7 @@ func generateServerConfig() dnsforward.ServerConfig {
|
||||||
newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
|
newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
|
||||||
|
|
||||||
newconfig.FilterHandler = applyAdditionalFiltering
|
newconfig.FilterHandler = applyAdditionalFiltering
|
||||||
newconfig.GetUpstreamsByClient = getUpstreamsByClient
|
newconfig.GetCustomUpstreamByClient = Context.clients.FindUpstreams
|
||||||
return newconfig
|
return newconfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -222,10 +221,6 @@ func getDNSAddresses() []string {
|
||||||
return dnsAddresses
|
return dnsAddresses
|
||||||
}
|
}
|
||||||
|
|
||||||
func getUpstreamsByClient(clientAddr string) []upstream.Upstream {
|
|
||||||
return Context.clients.FindUpstreams(clientAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If a client has his own settings, apply them
|
// If a client has his own settings, apply them
|
||||||
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
|
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
|
||||||
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
|
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
|
||||||
|
|
Loading…
Reference in a new issue