mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-22 04:55:33 +03:00
Pull request: all: fix client upstreams, imp code
Updates #3186. Squashed commit of the following: commit a8dd0e2cda3039839d069fe71a5bd0f9635ec064 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Fri May 28 12:54:07 2021 +0300 all: imp code, names commit 98f86c21ae23b665095075feb4a59dcfcc622bc7 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Thu May 27 21:11:37 2021 +0300 all: fix client upstreams, imp code
This commit is contained in:
parent
48b8579703
commit
3be783bd34
18 changed files with 249 additions and 270 deletions
|
@ -32,6 +32,8 @@ released by then.
|
|||
|
||||
### Fixed
|
||||
|
||||
- Custom upstreams selection for clients with client IDs in DNS-over-TLS and
|
||||
DNS-over-HTTP ([#3186]).
|
||||
- Incorrect client-based filtering applying logic ([#2875]).
|
||||
|
||||
### Removed
|
||||
|
@ -40,6 +42,7 @@ released by then.
|
|||
|
||||
[#3184]: https://github.com/AdguardTeam/AdGuardHome/issues/3184
|
||||
[#3185]: https://github.com/AdguardTeam/AdGuardHome/issues/3185
|
||||
[#3186]: https://github.com/AdguardTeam/AdGuardHome/issues/3186
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -10,6 +10,19 @@ import (
|
|||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
// IPFromAddr returns an IP address from addr. If addr is neither
|
||||
// a *net.TCPAddr nor a *net.UDPAddr, it returns nil.
|
||||
func IPFromAddr(addr net.Addr) (ip net.IP) {
|
||||
switch addr := addr.(type) {
|
||||
case *net.TCPAddr:
|
||||
return addr.IP
|
||||
case *net.UDPAddr:
|
||||
return addr.IP
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsValidHostOuterRune returns true if r is a valid initial or final rune for
|
||||
// a hostname label.
|
||||
func IsValidHostOuterRune(r rune) (ok bool) {
|
||||
|
|
|
@ -9,6 +9,14 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIPFromAddr(t *testing.T) {
|
||||
ip := net.IP{1, 2, 3, 4}
|
||||
assert.Equal(t, net.IP(nil), IPFromAddr(nil))
|
||||
assert.Equal(t, net.IP(nil), IPFromAddr(struct{ net.Addr }{}))
|
||||
assert.Equal(t, ip, IPFromAddr(&net.TCPAddr{IP: ip}))
|
||||
assert.Equal(t, ip, IPFromAddr(&net.UDPAddr{IP: ip}))
|
||||
}
|
||||
|
||||
func TestValidateHardwareAddress(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
|
|
@ -19,6 +19,19 @@ func CloneSlice(a []string) (b []string) {
|
|||
return CloneSliceOrEmpty(a)
|
||||
}
|
||||
|
||||
// Coalesce returns the first non-empty string. It is named after the function
|
||||
// COALESCE in SQL except that since strings in Go are non-nullable, it uses an
|
||||
// empty string as a NULL value. If strs is empty, it returns an empty string.
|
||||
func Coalesce(strs ...string) (res string) {
|
||||
for _, s := range strs {
|
||||
if s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// FilterOut returns a copy of strs with all strings for which f returned true
|
||||
// removed.
|
||||
func FilterOut(strs []string, f func(s string) (ok bool)) (filtered []string) {
|
||||
|
|
|
@ -36,6 +36,14 @@ func TestCloneSlice_family(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestCoalesce(t *testing.T) {
|
||||
assert.Equal(t, "", Coalesce())
|
||||
assert.Equal(t, "a", Coalesce("a"))
|
||||
assert.Equal(t, "a", Coalesce("", "a"))
|
||||
assert.Equal(t, "a", Coalesce("a", ""))
|
||||
assert.Equal(t, "a", Coalesce("a", "b"))
|
||||
}
|
||||
|
||||
func TestFilterOut(t *testing.T) {
|
||||
strs := []string{
|
||||
"1.2.3.4",
|
||||
|
|
|
@ -8,7 +8,9 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
|
@ -27,11 +29,10 @@ type FilteringConfig struct {
|
|||
// FilterHandler is an optional additional filtering callback.
|
||||
FilterHandler func(clientAddr net.IP, clientID string, settings *filtering.Settings) `yaml:"-"`
|
||||
|
||||
// GetCustomUpstreamByClient - a callback function that returns upstreams configuration
|
||||
// based on the client IP address. Returns nil if there are no custom upstreams for the client
|
||||
//
|
||||
// TODO(e.burkov): Replace argument type with net.IP.
|
||||
GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"`
|
||||
// GetCustomUpstreamByClient is a callback that returns upstreams
|
||||
// configuration based on the client IP address or ClientID. It returns
|
||||
// nil if there are no custom upstreams for the client.
|
||||
GetCustomUpstreamByClient func(id string) (conf *proxy.UpstreamConfig, err error) `yaml:"-"`
|
||||
|
||||
// Protection configuration
|
||||
// --
|
||||
|
@ -384,10 +385,51 @@ func (s *Server) prepareTLS(proxyConfig *proxy.Config) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// isInSorted returns true if s is in the sorted slice strs.
|
||||
func isInSorted(strs []string, s string) (ok bool) {
|
||||
i := sort.SearchStrings(strs, s)
|
||||
if i == len(strs) || strs[i] != s {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// isWildcard returns true if host is a wildcard hostname.
|
||||
func isWildcard(host string) (ok bool) {
|
||||
return len(host) >= 2 && host[0] == '*' && host[1] == '.'
|
||||
}
|
||||
|
||||
// matchesDomainWildcard returns true if host matches the domain wildcard
|
||||
// pattern pat.
|
||||
func matchesDomainWildcard(host, pat string) (ok bool) {
|
||||
return isWildcard(pat) && strings.HasSuffix(host, pat[1:])
|
||||
}
|
||||
|
||||
// anyNameMatches returns true if sni, the client's SNI value, matches any of
|
||||
// the DNS names and patterns from certificate. dnsNames must be sorted.
|
||||
func anyNameMatches(dnsNames []string, sni string) (ok bool) {
|
||||
if aghnet.ValidateDomainName(sni) != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if isInSorted(dnsNames, sni) {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, dn := range dnsNames {
|
||||
if matchesDomainWildcard(sni, dn) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Called by 'tls' package when Client Hello is received
|
||||
// If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake.
|
||||
func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if s.conf.StrictSNICheck && !matchDNSName(s.conf.dnsNames, ch.ServerName) {
|
||||
if s.conf.StrictSNICheck && !anyNameMatches(s.conf.dnsNames, ch.ServerName) {
|
||||
log.Info("dns: tls: unknown SNI in Client Hello: %s", ch.ServerName)
|
||||
return nil, fmt.Errorf("invalid SNI")
|
||||
}
|
||||
|
|
53
internal/dnsforward/config_test.go
Normal file
53
internal/dnsforward/config_test.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
package dnsforward
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAnyNameMatches(t *testing.T) {
|
||||
dnsNames := []string{"host1", "*.host2", "1.2.3.4"}
|
||||
sort.Strings(dnsNames)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
dnsName string
|
||||
want bool
|
||||
}{{
|
||||
name: "match",
|
||||
dnsName: "host1",
|
||||
want: true,
|
||||
}, {
|
||||
name: "match",
|
||||
dnsName: "a.host2",
|
||||
want: true,
|
||||
}, {
|
||||
name: "match",
|
||||
dnsName: "b.a.host2",
|
||||
want: true,
|
||||
}, {
|
||||
name: "match",
|
||||
dnsName: "1.2.3.4",
|
||||
want: true,
|
||||
}, {
|
||||
name: "mismatch",
|
||||
dnsName: "host2",
|
||||
want: false,
|
||||
}, {
|
||||
name: "mismatch",
|
||||
dnsName: "",
|
||||
want: false,
|
||||
}, {
|
||||
name: "mismatch",
|
||||
dnsName: "*.host2",
|
||||
want: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, anyNameMatches(dnsNames, tc.dnsName))
|
||||
})
|
||||
}
|
||||
}
|
|
@ -6,6 +6,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
|
@ -229,7 +230,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
|
|||
rc = resultCodeSuccess
|
||||
|
||||
var ip net.IP
|
||||
if ip = IPFromAddr(dctx.proxyCtx.Addr); ip == nil {
|
||||
if ip = aghnet.IPFromAddr(dctx.proxyCtx.Addr); ip == nil {
|
||||
return rc
|
||||
}
|
||||
|
||||
|
@ -489,6 +490,15 @@ func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
|
|||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// ipStringFromAddr extracts an IP address string from net.Addr.
|
||||
func ipStringFromAddr(addr net.Addr) (ipStr string) {
|
||||
if ip := aghnet.IPFromAddr(addr); ip != nil {
|
||||
return ip.String()
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// processUpstream passes request to upstream servers and handles the response.
|
||||
func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) {
|
||||
d := ctx.proxyCtx
|
||||
|
@ -497,9 +507,13 @@ func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) {
|
|||
}
|
||||
|
||||
if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
|
||||
clientIP := IPStringFromAddr(d.Addr)
|
||||
if upsConf := s.conf.GetCustomUpstreamByClient(clientIP); upsConf != nil {
|
||||
log.Debug("dns: using custom upstreams for client %s", clientIP)
|
||||
// Use the clientID first, since it has a higher priority.
|
||||
id := aghstrings.Coalesce(ctx.clientID, ipStringFromAddr(d.Addr))
|
||||
upsConf, err := s.conf.GetCustomUpstreamByClient(id)
|
||||
if err != nil {
|
||||
log.Error("dns: getting custom upstreams for client %s: %s", id, err)
|
||||
} else if upsConf != nil {
|
||||
log.Debug("dns: using custom upstreams for client %s", id)
|
||||
d.CustomUpstreamConfig = upsConf
|
||||
}
|
||||
}
|
||||
|
|
|
@ -379,3 +379,18 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
|
|||
require.Empty(t, proxyCtx.Res.Answer)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIPStringFromAddr(t *testing.T) {
|
||||
t.Run("not_nil", func(t *testing.T) {
|
||||
addr := net.UDPAddr{
|
||||
IP: net.ParseIP("1:2:3::4"),
|
||||
Port: 12345,
|
||||
Zone: "eth0",
|
||||
}
|
||||
assert.Equal(t, ipStringFromAddr(&addr), addr.IP.String())
|
||||
})
|
||||
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
assert.Empty(t, ipStringFromAddr(nil))
|
||||
})
|
||||
}
|
||||
|
|
|
@ -12,7 +12,6 @@ import (
|
|||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -521,16 +520,16 @@ func TestServerCustomClientUpstream(t *testing.T) {
|
|||
},
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||
s.conf.GetCustomUpstreamByClient = func(_ string) *proxy.UpstreamConfig {
|
||||
return &proxy.UpstreamConfig{
|
||||
Upstreams: []upstream.Upstream{
|
||||
&aghtest.TestUpstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"host.": {{192, 168, 0, 1}},
|
||||
},
|
||||
},
|
||||
s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) {
|
||||
ups := &aghtest.TestUpstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"host.": {{192, 168, 0, 1}},
|
||||
},
|
||||
}
|
||||
|
||||
return &proxy.UpstreamConfig{
|
||||
Upstreams: []upstream.Upstream{ups},
|
||||
}, nil
|
||||
}
|
||||
startDeferStop(t, s)
|
||||
|
||||
|
@ -962,65 +961,6 @@ func publicKey(priv interface{}) interface{} {
|
|||
}
|
||||
}
|
||||
|
||||
func TestIPStringFromAddr(t *testing.T) {
|
||||
t.Run("not_nil", func(t *testing.T) {
|
||||
addr := net.UDPAddr{
|
||||
IP: net.ParseIP("1:2:3::4"),
|
||||
Port: 12345,
|
||||
Zone: "eth0",
|
||||
}
|
||||
assert.Equal(t, IPStringFromAddr(&addr), addr.IP.String())
|
||||
})
|
||||
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
assert.Empty(t, IPStringFromAddr(nil))
|
||||
})
|
||||
}
|
||||
|
||||
func TestMatchDNSName(t *testing.T) {
|
||||
dnsNames := []string{"host1", "*.host2", "1.2.3.4"}
|
||||
sort.Strings(dnsNames)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
dnsName string
|
||||
want bool
|
||||
}{{
|
||||
name: "match",
|
||||
dnsName: "host1",
|
||||
want: true,
|
||||
}, {
|
||||
name: "match",
|
||||
dnsName: "a.host2",
|
||||
want: true,
|
||||
}, {
|
||||
name: "match",
|
||||
dnsName: "b.a.host2",
|
||||
want: true,
|
||||
}, {
|
||||
name: "match",
|
||||
dnsName: "1.2.3.4",
|
||||
want: true,
|
||||
}, {
|
||||
name: "mismatch",
|
||||
dnsName: "host2",
|
||||
want: false,
|
||||
}, {
|
||||
name: "mismatch",
|
||||
dnsName: "",
|
||||
want: false,
|
||||
}, {
|
||||
name: "mismatch",
|
||||
dnsName: "*.host2",
|
||||
want: false,
|
||||
}}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, matchDNSName(dnsNames, tc.dnsName))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testDHCP struct{}
|
||||
|
||||
func (d *testDHCP) Enabled() (ok bool) { return true }
|
||||
|
|
|
@ -4,15 +4,15 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
|
||||
ip := IPFromAddr(d.Addr)
|
||||
ip := aghnet.IPFromAddr(d.Addr)
|
||||
disallowed, _ := s.access.IsBlockedIP(ip)
|
||||
if disallowed {
|
||||
log.Tracef("Client IP %s is blocked by settings", ip)
|
||||
|
@ -39,7 +39,7 @@ func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool
|
|||
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings {
|
||||
setts := s.dnsFilter.GetConfig()
|
||||
if s.conf.FilterHandler != nil {
|
||||
s.conf.FilterHandler(IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts)
|
||||
s.conf.FilterHandler(aghnet.IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts)
|
||||
}
|
||||
|
||||
return &setts
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
|
@ -37,7 +38,7 @@ func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) {
|
|||
OrigAnswer: ctx.origResp,
|
||||
Result: ctx.result,
|
||||
Elapsed: elapsed,
|
||||
ClientIP: IPFromAddr(pctx.Addr),
|
||||
ClientIP: aghnet.IPFromAddr(pctx.Addr),
|
||||
ClientID: ctx.clientID,
|
||||
}
|
||||
|
||||
|
@ -79,7 +80,7 @@ func (s *Server) updateStats(ctx *dnsContext, elapsed time.Duration, res filteri
|
|||
|
||||
if clientID := ctx.clientID; clientID != "" {
|
||||
e.Client = clientID
|
||||
} else if ip := IPFromAddr(pctx.Addr); ip != nil {
|
||||
} else if ip := aghnet.IPFromAddr(pctx.Addr); ip != nil {
|
||||
e.Client = ip.String()
|
||||
}
|
||||
|
||||
|
|
|
@ -1,69 +0,0 @@
|
|||
package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
)
|
||||
|
||||
// IPFromAddr gets IP address from addr.
|
||||
func IPFromAddr(addr net.Addr) (ip net.IP) {
|
||||
switch addr := addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
return addr.IP
|
||||
case *net.TCPAddr:
|
||||
return addr.IP
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IPStringFromAddr extracts IP address from net.Addr.
|
||||
// Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone:
|
||||
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261
|
||||
func IPStringFromAddr(addr net.Addr) (ipStr string) {
|
||||
if ip := IPFromAddr(addr); ip != nil {
|
||||
return ip.String()
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// Find value in a sorted array
|
||||
func findSorted(ar []string, val string) int {
|
||||
i := sort.SearchStrings(ar, val)
|
||||
if i == len(ar) || ar[i] != val {
|
||||
return -1
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func isWildcard(host string) bool {
|
||||
return len(host) >= 2 &&
|
||||
host[0] == '*' && host[1] == '.'
|
||||
}
|
||||
|
||||
// Return TRUE if host name matches a wildcard pattern
|
||||
func matchDomainWildcard(host, wildcard string) bool {
|
||||
return isWildcard(wildcard) &&
|
||||
strings.HasSuffix(host, wildcard[1:])
|
||||
}
|
||||
|
||||
// Return TRUE if client's SNI value matches DNS names from certificate
|
||||
func matchDNSName(dnsNames []string, sni string) bool {
|
||||
if aghnet.ValidateDomainName(sni) != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if findSorted(dnsNames, sni) != -1 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, dn := range dnsNames {
|
||||
if matchDomainWildcard(sni, dn) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
|
@ -1,60 +0,0 @@
|
|||
package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// fakeAddr is a mock implementation of net.Addr interface to simplify testing.
|
||||
type fakeAddr struct {
|
||||
// Addr is embedded here simply to make fakeAddr a net.Addr without
|
||||
// actually implementing all methods.
|
||||
net.Addr
|
||||
}
|
||||
|
||||
func TestIPFromAddr(t *testing.T) {
|
||||
supIPv4 := net.IP{1, 2, 3, 4}
|
||||
supIPv6 := net.ParseIP("2a00:1450:400c:c06::93")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
addr net.Addr
|
||||
want net.IP
|
||||
}{{
|
||||
name: "ipv4_tcp",
|
||||
addr: &net.TCPAddr{
|
||||
IP: supIPv4,
|
||||
},
|
||||
want: supIPv4,
|
||||
}, {
|
||||
name: "ipv6_tcp",
|
||||
addr: &net.TCPAddr{
|
||||
IP: supIPv6,
|
||||
},
|
||||
want: supIPv6,
|
||||
}, {
|
||||
name: "ipv4_udp",
|
||||
addr: &net.UDPAddr{
|
||||
IP: supIPv4,
|
||||
},
|
||||
want: supIPv4,
|
||||
}, {
|
||||
name: "ipv6_udp",
|
||||
addr: &net.UDPAddr{
|
||||
IP: supIPv6,
|
||||
},
|
||||
want: supIPv6,
|
||||
}, {
|
||||
name: "non-ip_addr",
|
||||
addr: &fakeAddr{},
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, IPFromAddr(tc.addr))
|
||||
})
|
||||
}
|
||||
}
|
|
@ -335,37 +335,44 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
|
|||
return c, true
|
||||
}
|
||||
|
||||
// FindUpstreams looks for upstreams configured for the client
|
||||
// If no client found for this IP, or if no custom upstreams are configured,
|
||||
// this method returns nil
|
||||
func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig {
|
||||
// findUpstreams returns upstreams configured for the client, identified either
|
||||
// by its IP address or its ClientID. upsConf is nil if the client isn't found
|
||||
// or if the client has no custom upstreams.
|
||||
func (clients *clientsContainer) findUpstreams(
|
||||
id string,
|
||||
) (upsConf *proxy.UpstreamConfig, err error) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
c, ok := clients.findLocked(ip)
|
||||
c, ok := clients.findLocked(id)
|
||||
if !ok {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
upstreams := aghstrings.FilterOut(c.Upstreams, aghstrings.IsCommentOrEmpty)
|
||||
if len(upstreams) == 0 {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if c.upstreamConfig == nil {
|
||||
conf, err := proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
upstream.Options{
|
||||
Bootstrap: config.DNS.BootstrapDNS,
|
||||
Timeout: dnsforward.DefaultTimeout,
|
||||
},
|
||||
)
|
||||
if err == nil {
|
||||
c.upstreamConfig = &conf
|
||||
}
|
||||
if c.upstreamConfig != nil {
|
||||
return c.upstreamConfig, nil
|
||||
}
|
||||
|
||||
return c.upstreamConfig
|
||||
var conf proxy.UpstreamConfig
|
||||
conf, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
upstream.Options{
|
||||
Bootstrap: config.DNS.BootstrapDNS,
|
||||
Timeout: dnsforward.DefaultTimeout,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.upstreamConfig = &conf
|
||||
|
||||
return &conf, nil
|
||||
}
|
||||
|
||||
// findLocked searches for a client by its ID. For internal use only.
|
||||
|
|
|
@ -25,7 +25,7 @@ func TestClients(t *testing.T) {
|
|||
}
|
||||
|
||||
ok, err := clients.Add(c)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
c = &Client{
|
||||
|
@ -34,7 +34,7 @@ func TestClients(t *testing.T) {
|
|||
}
|
||||
|
||||
ok, err = clients.Add(c)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
c, ok = clients.Find("1.1.1.1")
|
||||
|
@ -59,7 +59,7 @@ func TestClients(t *testing.T) {
|
|||
IDs: []string{"1.2.3.5"},
|
||||
Name: "client1",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
|
@ -68,7 +68,7 @@ func TestClients(t *testing.T) {
|
|||
IDs: []string{"2.2.2.2"},
|
||||
Name: "client3",
|
||||
})
|
||||
require.NotNil(t, err)
|
||||
require.Error(t, err)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
|
@ -77,13 +77,13 @@ func TestClients(t *testing.T) {
|
|||
IDs: []string{"1.2.3.0"},
|
||||
Name: "client3",
|
||||
})
|
||||
require.NotNil(t, err)
|
||||
require.Error(t, err)
|
||||
|
||||
err = clients.Update("client3", &Client{
|
||||
IDs: []string{"1.2.3.0"},
|
||||
Name: "client2",
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("update_fail_ip", func(t *testing.T) {
|
||||
|
@ -91,7 +91,7 @@ func TestClients(t *testing.T) {
|
|||
IDs: []string{"2.2.2.2"},
|
||||
Name: "client1",
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("update_success", func(t *testing.T) {
|
||||
|
@ -99,7 +99,7 @@ func TestClients(t *testing.T) {
|
|||
IDs: []string{"1.1.1.2"},
|
||||
Name: "client1",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
||||
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
|
||||
|
@ -109,7 +109,7 @@ func TestClients(t *testing.T) {
|
|||
Name: "client1-renamed",
|
||||
UseOwnSettings: true,
|
||||
})
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
c, ok := clients.Find("1.1.1.2")
|
||||
require.True(t, ok)
|
||||
|
@ -137,15 +137,15 @@ func TestClients(t *testing.T) {
|
|||
|
||||
t.Run("addhost_success", func(t *testing.T) {
|
||||
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
||||
|
@ -153,7 +153,7 @@ func TestClients(t *testing.T) {
|
|||
|
||||
t.Run("addhost_fail", func(t *testing.T) {
|
||||
ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
@ -181,7 +181,7 @@ func TestClientsWhois(t *testing.T) {
|
|||
|
||||
t.Run("existing_auto-client", func(t *testing.T) {
|
||||
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
clients.SetWhoisInfo("1.1.1.1", whois)
|
||||
|
@ -198,7 +198,7 @@ func TestClientsWhois(t *testing.T) {
|
|||
IDs: []string{"1.1.1.2"},
|
||||
Name: "client1",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
clients.SetWhoisInfo("1.1.1.2", whois)
|
||||
|
@ -219,12 +219,12 @@ func TestClientsAddExisting(t *testing.T) {
|
|||
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
|
||||
Name: "client1",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Now add an auto-client with the same IP.
|
||||
ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
|
@ -253,14 +253,14 @@ func TestClientsAddExisting(t *testing.T) {
|
|||
Hostname: "testhost",
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
})
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add a new client with the same IP as for a client with MAC.
|
||||
ok, err := clients.Add(&Client{
|
||||
IDs: []string{testIP.String()},
|
||||
Name: "client2",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Add a new client with the IP from the first client's IP
|
||||
|
@ -269,7 +269,7 @@ func TestClientsAddExisting(t *testing.T) {
|
|||
IDs: []string{"2.2.2.2"},
|
||||
Name: "client3",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
|
@ -289,14 +289,16 @@ func TestClientsCustomUpstream(t *testing.T) {
|
|||
"[/example.org/]8.8.8.8",
|
||||
},
|
||||
})
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
config := clients.FindUpstreams("1.2.3.4")
|
||||
config, err := clients.findUpstreams("1.2.3.4")
|
||||
assert.Nil(t, config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
config = clients.FindUpstreams("1.1.1.1")
|
||||
config, err = clients.findUpstreams("1.1.1.1")
|
||||
require.NotNil(t, config)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, config.Upstreams, 1)
|
||||
assert.Len(t, config.DomainReservedUpstreams, 1)
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
|
@ -106,7 +107,7 @@ func isRunning() bool {
|
|||
}
|
||||
|
||||
func onDNSRequest(d *proxy.DNSContext) {
|
||||
ip := dnsforward.IPFromAddr(d.Addr)
|
||||
ip := aghnet.IPFromAddr(d.Addr)
|
||||
if ip == nil {
|
||||
// This would be quite weird if we get here.
|
||||
return
|
||||
|
@ -197,7 +198,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
|||
newConf.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
|
||||
|
||||
newConf.FilterHandler = applyAdditionalFiltering
|
||||
newConf.GetCustomUpstreamByClient = Context.clients.FindUpstreams
|
||||
newConf.GetCustomUpstreamByClient = Context.clients.findUpstreams
|
||||
|
||||
newConf.ResolveClients = dnsConf.ResolveClients
|
||||
newConf.UsePrivateRDNS = dnsConf.UsePrivateRDNS
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
|
@ -66,19 +67,6 @@ func trimValue(s string) string {
|
|||
return s[:maxValueLength-3] + "..."
|
||||
}
|
||||
|
||||
// coalesceStr returns the first non-empty string.
|
||||
//
|
||||
// TODO(a.garipov): Move to aghstrings?
|
||||
func coalesceStr(strs ...string) (res string) {
|
||||
for _, s := range strs {
|
||||
if s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// isWhoisComment returns true if the string is empty or is a WHOIS comment.
|
||||
func isWhoisComment(s string) (ok bool) {
|
||||
return len(s) == 0 || s[0] == '#' || s[0] == '%'
|
||||
|
@ -119,7 +107,7 @@ func whoisParse(data string) (m strmap) {
|
|||
v = trimValue(v)
|
||||
case "descr", "netname":
|
||||
k = "orgname"
|
||||
v = coalesceStr(orgname, v)
|
||||
v = aghstrings.Coalesce(orgname, v)
|
||||
orgname = v
|
||||
case "whois":
|
||||
k = "whois"
|
||||
|
|
Loading…
Reference in a new issue