mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-22 04:55:33 +03:00
Pull request 2080: AG-27539 imp upstream test
Squashed commit of the following:
commit 5a9e8c0c2e4b68c0ff6508c47fbd8abde0d05e95
Merge: 85820c173 c4e69cd96
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Nov 28 16:09:53 2023 +0300
Merge branch 'master' into AG-27539-imp-upstream-test
commit 85820c173dddb6391dabe9615b821b585b1ecdef
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Nov 28 15:48:53 2023 +0300
dnsforward: split code
commit dac0148a4d4780bea19fb7622b46ac08fbf1ee74
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Fri Nov 24 15:47:05 2023 +0300
dnsforward: fix docs
commit 9f0015b255d547f31d34513aa6bb2faf65a39e0e
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Fri Nov 24 14:45:43 2023 +0300
dnsforward: imp code
commit 49fefc373972b7c8991abcb46d7730288b92c24c
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Thu Nov 23 14:12:02 2023 +0300
dnsforward: imp code
commit 120ba4b1f727bba537471c4a8aa4b412eac30f85
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Wed Nov 22 17:02:01 2023 +0300
dnsforward: add tests
commit 70775975ced46191a6ba64504c7bac0e3d1eed7f
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Wed Nov 22 15:48:05 2023 +0300
dnsforward: imp code
commit 9487f1fd62b821efb242267d9972f3ae3785ad19
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Nov 21 18:06:00 2023 +0300
dnsforward: imp code
commit e2612e0e6fd1c9116872939edd0e86f2e9af07d7
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Nov 21 16:12:20 2023 +0300
dnsforward: add ups checker
commit 09db7d2a604809669affbeef2f0536fa6605a39b
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Nov 14 17:31:04 2023 +0300
dnsforward: separate upstream code
This commit is contained in:
parent
c4e69cd961
commit
849abaf25e
6 changed files with 785 additions and 506 deletions
349
internal/dnsforward/configvalidator.go
Normal file
349
internal/dnsforward/configvalidator.go
Normal file
|
@ -0,0 +1,349 @@
|
||||||
|
package dnsforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"golang.org/x/exp/slices"
|
||||||
|
)
|
||||||
|
|
||||||
|
// upstreamConfigValidator parses the [*proxy.UpstreamConfig] and checks the
|
||||||
|
// actual DNS availability of each upstream.
|
||||||
|
type upstreamConfigValidator struct {
|
||||||
|
// general is the general upstream configuration.
|
||||||
|
general []*upstreamResult
|
||||||
|
|
||||||
|
// fallback is the fallback upstream configuration.
|
||||||
|
fallback []*upstreamResult
|
||||||
|
|
||||||
|
// private is the private upstream configuration.
|
||||||
|
private []*upstreamResult
|
||||||
|
}
|
||||||
|
|
||||||
|
// upstreamResult is a result of validation of an [upstream.Upstream] within an
|
||||||
|
// [proxy.UpstreamConfig].
|
||||||
|
type upstreamResult struct {
|
||||||
|
// server is the parsed upstream. It is nil when there was an error during
|
||||||
|
// parsing.
|
||||||
|
server upstream.Upstream
|
||||||
|
|
||||||
|
// err is the error either from parsing or from checking the upstream.
|
||||||
|
err error
|
||||||
|
|
||||||
|
// original is the piece of configuration that have either been turned to an
|
||||||
|
// upstream or caused an error.
|
||||||
|
original string
|
||||||
|
|
||||||
|
// isSpecific is true if the upstream is domain-specific.
|
||||||
|
isSpecific bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// compare compares two [upstreamResult]s. It returns 0 if they are equal, -1
|
||||||
|
// if ur should be sorted before other, and 1 otherwise.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Perhaps it makes sense to sort the results with errors near
|
||||||
|
// the end.
|
||||||
|
func (ur *upstreamResult) compare(other *upstreamResult) (res int) {
|
||||||
|
return strings.Compare(ur.original, other.original)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newUpstreamConfigValidator parses the upstream configuration and returns a
|
||||||
|
// validator for it. cv already contains the parsed upstreams along with errors
|
||||||
|
// related.
|
||||||
|
func newUpstreamConfigValidator(
|
||||||
|
general []string,
|
||||||
|
fallback []string,
|
||||||
|
private []string,
|
||||||
|
opts *upstream.Options,
|
||||||
|
) (cv *upstreamConfigValidator) {
|
||||||
|
cv = &upstreamConfigValidator{}
|
||||||
|
|
||||||
|
for _, line := range general {
|
||||||
|
cv.general = cv.insertLineResults(cv.general, line, opts)
|
||||||
|
}
|
||||||
|
for _, line := range fallback {
|
||||||
|
cv.fallback = cv.insertLineResults(cv.fallback, line, opts)
|
||||||
|
}
|
||||||
|
for _, line := range private {
|
||||||
|
cv.private = cv.insertLineResults(cv.private, line, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cv
|
||||||
|
}
|
||||||
|
|
||||||
|
// insertLineResults parses line and inserts the result into s. It can insert
|
||||||
|
// multiple results as well as none.
|
||||||
|
func (cv *upstreamConfigValidator) insertLineResults(
|
||||||
|
s []*upstreamResult,
|
||||||
|
line string,
|
||||||
|
opts *upstream.Options,
|
||||||
|
) (result []*upstreamResult) {
|
||||||
|
upstreams, isSpecific, err := splitUpstreamLine(line)
|
||||||
|
if err != nil {
|
||||||
|
return cv.insert(s, &upstreamResult{
|
||||||
|
err: err,
|
||||||
|
original: line,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, upstreamAddr := range upstreams {
|
||||||
|
var res *upstreamResult
|
||||||
|
if upstreamAddr != "#" {
|
||||||
|
res = cv.parseUpstream(upstreamAddr, opts)
|
||||||
|
} else if !isSpecific {
|
||||||
|
res = &upstreamResult{
|
||||||
|
err: errNotDomainSpecific,
|
||||||
|
original: upstreamAddr,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
res.isSpecific = isSpecific
|
||||||
|
s = cv.insert(s, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// insert inserts r into slice in a sorted order, except duplicates. slice must
|
||||||
|
// not be nil.
|
||||||
|
func (cv *upstreamConfigValidator) insert(
|
||||||
|
s []*upstreamResult,
|
||||||
|
r *upstreamResult,
|
||||||
|
) (result []*upstreamResult) {
|
||||||
|
i, has := slices.BinarySearchFunc(s, r, (*upstreamResult).compare)
|
||||||
|
if has {
|
||||||
|
log.Debug("dnsforward: duplicate configuration %q", r.original)
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
return slices.Insert(s, i, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseUpstream parses addr and returns the result of parsing. It returns nil
|
||||||
|
// if the specified server points at the default upstream server which is
|
||||||
|
// validated separately.
|
||||||
|
func (cv *upstreamConfigValidator) parseUpstream(
|
||||||
|
addr string,
|
||||||
|
opts *upstream.Options,
|
||||||
|
) (r *upstreamResult) {
|
||||||
|
// Check if the upstream has a valid protocol prefix.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Validate the domain name.
|
||||||
|
if proto, _, ok := strings.Cut(addr, "://"); ok {
|
||||||
|
if !slices.Contains(protocols, proto) {
|
||||||
|
return &upstreamResult{
|
||||||
|
err: fmt.Errorf("bad protocol %q", proto),
|
||||||
|
original: addr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ups, err := upstream.AddressToUpstream(addr, opts)
|
||||||
|
|
||||||
|
return &upstreamResult{
|
||||||
|
server: ups,
|
||||||
|
err: err,
|
||||||
|
original: addr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check tries to exchange with each successfully parsed upstream and enriches
|
||||||
|
// the results with the healthcheck errors. It should not be called after the
|
||||||
|
// [upsConfValidator.close] method, since it makes no sense to check the closed
|
||||||
|
// upstreams.
|
||||||
|
func (cv *upstreamConfigValidator) check() {
|
||||||
|
const (
|
||||||
|
// testTLD is the special-use fully-qualified domain name for testing
|
||||||
|
// the DNS server reachability.
|
||||||
|
//
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc6761#section-6.2.
|
||||||
|
testTLD = "test."
|
||||||
|
|
||||||
|
// inAddrARPATLD is the special-use fully-qualified domain name for PTR
|
||||||
|
// IP address resolution.
|
||||||
|
//
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1035#section-3.5.
|
||||||
|
inAddrARPATLD = "in-addr.arpa."
|
||||||
|
)
|
||||||
|
|
||||||
|
commonChecker := &healthchecker{
|
||||||
|
hostname: testTLD,
|
||||||
|
qtype: dns.TypeA,
|
||||||
|
ansEmpty: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
arpaChecker := &healthchecker{
|
||||||
|
hostname: inAddrARPATLD,
|
||||||
|
qtype: dns.TypePTR,
|
||||||
|
ansEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
wg.Add(len(cv.general) + len(cv.fallback) + len(cv.private))
|
||||||
|
|
||||||
|
for _, res := range cv.general {
|
||||||
|
go cv.checkSrv(res, wg, commonChecker)
|
||||||
|
}
|
||||||
|
for _, res := range cv.fallback {
|
||||||
|
go cv.checkSrv(res, wg, commonChecker)
|
||||||
|
}
|
||||||
|
for _, res := range cv.private {
|
||||||
|
go cv.checkSrv(res, wg, arpaChecker)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkSrv runs hc on the server from res, if any, and stores any occurred
|
||||||
|
// error in res. wg is always marked done in the end. It used to be called in
|
||||||
|
// a separate goroutine.
|
||||||
|
func (cv *upstreamConfigValidator) checkSrv(
|
||||||
|
res *upstreamResult,
|
||||||
|
wg *sync.WaitGroup,
|
||||||
|
hc *healthchecker,
|
||||||
|
) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
if res.server == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res.err = hc.check(res.server)
|
||||||
|
if res.err != nil && res.isSpecific {
|
||||||
|
res.err = domainSpecificTestError{Err: res.err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// close closes all the upstreams that were successfully parsed. It enriches
|
||||||
|
// the results with deferred closing errors.
|
||||||
|
func (cv *upstreamConfigValidator) close() {
|
||||||
|
for _, slice := range [][]*upstreamResult{cv.general, cv.fallback, cv.private} {
|
||||||
|
for _, r := range slice {
|
||||||
|
if r.server != nil {
|
||||||
|
r.err = errors.WithDeferred(r.err, r.server.Close())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// status returns all the data collected during parsing, healthcheck, and
|
||||||
|
// closing of the upstreams. The returned map is keyed by the original upstream
|
||||||
|
// configuration piece and contains the corresponding error or "OK" if there was
|
||||||
|
// no error.
|
||||||
|
func (cv *upstreamConfigValidator) status() (results map[string]string) {
|
||||||
|
result := map[string]string{}
|
||||||
|
|
||||||
|
for _, res := range cv.general {
|
||||||
|
resultToStatus("general", res, result)
|
||||||
|
}
|
||||||
|
for _, res := range cv.fallback {
|
||||||
|
resultToStatus("fallback", res, result)
|
||||||
|
}
|
||||||
|
for _, res := range cv.private {
|
||||||
|
resultToStatus("private", res, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// resultToStatus puts "OK" or an error message from res into resMap. section
|
||||||
|
// is the name of the upstream configuration section, i.e. "general",
|
||||||
|
// "fallback", or "private", and only used for logging.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Currently, the HTTP handler expects that all the results are
|
||||||
|
// put together in a single map, which may lead to collisions, see AG-27539.
|
||||||
|
// Improve the results compilation.
|
||||||
|
func resultToStatus(section string, res *upstreamResult, resMap map[string]string) {
|
||||||
|
val := "OK"
|
||||||
|
if res.err != nil {
|
||||||
|
val = res.err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
prevVal := resMap[res.original]
|
||||||
|
switch prevVal {
|
||||||
|
case "":
|
||||||
|
resMap[res.original] = val
|
||||||
|
case val:
|
||||||
|
log.Debug("dnsforward: duplicating %s config line %q", section, res.original)
|
||||||
|
default:
|
||||||
|
log.Debug(
|
||||||
|
"dnsforward: warning: %s config line %q (%v) had different result %v",
|
||||||
|
section,
|
||||||
|
val,
|
||||||
|
res.original,
|
||||||
|
prevVal,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark
|
||||||
|
// the tested upstream domain-specific and therefore consider its errors
|
||||||
|
// non-critical.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Some common mechanism of distinguishing between errors and
|
||||||
|
// warnings (non-critical errors) is desired.
|
||||||
|
type domainSpecificTestError struct {
|
||||||
|
// Err is the actual error occurred during healthcheck test.
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ error = domainSpecificTestError{}
|
||||||
|
|
||||||
|
// Error implements the [error] interface for domainSpecificTestError.
|
||||||
|
func (err domainSpecificTestError) Error() (msg string) {
|
||||||
|
return fmt.Sprintf("WARNING: %s", err.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ errors.Wrapper = domainSpecificTestError{}
|
||||||
|
|
||||||
|
// Unwrap implements the [errors.Wrapper] interface for domainSpecificTestError.
|
||||||
|
func (err domainSpecificTestError) Unwrap() (wrapped error) {
|
||||||
|
return err.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
// healthchecker checks the upstream's status by exchanging with it.
|
||||||
|
type healthchecker struct {
|
||||||
|
// hostname is the name of the host to put into healthcheck DNS request.
|
||||||
|
hostname string
|
||||||
|
|
||||||
|
// qtype is the type of DNS request to use for healthcheck.
|
||||||
|
qtype uint16
|
||||||
|
|
||||||
|
// ansEmpty defines if the answer section within the response is expected to
|
||||||
|
// be empty.
|
||||||
|
ansEmpty bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// check exchanges with u and validates the response.
|
||||||
|
func (h *healthchecker) check(u upstream.Upstream) (err error) {
|
||||||
|
req := &dns.Msg{
|
||||||
|
MsgHdr: dns.MsgHdr{
|
||||||
|
Id: dns.Id(),
|
||||||
|
RecursionDesired: true,
|
||||||
|
},
|
||||||
|
Question: []dns.Question{{
|
||||||
|
Name: h.hostname,
|
||||||
|
Qtype: h.qtype,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
reply, err := u.Exchange(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("couldn't communicate with upstream: %w", err)
|
||||||
|
} else if h.ansEmpty && len(reply.Answer) > 0 {
|
||||||
|
return errWrongResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -35,6 +35,11 @@ import (
|
||||||
// DefaultTimeout is the default upstream timeout
|
// DefaultTimeout is the default upstream timeout
|
||||||
const DefaultTimeout = 10 * time.Second
|
const DefaultTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
// defaultLocalTimeout is the default timeout for resolving addresses from
|
||||||
|
// locally-served networks. It is assumed that local resolvers should work much
|
||||||
|
// faster than ordinary upstreams.
|
||||||
|
const defaultLocalTimeout = 1 * time.Second
|
||||||
|
|
||||||
// defaultClientIDCacheCount is the default count of items in the LRU ClientID
|
// defaultClientIDCacheCount is the default count of items in the LRU ClientID
|
||||||
// cache. The assumption here is that there won't be more than this many
|
// cache. The assumption here is that there won't be more than this many
|
||||||
// requests between the BeforeRequestHandler stage and the actual processing.
|
// requests between the BeforeRequestHandler stage and the actual processing.
|
||||||
|
@ -459,11 +464,6 @@ func (s *Server) startLocked() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// defaultLocalTimeout is the default timeout for resolving addresses from
|
|
||||||
// locally-served networks. It is assumed that local resolvers should work much
|
|
||||||
// faster than ordinary upstreams.
|
|
||||||
const defaultLocalTimeout = 1 * time.Second
|
|
||||||
|
|
||||||
// setupLocalResolvers initializes the resolvers for local addresses. It
|
// setupLocalResolvers initializes the resolvers for local addresses. It
|
||||||
// assumes s.serverLock is locked or the Server not running.
|
// assumes s.serverLock is locked or the Server not running.
|
||||||
func (s *Server) setupLocalResolvers(boot upstream.Resolver) (err error) {
|
func (s *Server) setupLocalResolvers(boot upstream.Resolver) (err error) {
|
||||||
|
|
|
@ -4,23 +4,17 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/AdguardTeam/golibs/stringutil"
|
"github.com/AdguardTeam/golibs/stringutil"
|
||||||
"github.com/miekg/dns"
|
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -546,365 +540,6 @@ type upstreamJSON struct {
|
||||||
PrivateUpstreams []string `json:"private_upstream"`
|
PrivateUpstreams []string `json:"private_upstream"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
|
|
||||||
// This function is useful for filtering out non-upstream lines from upstream
|
|
||||||
// configs.
|
|
||||||
func IsCommentOrEmpty(s string) (ok bool) {
|
|
||||||
return len(s) == 0 || s[0] == '#'
|
|
||||||
}
|
|
||||||
|
|
||||||
// newUpstreamConfig validates upstreams and returns an appropriate upstream
|
|
||||||
// configuration or nil if it can't be built.
|
|
||||||
//
|
|
||||||
// TODO(e.burkov): Perhaps proxy.ParseUpstreamsConfig should validate upstreams
|
|
||||||
// slice already so that this function may be considered useless.
|
|
||||||
func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err error) {
|
|
||||||
// No need to validate comments and empty lines.
|
|
||||||
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
|
|
||||||
if len(upstreams) == 0 {
|
|
||||||
// Consider this case valid since it means the default server should be
|
|
||||||
// used.
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err = validateUpstreamConfig(upstreams)
|
|
||||||
if err != nil {
|
|
||||||
// Don't wrap the error since it's informative enough as is.
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
conf, err = proxy.ParseUpstreamsConfig(
|
|
||||||
upstreams,
|
|
||||||
&upstream.Options{
|
|
||||||
Bootstrap: net.DefaultResolver,
|
|
||||||
Timeout: DefaultTimeout,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
// Don't wrap the error since it's informative enough as is.
|
|
||||||
return nil, err
|
|
||||||
} else if len(conf.Upstreams) == 0 {
|
|
||||||
return nil, errors.Error("no default upstreams specified")
|
|
||||||
}
|
|
||||||
|
|
||||||
return conf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateUpstreamConfig validates each upstream from the upstream
|
|
||||||
// configuration and returns an error if any upstream is invalid.
|
|
||||||
//
|
|
||||||
// TODO(e.burkov): Move into aghnet or even into dnsproxy.
|
|
||||||
func validateUpstreamConfig(conf []string) (err error) {
|
|
||||||
for _, u := range conf {
|
|
||||||
var ups []string
|
|
||||||
var domains []string
|
|
||||||
ups, domains, err = separateUpstream(u)
|
|
||||||
if err != nil {
|
|
||||||
// Don't wrap the error since it's informative enough as is.
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, addr := range ups {
|
|
||||||
_, err = validateUpstream(addr, domains)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("validating upstream %q: %w", addr, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateUpstreams validates each upstream and returns an error if any
|
|
||||||
// upstream is invalid or if there are no default upstreams specified.
|
|
||||||
//
|
|
||||||
// TODO(e.burkov): Move into aghnet or even into dnsproxy.
|
|
||||||
func ValidateUpstreams(upstreams []string) (err error) {
|
|
||||||
_, err = newUpstreamConfig(upstreams)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
|
|
||||||
// upstream is invalid or if there are no default upstreams specified. It also
|
|
||||||
// checks each domain of domain-specific upstreams for being ARPA pointing to
|
|
||||||
// a locally-served network. privateNets must not be nil.
|
|
||||||
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
|
|
||||||
conf, err := newUpstreamConfig(upstreams)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("creating config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if conf == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
keys := maps.Keys(conf.DomainReservedUpstreams)
|
|
||||||
slices.Sort(keys)
|
|
||||||
|
|
||||||
var errs []error
|
|
||||||
for _, domain := range keys {
|
|
||||||
var subnet netip.Prefix
|
|
||||||
subnet, err = extractARPASubnet(domain)
|
|
||||||
if err != nil {
|
|
||||||
errs = append(errs, err)
|
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if !privateNets.Contains(subnet.Addr().AsSlice()) {
|
|
||||||
errs = append(
|
|
||||||
errs,
|
|
||||||
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w")
|
|
||||||
}
|
|
||||||
|
|
||||||
var protocols = []string{
|
|
||||||
"h3://",
|
|
||||||
"https://",
|
|
||||||
"quic://",
|
|
||||||
"sdns://",
|
|
||||||
"tcp://",
|
|
||||||
"tls://",
|
|
||||||
"udp://",
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateUpstream returns an error if u alongside with domains is not a valid
|
|
||||||
// upstream configuration. useDefault is true if the upstream is
|
|
||||||
// domain-specific and is configured to point at the default upstream server
|
|
||||||
// which is validated separately. The upstream is considered domain-specific
|
|
||||||
// only if domains is at least not nil.
|
|
||||||
func validateUpstream(u string, domains []string) (useDefault bool, err error) {
|
|
||||||
// The special server address '#' means that default server must be used.
|
|
||||||
if useDefault = u == "#" && domains != nil; useDefault {
|
|
||||||
return useDefault, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the upstream has a valid protocol prefix.
|
|
||||||
//
|
|
||||||
// TODO(e.burkov): Validate the domain name.
|
|
||||||
for _, proto := range protocols {
|
|
||||||
if strings.HasPrefix(u, proto) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if proto, _, ok := strings.Cut(u, "://"); ok {
|
|
||||||
return false, fmt.Errorf("bad protocol %q", proto)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if upstream is either an IP or IP with port.
|
|
||||||
if _, err = netip.ParseAddr(u); err == nil {
|
|
||||||
return false, nil
|
|
||||||
} else if _, err = netip.ParseAddrPort(u); err == nil {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// separateUpstream returns the upstreams and the specified domains. domains
|
|
||||||
// is nil when the upstream is not domains-specific. Otherwise it may also be
|
|
||||||
// empty.
|
|
||||||
func separateUpstream(upstreamStr string) (upstreams, domains []string, err error) {
|
|
||||||
if !strings.HasPrefix(upstreamStr, "[/") {
|
|
||||||
return []string{upstreamStr}, nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() { err = errors.Annotate(err, "bad upstream for domain %q: %w", upstreamStr) }()
|
|
||||||
|
|
||||||
parts := strings.Split(upstreamStr[2:], "/]")
|
|
||||||
switch len(parts) {
|
|
||||||
case 2:
|
|
||||||
// Go on.
|
|
||||||
case 1:
|
|
||||||
return nil, nil, errors.Error("missing separator")
|
|
||||||
default:
|
|
||||||
return nil, nil, errors.Error("duplicated separator")
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, host := range strings.Split(parts[0], "/") {
|
|
||||||
if host == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*."))
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("domain at index %d: %w", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
domains = append(domains, host)
|
|
||||||
}
|
|
||||||
|
|
||||||
return strings.Fields(parts[1]), domains, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// healthCheckFunc is a signature of function to check if upstream exchanges
|
|
||||||
// properly.
|
|
||||||
type healthCheckFunc func(u upstream.Upstream) (err error)
|
|
||||||
|
|
||||||
// checkDNSUpstreamExc checks if the DNS upstream exchanges correctly.
|
|
||||||
func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
|
|
||||||
// testTLD is the special-use fully-qualified domain name for testing the
|
|
||||||
// DNS server reachability.
|
|
||||||
//
|
|
||||||
// See https://datatracker.ietf.org/doc/html/rfc6761#section-6.2.
|
|
||||||
const testTLD = "test."
|
|
||||||
|
|
||||||
req := &dns.Msg{
|
|
||||||
MsgHdr: dns.MsgHdr{
|
|
||||||
Id: dns.Id(),
|
|
||||||
RecursionDesired: true,
|
|
||||||
},
|
|
||||||
Question: []dns.Question{{
|
|
||||||
Name: testTLD,
|
|
||||||
Qtype: dns.TypeA,
|
|
||||||
Qclass: dns.ClassINET,
|
|
||||||
}},
|
|
||||||
}
|
|
||||||
|
|
||||||
var reply *dns.Msg
|
|
||||||
reply, err = u.Exchange(req)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("couldn't communicate with upstream: %w", err)
|
|
||||||
} else if len(reply.Answer) != 0 {
|
|
||||||
return errors.Error("wrong response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkPrivateUpstreamExc checks if the upstream for resolving private
|
|
||||||
// addresses exchanges correctly.
|
|
||||||
//
|
|
||||||
// TODO(e.burkov): Think about testing the ip6.arpa. as well.
|
|
||||||
func checkPrivateUpstreamExc(u upstream.Upstream) (err error) {
|
|
||||||
// inAddrArpaTLD is the special-use fully-qualified domain name for PTR IP
|
|
||||||
// address resolution.
|
|
||||||
//
|
|
||||||
// See https://datatracker.ietf.org/doc/html/rfc1035#section-3.5.
|
|
||||||
const inAddrArpaTLD = "in-addr.arpa."
|
|
||||||
|
|
||||||
req := &dns.Msg{
|
|
||||||
MsgHdr: dns.MsgHdr{
|
|
||||||
Id: dns.Id(),
|
|
||||||
RecursionDesired: true,
|
|
||||||
},
|
|
||||||
Question: []dns.Question{{
|
|
||||||
Name: inAddrArpaTLD,
|
|
||||||
Qtype: dns.TypePTR,
|
|
||||||
Qclass: dns.ClassINET,
|
|
||||||
}},
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err = u.Exchange(req); err != nil {
|
|
||||||
return fmt.Errorf("couldn't communicate with upstream: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark
|
|
||||||
// the tested upstream domain-specific and therefore consider its errors
|
|
||||||
// non-critical.
|
|
||||||
//
|
|
||||||
// TODO(a.garipov): Some common mechanism of distinguishing between errors and
|
|
||||||
// warnings (non-critical errors) is desired.
|
|
||||||
type domainSpecificTestError struct {
|
|
||||||
error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error implements the [error] interface for domainSpecificTestError.
|
|
||||||
func (err domainSpecificTestError) Error() (msg string) {
|
|
||||||
return fmt.Sprintf("WARNING: %s", err.error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkDNS parses line, creates DNS upstreams using opts, and checks if the
|
|
||||||
// upstreams are exchanging correctly. It saves the result into a sync.Map
|
|
||||||
// where key is an upstream address and value is "OK", if the upstream
|
|
||||||
// exchanges correctly, or text of the error. It is intended to be used as a
|
|
||||||
// goroutine.
|
|
||||||
//
|
|
||||||
// TODO(s.chzhen): Separate to a different structure/file.
|
|
||||||
func (s *Server) checkDNS(
|
|
||||||
line string,
|
|
||||||
opts *upstream.Options,
|
|
||||||
check healthCheckFunc,
|
|
||||||
wg *sync.WaitGroup,
|
|
||||||
m *sync.Map,
|
|
||||||
) {
|
|
||||||
defer wg.Done()
|
|
||||||
defer log.OnPanic("dnsforward: checking upstreams")
|
|
||||||
|
|
||||||
upstreams, domains, err := separateUpstream(line)
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("wrong upstream format: %w", err)
|
|
||||||
m.Store(line, err.Error())
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
specific := len(domains) > 0
|
|
||||||
|
|
||||||
for _, upstreamAddr := range upstreams {
|
|
||||||
var useDefault bool
|
|
||||||
useDefault, err = validateUpstream(upstreamAddr, domains)
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("wrong upstream format: %w", err)
|
|
||||||
m.Store(upstreamAddr, err.Error())
|
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if useDefault {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
|
|
||||||
|
|
||||||
err = s.checkUpstreamAddr(upstreamAddr, specific, opts, check)
|
|
||||||
if err != nil {
|
|
||||||
m.Store(upstreamAddr, err.Error())
|
|
||||||
} else {
|
|
||||||
m.Store(upstreamAddr, "OK")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkUpstreamAddr creates the DNS upstream using opts and information from
|
|
||||||
// [s.dnsFilter.EtcHosts]. Checks if the DNS upstream exchanges correctly. It
|
|
||||||
// returns an error if addr is not valid DNS upstream address or the upstream
|
|
||||||
// is not exchanging correctly.
|
|
||||||
func (s *Server) checkUpstreamAddr(
|
|
||||||
addr string,
|
|
||||||
specific bool,
|
|
||||||
opts *upstream.Options,
|
|
||||||
check healthCheckFunc,
|
|
||||||
) (err error) {
|
|
||||||
defer func() {
|
|
||||||
if err != nil && specific {
|
|
||||||
err = domainSpecificTestError{error: err}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
u, err := upstream.AddressToUpstream(addr, &upstream.Options{
|
|
||||||
Bootstrap: opts.Bootstrap,
|
|
||||||
Timeout: opts.Timeout,
|
|
||||||
PreferIPv6: opts.PreferIPv6,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("creating upstream for %q: %w", addr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() { err = errors.WithDeferred(err, u.Close()) }()
|
|
||||||
|
|
||||||
return check(u)
|
|
||||||
}
|
|
||||||
|
|
||||||
// closeBoots closes all the provided bootstrap servers and logs errors if any.
|
// closeBoots closes all the provided bootstrap servers and logs errors if any.
|
||||||
func closeBoots(boots []*upstream.UpstreamResolver) {
|
func closeBoots(boots []*upstream.UpstreamResolver) {
|
||||||
for _, c := range boots {
|
for _, c := range boots {
|
||||||
|
@ -942,36 +577,11 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
defer closeBoots(boots)
|
defer closeBoots(boots)
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
cv := newUpstreamConfigValidator(req.Upstreams, req.FallbackDNS, req.PrivateUpstreams, opts)
|
||||||
m := &sync.Map{}
|
cv.check()
|
||||||
|
cv.close()
|
||||||
|
|
||||||
wg.Add(len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams))
|
aghhttp.WriteJSONResponseOK(w, r, cv.status())
|
||||||
|
|
||||||
for _, ups := range req.Upstreams {
|
|
||||||
go s.checkDNS(ups, opts, checkDNSUpstreamExc, wg, m)
|
|
||||||
}
|
|
||||||
for _, ups := range req.FallbackDNS {
|
|
||||||
go s.checkDNS(ups, opts, checkDNSUpstreamExc, wg, m)
|
|
||||||
}
|
|
||||||
for _, ups := range req.PrivateUpstreams {
|
|
||||||
go s.checkDNS(ups, opts, checkPrivateUpstreamExc, wg, m)
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
result := map[string]string{}
|
|
||||||
m.Range(func(k, v any) bool {
|
|
||||||
// TODO(e.burkov): The upstreams used for both common and private
|
|
||||||
// resolving should be reported separately.
|
|
||||||
ups := k.(string)
|
|
||||||
status := v.(string)
|
|
||||||
|
|
||||||
result[ups] = status
|
|
||||||
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
|
|
||||||
aghhttp.WriteJSONResponseOK(w, r, result)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleCacheClear is the handler for the POST /control/cache_clear HTTP API.
|
// handleCacheClear is the handler for the POST /control/cache_clear HTTP API.
|
||||||
|
|
|
@ -363,7 +363,7 @@ func TestValidateUpstreams(t *testing.T) {
|
||||||
set: []string{"123.3.7m"},
|
set: []string{"123.3.7m"},
|
||||||
}, {
|
}, {
|
||||||
name: "invalid",
|
name: "invalid",
|
||||||
wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": ` +
|
wantErr: `splitting upstream line "[/host.com]tls://dns.adguard.com": ` +
|
||||||
`missing separator`,
|
`missing separator`,
|
||||||
set: []string{"[/host.com]tls://dns.adguard.com"},
|
set: []string{"[/host.com]tls://dns.adguard.com"},
|
||||||
}, {
|
}, {
|
||||||
|
@ -389,7 +389,7 @@ func TestValidateUpstreams(t *testing.T) {
|
||||||
},
|
},
|
||||||
}, {
|
}, {
|
||||||
name: "bad_domain",
|
name: "bad_domain",
|
||||||
wantErr: `bad upstream for domain "[/!/]8.8.8.8": domain at index 0: ` +
|
wantErr: `splitting upstream line "[/!/]8.8.8.8": domain at index 0: ` +
|
||||||
`bad domain name "!": bad top-level domain name label "!": ` +
|
`bad domain name "!": bad top-level domain name label "!": ` +
|
||||||
`bad top-level domain name label rune '!'`,
|
`bad top-level domain name label rune '!'`,
|
||||||
set: []string{"[/!/]8.8.8.8"},
|
set: []string{"[/!/]8.8.8.8"},
|
||||||
|
@ -477,25 +477,15 @@ func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (r
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||||
goodHandler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
hdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||||
err := w.WriteMsg(new(dns.Msg).SetReply(m))
|
err := w.WriteMsg(new(dns.Msg).SetReply(m))
|
||||||
require.NoError(testutil.PanicT{}, err)
|
require.NoError(testutil.PanicT{}, err)
|
||||||
})
|
})
|
||||||
badHandler := dns.HandlerFunc(func(w dns.ResponseWriter, _ *dns.Msg) {
|
|
||||||
err := w.WriteMsg(new(dns.Msg))
|
|
||||||
require.NoError(testutil.PanicT{}, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
goodUps := (&url.URL{
|
ups := (&url.URL{
|
||||||
Scheme: "tcp",
|
Scheme: "tcp",
|
||||||
Host: newLocalUpstreamListener(t, 0, goodHandler).String(),
|
Host: newLocalUpstreamListener(t, 0, hdlr).String(),
|
||||||
}).String()
|
}).String()
|
||||||
badUps := (&url.URL{
|
|
||||||
Scheme: "tcp",
|
|
||||||
Host: newLocalUpstreamListener(t, 0, badHandler).String(),
|
|
||||||
}).String()
|
|
||||||
|
|
||||||
goodAndBadUps := strings.Join([]string{goodUps, badUps}, " ")
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
upsTimeout = 100 * time.Millisecond
|
upsTimeout = 100 * time.Millisecond
|
||||||
|
@ -504,7 +494,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||||
upstreamHost = "custom.localhost"
|
upstreamHost = "custom.localhost"
|
||||||
)
|
)
|
||||||
|
|
||||||
hostsListener := newLocalUpstreamListener(t, 0, goodHandler)
|
hostsListener := newLocalUpstreamListener(t, 0, hdlr)
|
||||||
hostsUps := (&url.URL{
|
hostsUps := (&url.URL{
|
||||||
Scheme: "tcp",
|
Scheme: "tcp",
|
||||||
Host: netutil.JoinHostPort(upstreamHost, hostsListener.Port()),
|
Host: netutil.JoinHostPort(upstreamHost, hostsListener.Port()),
|
||||||
|
@ -545,43 +535,6 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||||
wantResp map[string]any
|
wantResp map[string]any
|
||||||
name string
|
name string
|
||||||
}{{
|
}{{
|
||||||
body: map[string]any{
|
|
||||||
"upstream_dns": []string{goodUps},
|
|
||||||
},
|
|
||||||
wantResp: map[string]any{
|
|
||||||
goodUps: "OK",
|
|
||||||
},
|
|
||||||
name: "success",
|
|
||||||
}, {
|
|
||||||
body: map[string]any{
|
|
||||||
"upstream_dns": []string{badUps},
|
|
||||||
},
|
|
||||||
wantResp: map[string]any{
|
|
||||||
badUps: `couldn't communicate with upstream: exchanging with ` +
|
|
||||||
badUps + ` over tcp: dns: id mismatch`,
|
|
||||||
},
|
|
||||||
name: "broken",
|
|
||||||
}, {
|
|
||||||
body: map[string]any{
|
|
||||||
"upstream_dns": []string{goodUps, badUps},
|
|
||||||
},
|
|
||||||
wantResp: map[string]any{
|
|
||||||
goodUps: "OK",
|
|
||||||
badUps: `couldn't communicate with upstream: exchanging with ` +
|
|
||||||
badUps + ` over tcp: dns: id mismatch`,
|
|
||||||
},
|
|
||||||
name: "both",
|
|
||||||
}, {
|
|
||||||
body: map[string]any{
|
|
||||||
"upstream_dns": []string{"[/domain.example/]" + badUps},
|
|
||||||
},
|
|
||||||
wantResp: map[string]any{
|
|
||||||
badUps: `WARNING: couldn't communicate ` +
|
|
||||||
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
|
|
||||||
`dns: id mismatch`,
|
|
||||||
},
|
|
||||||
name: "domain_specific_error",
|
|
||||||
}, {
|
|
||||||
body: map[string]any{
|
body: map[string]any{
|
||||||
"upstream_dns": []string{hostsUps},
|
"upstream_dns": []string{hostsUps},
|
||||||
},
|
},
|
||||||
|
@ -591,63 +544,12 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||||
name: "etc_hosts",
|
name: "etc_hosts",
|
||||||
}, {
|
}, {
|
||||||
body: map[string]any{
|
body: map[string]any{
|
||||||
"fallback_dns": []string{goodUps},
|
"upstream_dns": []string{ups, "#this.is.comment"},
|
||||||
},
|
},
|
||||||
wantResp: map[string]any{
|
wantResp: map[string]any{
|
||||||
goodUps: "OK",
|
ups: "OK",
|
||||||
},
|
},
|
||||||
name: "fallback_success",
|
name: "comment_mix",
|
||||||
}, {
|
|
||||||
body: map[string]any{
|
|
||||||
"fallback_dns": []string{badUps},
|
|
||||||
},
|
|
||||||
wantResp: map[string]any{
|
|
||||||
badUps: `couldn't communicate with upstream: exchanging with ` +
|
|
||||||
badUps + ` over tcp: dns: id mismatch`,
|
|
||||||
},
|
|
||||||
name: "fallback_broken",
|
|
||||||
}, {
|
|
||||||
body: map[string]any{
|
|
||||||
"fallback_dns": []string{goodUps, "#this.is.comment"},
|
|
||||||
},
|
|
||||||
wantResp: map[string]any{
|
|
||||||
goodUps: "OK",
|
|
||||||
},
|
|
||||||
name: "fallback_comment_mix",
|
|
||||||
}, {
|
|
||||||
body: map[string]any{
|
|
||||||
"upstream_dns": []string{"[/domain.example/]" + goodUps + " " + badUps},
|
|
||||||
},
|
|
||||||
wantResp: map[string]any{
|
|
||||||
goodUps: "OK",
|
|
||||||
badUps: `WARNING: couldn't communicate ` +
|
|
||||||
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
|
|
||||||
`dns: id mismatch`,
|
|
||||||
},
|
|
||||||
name: "multiple_domain_specific_upstreams",
|
|
||||||
}, {
|
|
||||||
body: map[string]any{
|
|
||||||
"upstream_dns": []string{"[/domain.example/]/]1.2.3.4"},
|
|
||||||
},
|
|
||||||
wantResp: map[string]any{
|
|
||||||
"[/domain.example/]/]1.2.3.4": `wrong upstream format: ` +
|
|
||||||
`bad upstream for domain "[/domain.example/]/]1.2.3.4": ` +
|
|
||||||
`duplicated separator`,
|
|
||||||
},
|
|
||||||
name: "bad_specification",
|
|
||||||
}, {
|
|
||||||
body: map[string]any{
|
|
||||||
"upstream_dns": []string{"[/domain.example/]" + goodAndBadUps},
|
|
||||||
"fallback_dns": []string{"[/domain.example/]" + goodAndBadUps},
|
|
||||||
"private_upstream": []string{"[/domain.example/]" + goodAndBadUps},
|
|
||||||
},
|
|
||||||
wantResp: map[string]any{
|
|
||||||
goodUps: "OK",
|
|
||||||
badUps: `WARNING: couldn't communicate ` +
|
|
||||||
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
|
|
||||||
`dns: id mismatch`,
|
|
||||||
},
|
|
||||||
name: "all_different",
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
|
|
|
@ -2,14 +2,43 @@ package dnsforward
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/AdguardTeam/golibs/stringutil"
|
"github.com/AdguardTeam/golibs/stringutil"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
"golang.org/x/exp/slices"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// errNotDomainSpecific is returned when the upstream should be
|
||||||
|
// domain-specific, but isn't.
|
||||||
|
errNotDomainSpecific errors.Error = "not a domain-specific upstream"
|
||||||
|
|
||||||
|
// errMissingSeparator is returned when the domain-specific part of the
|
||||||
|
// upstream configuration line isn't closed.
|
||||||
|
errMissingSeparator errors.Error = "missing separator"
|
||||||
|
|
||||||
|
// errDupSeparator is returned when the domain-specific part of the upstream
|
||||||
|
// configuration line contains more than one ending separator.
|
||||||
|
errDupSeparator errors.Error = "duplicated separator"
|
||||||
|
|
||||||
|
// errNoDefaultUpstreams is returned when there are no default upstreams
|
||||||
|
// specified in the upstream configuration.
|
||||||
|
errNoDefaultUpstreams errors.Error = "no default upstreams specified"
|
||||||
|
|
||||||
|
// errWrongResponse is returned when the checked upstream replies in an
|
||||||
|
// unexpected way.
|
||||||
|
errWrongResponse errors.Error = "wrong response"
|
||||||
)
|
)
|
||||||
|
|
||||||
// loadUpstreams parses upstream DNS servers from the configured file or from
|
// loadUpstreams parses upstream DNS servers from the configured file or from
|
||||||
|
@ -158,3 +187,183 @@ func (s *Server) createBootstrap(
|
||||||
|
|
||||||
return r, boots, nil
|
return r, boots, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
|
||||||
|
// This function is useful for filtering out non-upstream lines from upstream
|
||||||
|
// configs.
|
||||||
|
func IsCommentOrEmpty(s string) (ok bool) {
|
||||||
|
return len(s) == 0 || s[0] == '#'
|
||||||
|
}
|
||||||
|
|
||||||
|
// newUpstreamConfig validates upstreams and returns an appropriate upstream
|
||||||
|
// configuration or nil if it can't be built.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Perhaps proxy.ParseUpstreamsConfig should validate upstreams
|
||||||
|
// slice already so that this function may be considered useless.
|
||||||
|
func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err error) {
|
||||||
|
// No need to validate comments and empty lines.
|
||||||
|
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
|
||||||
|
if len(upstreams) == 0 {
|
||||||
|
// Consider this case valid since it means the default server should be
|
||||||
|
// used.
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = validateUpstreamConfig(upstreams)
|
||||||
|
if err != nil {
|
||||||
|
// Don't wrap the error since it's informative enough as is.
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
conf, err = proxy.ParseUpstreamsConfig(
|
||||||
|
upstreams,
|
||||||
|
&upstream.Options{
|
||||||
|
Bootstrap: net.DefaultResolver,
|
||||||
|
Timeout: DefaultTimeout,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
// Don't wrap the error since it's informative enough as is.
|
||||||
|
return nil, err
|
||||||
|
} else if len(conf.Upstreams) == 0 {
|
||||||
|
return nil, errNoDefaultUpstreams
|
||||||
|
}
|
||||||
|
|
||||||
|
return conf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateUpstreamConfig validates each upstream from the upstream
|
||||||
|
// configuration and returns an error if any upstream is invalid.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Merge with [upstreamConfigValidator] somehow.
|
||||||
|
func validateUpstreamConfig(conf []string) (err error) {
|
||||||
|
for _, u := range conf {
|
||||||
|
var ups []string
|
||||||
|
var isSpecific bool
|
||||||
|
ups, isSpecific, err = splitUpstreamLine(u)
|
||||||
|
if err != nil {
|
||||||
|
// Don't wrap the error since it's informative enough as is.
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addr := range ups {
|
||||||
|
_, err = validateUpstream(addr, isSpecific)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("validating upstream %q: %w", addr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateUpstreams validates each upstream and returns an error if any
|
||||||
|
// upstream is invalid or if there are no default upstreams specified.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Merge with [upstreamConfigValidator] somehow.
|
||||||
|
func ValidateUpstreams(upstreams []string) (err error) {
|
||||||
|
_, err = newUpstreamConfig(upstreams)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
|
||||||
|
// upstream is invalid or if there are no default upstreams specified. It also
|
||||||
|
// checks each domain of domain-specific upstreams for being ARPA pointing to
|
||||||
|
// a locally-served network. privateNets must not be nil.
|
||||||
|
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
|
||||||
|
conf, err := newUpstreamConfig(upstreams)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if conf == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := maps.Keys(conf.DomainReservedUpstreams)
|
||||||
|
slices.Sort(keys)
|
||||||
|
|
||||||
|
var errs []error
|
||||||
|
for _, domain := range keys {
|
||||||
|
var subnet netip.Prefix
|
||||||
|
subnet, err = extractARPASubnet(domain)
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, err)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !privateNets.Contains(subnet.Addr().AsSlice()) {
|
||||||
|
errs = append(
|
||||||
|
errs,
|
||||||
|
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w")
|
||||||
|
}
|
||||||
|
|
||||||
|
// protocols are the supported URL schemes for upstreams.
|
||||||
|
var protocols = []string{"h3", "https", "quic", "sdns", "tcp", "tls", "udp"}
|
||||||
|
|
||||||
|
// validateUpstream returns an error if u alongside with domains is not a valid
|
||||||
|
// upstream configuration. useDefault is true if the upstream is
|
||||||
|
// domain-specific and is configured to point at the default upstream server
|
||||||
|
// which is validated separately. The upstream is considered domain-specific
|
||||||
|
// only if domains is at least not nil.
|
||||||
|
func validateUpstream(u string, isSpecific bool) (useDefault bool, err error) {
|
||||||
|
// The special server address '#' means that default server must be used.
|
||||||
|
if useDefault = u == "#" && isSpecific; useDefault {
|
||||||
|
return useDefault, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the upstream has a valid protocol prefix.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Validate the domain name.
|
||||||
|
if proto, _, ok := strings.Cut(u, "://"); ok {
|
||||||
|
if !slices.Contains(protocols, proto) {
|
||||||
|
return false, fmt.Errorf("bad protocol %q", proto)
|
||||||
|
}
|
||||||
|
} else if _, err = netip.ParseAddr(u); err == nil {
|
||||||
|
return false, nil
|
||||||
|
} else if _, err = netip.ParseAddrPort(u); err == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitUpstreamLine returns the upstreams and the specified domains. domains
|
||||||
|
// is nil when the upstream is not domains-specific. Otherwise it may also be
|
||||||
|
// empty.
|
||||||
|
func splitUpstreamLine(upstreamStr string) (upstreams []string, isSpecific bool, err error) {
|
||||||
|
if !strings.HasPrefix(upstreamStr, "[/") {
|
||||||
|
return []string{upstreamStr}, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() { err = errors.Annotate(err, "splitting upstream line %q: %w", upstreamStr) }()
|
||||||
|
|
||||||
|
doms, ups, found := strings.Cut(upstreamStr[2:], "/]")
|
||||||
|
if !found {
|
||||||
|
return nil, false, errMissingSeparator
|
||||||
|
} else if strings.Contains(ups, "/]") {
|
||||||
|
return nil, false, errDupSeparator
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, host := range strings.Split(doms, "/") {
|
||||||
|
if host == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*."))
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, fmt.Errorf("domain at index %d: %w", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
isSpecific = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Fields(ups), isSpecific, nil
|
||||||
|
}
|
||||||
|
|
209
internal/dnsforward/upstreams_internal_test.go
Normal file
209
internal/dnsforward/upstreams_internal_test.go
Normal file
|
@ -0,0 +1,209 @@
|
||||||
|
package dnsforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUpstreamConfigValidator(t *testing.T) {
|
||||||
|
goodHandler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||||
|
err := w.WriteMsg(new(dns.Msg).SetReply(m))
|
||||||
|
require.NoError(testutil.PanicT{}, err)
|
||||||
|
})
|
||||||
|
badHandler := dns.HandlerFunc(func(w dns.ResponseWriter, _ *dns.Msg) {
|
||||||
|
err := w.WriteMsg(new(dns.Msg))
|
||||||
|
require.NoError(testutil.PanicT{}, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
goodUps := (&url.URL{
|
||||||
|
Scheme: "tcp",
|
||||||
|
Host: newLocalUpstreamListener(t, 0, goodHandler).String(),
|
||||||
|
}).String()
|
||||||
|
badUps := (&url.URL{
|
||||||
|
Scheme: "tcp",
|
||||||
|
Host: newLocalUpstreamListener(t, 0, badHandler).String(),
|
||||||
|
}).String()
|
||||||
|
|
||||||
|
goodAndBadUps := strings.Join([]string{goodUps, badUps}, " ")
|
||||||
|
|
||||||
|
// upsTimeout restricts the checking process to prevent the test from
|
||||||
|
// hanging.
|
||||||
|
const upsTimeout = 100 * time.Millisecond
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
want map[string]string
|
||||||
|
name string
|
||||||
|
general []string
|
||||||
|
fallback []string
|
||||||
|
private []string
|
||||||
|
}{{
|
||||||
|
name: "success",
|
||||||
|
general: []string{goodUps},
|
||||||
|
want: map[string]string{
|
||||||
|
goodUps: "OK",
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "broken",
|
||||||
|
general: []string{badUps},
|
||||||
|
want: map[string]string{
|
||||||
|
badUps: `couldn't communicate with upstream: exchanging with ` +
|
||||||
|
badUps + ` over tcp: dns: id mismatch`,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "both",
|
||||||
|
general: []string{goodUps, badUps, goodUps},
|
||||||
|
want: map[string]string{
|
||||||
|
goodUps: "OK",
|
||||||
|
badUps: `couldn't communicate with upstream: exchanging with ` +
|
||||||
|
badUps + ` over tcp: dns: id mismatch`,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "domain_specific_error",
|
||||||
|
general: []string{"[/domain.example/]" + badUps},
|
||||||
|
want: map[string]string{
|
||||||
|
badUps: `WARNING: couldn't communicate ` +
|
||||||
|
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
|
||||||
|
`dns: id mismatch`,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "fallback_success",
|
||||||
|
fallback: []string{goodUps},
|
||||||
|
want: map[string]string{
|
||||||
|
goodUps: "OK",
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "fallback_broken",
|
||||||
|
fallback: []string{badUps},
|
||||||
|
want: map[string]string{
|
||||||
|
badUps: `couldn't communicate with upstream: exchanging with ` +
|
||||||
|
badUps + ` over tcp: dns: id mismatch`,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "multiple_domain_specific_upstreams",
|
||||||
|
general: []string{"[/domain.example/]" + goodAndBadUps},
|
||||||
|
want: map[string]string{
|
||||||
|
goodUps: "OK",
|
||||||
|
badUps: `WARNING: couldn't communicate ` +
|
||||||
|
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
|
||||||
|
`dns: id mismatch`,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "bad_specification",
|
||||||
|
general: []string{"[/domain.example/]/]1.2.3.4"},
|
||||||
|
want: map[string]string{
|
||||||
|
"[/domain.example/]/]1.2.3.4": `splitting upstream line ` +
|
||||||
|
`"[/domain.example/]/]1.2.3.4": duplicated separator`,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "all_different",
|
||||||
|
general: []string{"[/domain.example/]" + goodAndBadUps},
|
||||||
|
fallback: []string{"[/domain.example/]" + goodAndBadUps},
|
||||||
|
private: []string{"[/domain.example/]" + goodAndBadUps},
|
||||||
|
want: map[string]string{
|
||||||
|
goodUps: "OK",
|
||||||
|
badUps: `WARNING: couldn't communicate ` +
|
||||||
|
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
|
||||||
|
`dns: id mismatch`,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "bad_specific_domains",
|
||||||
|
general: []string{"[/example/]/]" + goodUps},
|
||||||
|
fallback: []string{"[/example/" + goodUps},
|
||||||
|
private: []string{"[/example//bad.123/]" + goodUps},
|
||||||
|
want: map[string]string{
|
||||||
|
`[/example/]/]` + goodUps: `splitting upstream line ` +
|
||||||
|
`"[/example/]/]` + goodUps + `": duplicated separator`,
|
||||||
|
`[/example/` + goodUps: `splitting upstream line ` +
|
||||||
|
`"[/example/` + goodUps + `": missing separator`,
|
||||||
|
`[/example//bad.123/]` + goodUps: `splitting upstream line ` +
|
||||||
|
`"[/example//bad.123/]` + goodUps + `": domain at index 2: ` +
|
||||||
|
`bad domain name "bad.123": ` +
|
||||||
|
`bad top-level domain name label "123": all octets are numeric`,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "non-specific_default",
|
||||||
|
general: []string{
|
||||||
|
"#",
|
||||||
|
"[/example/]#",
|
||||||
|
},
|
||||||
|
want: map[string]string{
|
||||||
|
"#": "not a domain-specific upstream",
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "bad_proto",
|
||||||
|
general: []string{
|
||||||
|
"bad://1.2.3.4",
|
||||||
|
},
|
||||||
|
want: map[string]string{
|
||||||
|
"bad://1.2.3.4": `bad protocol "bad"`,
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
cv := newUpstreamConfigValidator(tc.general, tc.fallback, tc.private, &upstream.Options{
|
||||||
|
Timeout: upsTimeout,
|
||||||
|
Bootstrap: net.DefaultResolver,
|
||||||
|
})
|
||||||
|
cv.check()
|
||||||
|
cv.close()
|
||||||
|
|
||||||
|
assert.Equal(t, tc.want, cv.status())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpstreamConfigValidator_Check_once(t *testing.T) {
|
||||||
|
reqs := atomic.Int32{}
|
||||||
|
reset := func() { reqs.Store(0) }
|
||||||
|
|
||||||
|
hdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||||
|
err := w.WriteMsg(new(dns.Msg).SetReply(m))
|
||||||
|
require.NoError(testutil.PanicT{}, err)
|
||||||
|
reqs.Add(1)
|
||||||
|
})
|
||||||
|
|
||||||
|
addr := (&url.URL{
|
||||||
|
Scheme: "tcp",
|
||||||
|
Host: newLocalUpstreamListener(t, 0, hdlr).String(),
|
||||||
|
}).String()
|
||||||
|
twoAddrs := strings.Join([]string{addr, addr}, " ")
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
ups []string
|
||||||
|
}{{
|
||||||
|
name: "common",
|
||||||
|
ups: []string{addr, addr, addr},
|
||||||
|
}, {
|
||||||
|
name: "domain-specific",
|
||||||
|
ups: []string{"[/one.example/]" + addr, "[/two.example/]" + twoAddrs},
|
||||||
|
}, {
|
||||||
|
name: "both",
|
||||||
|
ups: []string{addr, "[/one.example/]" + addr, addr, "[/two.example/]" + twoAddrs},
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Cleanup(reset)
|
||||||
|
|
||||||
|
cv := newUpstreamConfigValidator(tc.ups, nil, nil, &upstream.Options{
|
||||||
|
Timeout: 100 * time.Millisecond,
|
||||||
|
})
|
||||||
|
cv.check()
|
||||||
|
cv.close()
|
||||||
|
|
||||||
|
assert.Equal(t, int32(1), reqs.Load())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue