diff --git a/internal/dnsforward/configvalidator.go b/internal/dnsforward/configvalidator.go new file mode 100644 index 00000000..b55f53cb --- /dev/null +++ b/internal/dnsforward/configvalidator.go @@ -0,0 +1,349 @@ +package dnsforward + +import ( + "fmt" + "strings" + "sync" + + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/log" + "github.com/miekg/dns" + "golang.org/x/exp/slices" +) + +// upstreamConfigValidator parses the [*proxy.UpstreamConfig] and checks the +// actual DNS availability of each upstream. +type upstreamConfigValidator struct { + // general is the general upstream configuration. + general []*upstreamResult + + // fallback is the fallback upstream configuration. + fallback []*upstreamResult + + // private is the private upstream configuration. + private []*upstreamResult +} + +// upstreamResult is a result of validation of an [upstream.Upstream] within an +// [proxy.UpstreamConfig]. +type upstreamResult struct { + // server is the parsed upstream. It is nil when there was an error during + // parsing. + server upstream.Upstream + + // err is the error either from parsing or from checking the upstream. + err error + + // original is the piece of configuration that have either been turned to an + // upstream or caused an error. + original string + + // isSpecific is true if the upstream is domain-specific. + isSpecific bool +} + +// compare compares two [upstreamResult]s. It returns 0 if they are equal, -1 +// if ur should be sorted before other, and 1 otherwise. +// +// TODO(e.burkov): Perhaps it makes sense to sort the results with errors near +// the end. +func (ur *upstreamResult) compare(other *upstreamResult) (res int) { + return strings.Compare(ur.original, other.original) +} + +// newUpstreamConfigValidator parses the upstream configuration and returns a +// validator for it. cv already contains the parsed upstreams along with errors +// related. +func newUpstreamConfigValidator( + general []string, + fallback []string, + private []string, + opts *upstream.Options, +) (cv *upstreamConfigValidator) { + cv = &upstreamConfigValidator{} + + for _, line := range general { + cv.general = cv.insertLineResults(cv.general, line, opts) + } + for _, line := range fallback { + cv.fallback = cv.insertLineResults(cv.fallback, line, opts) + } + for _, line := range private { + cv.private = cv.insertLineResults(cv.private, line, opts) + } + + return cv +} + +// insertLineResults parses line and inserts the result into s. It can insert +// multiple results as well as none. +func (cv *upstreamConfigValidator) insertLineResults( + s []*upstreamResult, + line string, + opts *upstream.Options, +) (result []*upstreamResult) { + upstreams, isSpecific, err := splitUpstreamLine(line) + if err != nil { + return cv.insert(s, &upstreamResult{ + err: err, + original: line, + }) + } + + for _, upstreamAddr := range upstreams { + var res *upstreamResult + if upstreamAddr != "#" { + res = cv.parseUpstream(upstreamAddr, opts) + } else if !isSpecific { + res = &upstreamResult{ + err: errNotDomainSpecific, + original: upstreamAddr, + } + } else { + continue + } + + res.isSpecific = isSpecific + s = cv.insert(s, res) + } + + return s +} + +// insert inserts r into slice in a sorted order, except duplicates. slice must +// not be nil. +func (cv *upstreamConfigValidator) insert( + s []*upstreamResult, + r *upstreamResult, +) (result []*upstreamResult) { + i, has := slices.BinarySearchFunc(s, r, (*upstreamResult).compare) + if has { + log.Debug("dnsforward: duplicate configuration %q", r.original) + + return s + } + + return slices.Insert(s, i, r) +} + +// parseUpstream parses addr and returns the result of parsing. It returns nil +// if the specified server points at the default upstream server which is +// validated separately. +func (cv *upstreamConfigValidator) parseUpstream( + addr string, + opts *upstream.Options, +) (r *upstreamResult) { + // Check if the upstream has a valid protocol prefix. + // + // TODO(e.burkov): Validate the domain name. + if proto, _, ok := strings.Cut(addr, "://"); ok { + if !slices.Contains(protocols, proto) { + return &upstreamResult{ + err: fmt.Errorf("bad protocol %q", proto), + original: addr, + } + } + } + + ups, err := upstream.AddressToUpstream(addr, opts) + + return &upstreamResult{ + server: ups, + err: err, + original: addr, + } +} + +// check tries to exchange with each successfully parsed upstream and enriches +// the results with the healthcheck errors. It should not be called after the +// [upsConfValidator.close] method, since it makes no sense to check the closed +// upstreams. +func (cv *upstreamConfigValidator) check() { + const ( + // testTLD is the special-use fully-qualified domain name for testing + // the DNS server reachability. + // + // See https://datatracker.ietf.org/doc/html/rfc6761#section-6.2. + testTLD = "test." + + // inAddrARPATLD is the special-use fully-qualified domain name for PTR + // IP address resolution. + // + // See https://datatracker.ietf.org/doc/html/rfc1035#section-3.5. + inAddrARPATLD = "in-addr.arpa." + ) + + commonChecker := &healthchecker{ + hostname: testTLD, + qtype: dns.TypeA, + ansEmpty: true, + } + + arpaChecker := &healthchecker{ + hostname: inAddrARPATLD, + qtype: dns.TypePTR, + ansEmpty: false, + } + + wg := &sync.WaitGroup{} + wg.Add(len(cv.general) + len(cv.fallback) + len(cv.private)) + + for _, res := range cv.general { + go cv.checkSrv(res, wg, commonChecker) + } + for _, res := range cv.fallback { + go cv.checkSrv(res, wg, commonChecker) + } + for _, res := range cv.private { + go cv.checkSrv(res, wg, arpaChecker) + } + + wg.Wait() +} + +// checkSrv runs hc on the server from res, if any, and stores any occurred +// error in res. wg is always marked done in the end. It used to be called in +// a separate goroutine. +func (cv *upstreamConfigValidator) checkSrv( + res *upstreamResult, + wg *sync.WaitGroup, + hc *healthchecker, +) { + defer wg.Done() + + if res.server == nil { + return + } + + res.err = hc.check(res.server) + if res.err != nil && res.isSpecific { + res.err = domainSpecificTestError{Err: res.err} + } +} + +// close closes all the upstreams that were successfully parsed. It enriches +// the results with deferred closing errors. +func (cv *upstreamConfigValidator) close() { + for _, slice := range [][]*upstreamResult{cv.general, cv.fallback, cv.private} { + for _, r := range slice { + if r.server != nil { + r.err = errors.WithDeferred(r.err, r.server.Close()) + } + } + } +} + +// status returns all the data collected during parsing, healthcheck, and +// closing of the upstreams. The returned map is keyed by the original upstream +// configuration piece and contains the corresponding error or "OK" if there was +// no error. +func (cv *upstreamConfigValidator) status() (results map[string]string) { + result := map[string]string{} + + for _, res := range cv.general { + resultToStatus("general", res, result) + } + for _, res := range cv.fallback { + resultToStatus("fallback", res, result) + } + for _, res := range cv.private { + resultToStatus("private", res, result) + } + + return result +} + +// resultToStatus puts "OK" or an error message from res into resMap. section +// is the name of the upstream configuration section, i.e. "general", +// "fallback", or "private", and only used for logging. +// +// TODO(e.burkov): Currently, the HTTP handler expects that all the results are +// put together in a single map, which may lead to collisions, see AG-27539. +// Improve the results compilation. +func resultToStatus(section string, res *upstreamResult, resMap map[string]string) { + val := "OK" + if res.err != nil { + val = res.err.Error() + } + + prevVal := resMap[res.original] + switch prevVal { + case "": + resMap[res.original] = val + case val: + log.Debug("dnsforward: duplicating %s config line %q", section, res.original) + default: + log.Debug( + "dnsforward: warning: %s config line %q (%v) had different result %v", + section, + val, + res.original, + prevVal, + ) + } +} + +// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark +// the tested upstream domain-specific and therefore consider its errors +// non-critical. +// +// TODO(a.garipov): Some common mechanism of distinguishing between errors and +// warnings (non-critical errors) is desired. +type domainSpecificTestError struct { + // Err is the actual error occurred during healthcheck test. + Err error +} + +// type check +var _ error = domainSpecificTestError{} + +// Error implements the [error] interface for domainSpecificTestError. +func (err domainSpecificTestError) Error() (msg string) { + return fmt.Sprintf("WARNING: %s", err.Err) +} + +// type check +var _ errors.Wrapper = domainSpecificTestError{} + +// Unwrap implements the [errors.Wrapper] interface for domainSpecificTestError. +func (err domainSpecificTestError) Unwrap() (wrapped error) { + return err.Err +} + +// healthchecker checks the upstream's status by exchanging with it. +type healthchecker struct { + // hostname is the name of the host to put into healthcheck DNS request. + hostname string + + // qtype is the type of DNS request to use for healthcheck. + qtype uint16 + + // ansEmpty defines if the answer section within the response is expected to + // be empty. + ansEmpty bool +} + +// check exchanges with u and validates the response. +func (h *healthchecker) check(u upstream.Upstream) (err error) { + req := &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: dns.Id(), + RecursionDesired: true, + }, + Question: []dns.Question{{ + Name: h.hostname, + Qtype: h.qtype, + Qclass: dns.ClassINET, + }}, + } + + reply, err := u.Exchange(req) + if err != nil { + return fmt.Errorf("couldn't communicate with upstream: %w", err) + } else if h.ansEmpty && len(reply.Answer) > 0 { + return errWrongResponse + } + + return nil +} diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index d3bca111..d21346a7 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -35,6 +35,11 @@ import ( // DefaultTimeout is the default upstream timeout const DefaultTimeout = 10 * time.Second +// defaultLocalTimeout is the default timeout for resolving addresses from +// locally-served networks. It is assumed that local resolvers should work much +// faster than ordinary upstreams. +const defaultLocalTimeout = 1 * time.Second + // defaultClientIDCacheCount is the default count of items in the LRU ClientID // cache. The assumption here is that there won't be more than this many // requests between the BeforeRequestHandler stage and the actual processing. @@ -459,11 +464,6 @@ func (s *Server) startLocked() error { return err } -// defaultLocalTimeout is the default timeout for resolving addresses from -// locally-served networks. It is assumed that local resolvers should work much -// faster than ordinary upstreams. -const defaultLocalTimeout = 1 * time.Second - // setupLocalResolvers initializes the resolvers for local addresses. It // assumes s.serverLock is locked or the Server not running. func (s *Server) setupLocalResolvers(boot upstream.Resolver) (err error) { diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index 53874578..9544e22c 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -4,23 +4,17 @@ import ( "encoding/json" "fmt" "io" - "net" "net/http" "net/netip" - "strings" - "sync" "time" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/filtering" - "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" - "github.com/miekg/dns" - "golang.org/x/exp/maps" "golang.org/x/exp/slices" ) @@ -546,365 +540,6 @@ type upstreamJSON struct { PrivateUpstreams []string `json:"private_upstream"` } -// IsCommentOrEmpty returns true if s starts with a "#" character or is empty. -// This function is useful for filtering out non-upstream lines from upstream -// configs. -func IsCommentOrEmpty(s string) (ok bool) { - return len(s) == 0 || s[0] == '#' -} - -// newUpstreamConfig validates upstreams and returns an appropriate upstream -// configuration or nil if it can't be built. -// -// TODO(e.burkov): Perhaps proxy.ParseUpstreamsConfig should validate upstreams -// slice already so that this function may be considered useless. -func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err error) { - // No need to validate comments and empty lines. - upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty) - if len(upstreams) == 0 { - // Consider this case valid since it means the default server should be - // used. - return nil, nil - } - - err = validateUpstreamConfig(upstreams) - if err != nil { - // Don't wrap the error since it's informative enough as is. - return nil, err - } - - conf, err = proxy.ParseUpstreamsConfig( - upstreams, - &upstream.Options{ - Bootstrap: net.DefaultResolver, - Timeout: DefaultTimeout, - }, - ) - if err != nil { - // Don't wrap the error since it's informative enough as is. - return nil, err - } else if len(conf.Upstreams) == 0 { - return nil, errors.Error("no default upstreams specified") - } - - return conf, nil -} - -// validateUpstreamConfig validates each upstream from the upstream -// configuration and returns an error if any upstream is invalid. -// -// TODO(e.burkov): Move into aghnet or even into dnsproxy. -func validateUpstreamConfig(conf []string) (err error) { - for _, u := range conf { - var ups []string - var domains []string - ups, domains, err = separateUpstream(u) - if err != nil { - // Don't wrap the error since it's informative enough as is. - return err - } - - for _, addr := range ups { - _, err = validateUpstream(addr, domains) - if err != nil { - return fmt.Errorf("validating upstream %q: %w", addr, err) - } - } - } - - return nil -} - -// ValidateUpstreams validates each upstream and returns an error if any -// upstream is invalid or if there are no default upstreams specified. -// -// TODO(e.burkov): Move into aghnet or even into dnsproxy. -func ValidateUpstreams(upstreams []string) (err error) { - _, err = newUpstreamConfig(upstreams) - - return err -} - -// ValidateUpstreamsPrivate validates each upstream and returns an error if any -// upstream is invalid or if there are no default upstreams specified. It also -// checks each domain of domain-specific upstreams for being ARPA pointing to -// a locally-served network. privateNets must not be nil. -func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) { - conf, err := newUpstreamConfig(upstreams) - if err != nil { - return fmt.Errorf("creating config: %w", err) - } - - if conf == nil { - return nil - } - - keys := maps.Keys(conf.DomainReservedUpstreams) - slices.Sort(keys) - - var errs []error - for _, domain := range keys { - var subnet netip.Prefix - subnet, err = extractARPASubnet(domain) - if err != nil { - errs = append(errs, err) - - continue - } - - if !privateNets.Contains(subnet.Addr().AsSlice()) { - errs = append( - errs, - fmt.Errorf("arpa domain %q should point to a locally-served network", domain), - ) - } - } - - return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w") -} - -var protocols = []string{ - "h3://", - "https://", - "quic://", - "sdns://", - "tcp://", - "tls://", - "udp://", -} - -// validateUpstream returns an error if u alongside with domains is not a valid -// upstream configuration. useDefault is true if the upstream is -// domain-specific and is configured to point at the default upstream server -// which is validated separately. The upstream is considered domain-specific -// only if domains is at least not nil. -func validateUpstream(u string, domains []string) (useDefault bool, err error) { - // The special server address '#' means that default server must be used. - if useDefault = u == "#" && domains != nil; useDefault { - return useDefault, nil - } - - // Check if the upstream has a valid protocol prefix. - // - // TODO(e.burkov): Validate the domain name. - for _, proto := range protocols { - if strings.HasPrefix(u, proto) { - return false, nil - } - } - - if proto, _, ok := strings.Cut(u, "://"); ok { - return false, fmt.Errorf("bad protocol %q", proto) - } - - // Check if upstream is either an IP or IP with port. - if _, err = netip.ParseAddr(u); err == nil { - return false, nil - } else if _, err = netip.ParseAddrPort(u); err == nil { - return false, nil - } - - return false, err -} - -// separateUpstream returns the upstreams and the specified domains. domains -// is nil when the upstream is not domains-specific. Otherwise it may also be -// empty. -func separateUpstream(upstreamStr string) (upstreams, domains []string, err error) { - if !strings.HasPrefix(upstreamStr, "[/") { - return []string{upstreamStr}, nil, nil - } - - defer func() { err = errors.Annotate(err, "bad upstream for domain %q: %w", upstreamStr) }() - - parts := strings.Split(upstreamStr[2:], "/]") - switch len(parts) { - case 2: - // Go on. - case 1: - return nil, nil, errors.Error("missing separator") - default: - return nil, nil, errors.Error("duplicated separator") - } - - for i, host := range strings.Split(parts[0], "/") { - if host == "" { - continue - } - - err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*.")) - if err != nil { - return nil, nil, fmt.Errorf("domain at index %d: %w", i, err) - } - - domains = append(domains, host) - } - - return strings.Fields(parts[1]), domains, nil -} - -// healthCheckFunc is a signature of function to check if upstream exchanges -// properly. -type healthCheckFunc func(u upstream.Upstream) (err error) - -// checkDNSUpstreamExc checks if the DNS upstream exchanges correctly. -func checkDNSUpstreamExc(u upstream.Upstream) (err error) { - // testTLD is the special-use fully-qualified domain name for testing the - // DNS server reachability. - // - // See https://datatracker.ietf.org/doc/html/rfc6761#section-6.2. - const testTLD = "test." - - req := &dns.Msg{ - MsgHdr: dns.MsgHdr{ - Id: dns.Id(), - RecursionDesired: true, - }, - Question: []dns.Question{{ - Name: testTLD, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - }}, - } - - var reply *dns.Msg - reply, err = u.Exchange(req) - if err != nil { - return fmt.Errorf("couldn't communicate with upstream: %w", err) - } else if len(reply.Answer) != 0 { - return errors.Error("wrong response") - } - - return nil -} - -// checkPrivateUpstreamExc checks if the upstream for resolving private -// addresses exchanges correctly. -// -// TODO(e.burkov): Think about testing the ip6.arpa. as well. -func checkPrivateUpstreamExc(u upstream.Upstream) (err error) { - // inAddrArpaTLD is the special-use fully-qualified domain name for PTR IP - // address resolution. - // - // See https://datatracker.ietf.org/doc/html/rfc1035#section-3.5. - const inAddrArpaTLD = "in-addr.arpa." - - req := &dns.Msg{ - MsgHdr: dns.MsgHdr{ - Id: dns.Id(), - RecursionDesired: true, - }, - Question: []dns.Question{{ - Name: inAddrArpaTLD, - Qtype: dns.TypePTR, - Qclass: dns.ClassINET, - }}, - } - - if _, err = u.Exchange(req); err != nil { - return fmt.Errorf("couldn't communicate with upstream: %w", err) - } - - return nil -} - -// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark -// the tested upstream domain-specific and therefore consider its errors -// non-critical. -// -// TODO(a.garipov): Some common mechanism of distinguishing between errors and -// warnings (non-critical errors) is desired. -type domainSpecificTestError struct { - error -} - -// Error implements the [error] interface for domainSpecificTestError. -func (err domainSpecificTestError) Error() (msg string) { - return fmt.Sprintf("WARNING: %s", err.error) -} - -// checkDNS parses line, creates DNS upstreams using opts, and checks if the -// upstreams are exchanging correctly. It saves the result into a sync.Map -// where key is an upstream address and value is "OK", if the upstream -// exchanges correctly, or text of the error. It is intended to be used as a -// goroutine. -// -// TODO(s.chzhen): Separate to a different structure/file. -func (s *Server) checkDNS( - line string, - opts *upstream.Options, - check healthCheckFunc, - wg *sync.WaitGroup, - m *sync.Map, -) { - defer wg.Done() - defer log.OnPanic("dnsforward: checking upstreams") - - upstreams, domains, err := separateUpstream(line) - if err != nil { - err = fmt.Errorf("wrong upstream format: %w", err) - m.Store(line, err.Error()) - - return - } - - specific := len(domains) > 0 - - for _, upstreamAddr := range upstreams { - var useDefault bool - useDefault, err = validateUpstream(upstreamAddr, domains) - if err != nil { - err = fmt.Errorf("wrong upstream format: %w", err) - m.Store(upstreamAddr, err.Error()) - - continue - } - - if useDefault { - continue - } - - log.Debug("dnsforward: checking if upstream %q works", upstreamAddr) - - err = s.checkUpstreamAddr(upstreamAddr, specific, opts, check) - if err != nil { - m.Store(upstreamAddr, err.Error()) - } else { - m.Store(upstreamAddr, "OK") - } - } -} - -// checkUpstreamAddr creates the DNS upstream using opts and information from -// [s.dnsFilter.EtcHosts]. Checks if the DNS upstream exchanges correctly. It -// returns an error if addr is not valid DNS upstream address or the upstream -// is not exchanging correctly. -func (s *Server) checkUpstreamAddr( - addr string, - specific bool, - opts *upstream.Options, - check healthCheckFunc, -) (err error) { - defer func() { - if err != nil && specific { - err = domainSpecificTestError{error: err} - } - }() - - u, err := upstream.AddressToUpstream(addr, &upstream.Options{ - Bootstrap: opts.Bootstrap, - Timeout: opts.Timeout, - PreferIPv6: opts.PreferIPv6, - }) - if err != nil { - return fmt.Errorf("creating upstream for %q: %w", addr, err) - } - - defer func() { err = errors.WithDeferred(err, u.Close()) }() - - return check(u) -} - // closeBoots closes all the provided bootstrap servers and logs errors if any. func closeBoots(boots []*upstream.UpstreamResolver) { for _, c := range boots { @@ -942,36 +577,11 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { } defer closeBoots(boots) - wg := &sync.WaitGroup{} - m := &sync.Map{} + cv := newUpstreamConfigValidator(req.Upstreams, req.FallbackDNS, req.PrivateUpstreams, opts) + cv.check() + cv.close() - wg.Add(len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams)) - - for _, ups := range req.Upstreams { - go s.checkDNS(ups, opts, checkDNSUpstreamExc, wg, m) - } - for _, ups := range req.FallbackDNS { - go s.checkDNS(ups, opts, checkDNSUpstreamExc, wg, m) - } - for _, ups := range req.PrivateUpstreams { - go s.checkDNS(ups, opts, checkPrivateUpstreamExc, wg, m) - } - - wg.Wait() - - result := map[string]string{} - m.Range(func(k, v any) bool { - // TODO(e.burkov): The upstreams used for both common and private - // resolving should be reported separately. - ups := k.(string) - status := v.(string) - - result[ups] = status - - return true - }) - - aghhttp.WriteJSONResponseOK(w, r, result) + aghhttp.WriteJSONResponseOK(w, r, cv.status()) } // handleCacheClear is the handler for the POST /control/cache_clear HTTP API. diff --git a/internal/dnsforward/http_test.go b/internal/dnsforward/http_test.go index b16c26df..005f08f5 100644 --- a/internal/dnsforward/http_test.go +++ b/internal/dnsforward/http_test.go @@ -363,7 +363,7 @@ func TestValidateUpstreams(t *testing.T) { set: []string{"123.3.7m"}, }, { name: "invalid", - wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": ` + + wantErr: `splitting upstream line "[/host.com]tls://dns.adguard.com": ` + `missing separator`, set: []string{"[/host.com]tls://dns.adguard.com"}, }, { @@ -389,7 +389,7 @@ func TestValidateUpstreams(t *testing.T) { }, }, { name: "bad_domain", - wantErr: `bad upstream for domain "[/!/]8.8.8.8": domain at index 0: ` + + wantErr: `splitting upstream line "[/!/]8.8.8.8": domain at index 0: ` + `bad domain name "!": bad top-level domain name label "!": ` + `bad top-level domain name label rune '!'`, set: []string{"[/!/]8.8.8.8"}, @@ -477,25 +477,15 @@ func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (r } func TestServer_HandleTestUpstreamDNS(t *testing.T) { - goodHandler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { + hdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { err := w.WriteMsg(new(dns.Msg).SetReply(m)) require.NoError(testutil.PanicT{}, err) }) - badHandler := dns.HandlerFunc(func(w dns.ResponseWriter, _ *dns.Msg) { - err := w.WriteMsg(new(dns.Msg)) - require.NoError(testutil.PanicT{}, err) - }) - goodUps := (&url.URL{ + ups := (&url.URL{ Scheme: "tcp", - Host: newLocalUpstreamListener(t, 0, goodHandler).String(), + Host: newLocalUpstreamListener(t, 0, hdlr).String(), }).String() - badUps := (&url.URL{ - Scheme: "tcp", - Host: newLocalUpstreamListener(t, 0, badHandler).String(), - }).String() - - goodAndBadUps := strings.Join([]string{goodUps, badUps}, " ") const ( upsTimeout = 100 * time.Millisecond @@ -504,7 +494,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) { upstreamHost = "custom.localhost" ) - hostsListener := newLocalUpstreamListener(t, 0, goodHandler) + hostsListener := newLocalUpstreamListener(t, 0, hdlr) hostsUps := (&url.URL{ Scheme: "tcp", Host: netutil.JoinHostPort(upstreamHost, hostsListener.Port()), @@ -545,43 +535,6 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) { wantResp map[string]any name string }{{ - body: map[string]any{ - "upstream_dns": []string{goodUps}, - }, - wantResp: map[string]any{ - goodUps: "OK", - }, - name: "success", - }, { - body: map[string]any{ - "upstream_dns": []string{badUps}, - }, - wantResp: map[string]any{ - badUps: `couldn't communicate with upstream: exchanging with ` + - badUps + ` over tcp: dns: id mismatch`, - }, - name: "broken", - }, { - body: map[string]any{ - "upstream_dns": []string{goodUps, badUps}, - }, - wantResp: map[string]any{ - goodUps: "OK", - badUps: `couldn't communicate with upstream: exchanging with ` + - badUps + ` over tcp: dns: id mismatch`, - }, - name: "both", - }, { - body: map[string]any{ - "upstream_dns": []string{"[/domain.example/]" + badUps}, - }, - wantResp: map[string]any{ - badUps: `WARNING: couldn't communicate ` + - `with upstream: exchanging with ` + badUps + ` over tcp: ` + - `dns: id mismatch`, - }, - name: "domain_specific_error", - }, { body: map[string]any{ "upstream_dns": []string{hostsUps}, }, @@ -591,63 +544,12 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) { name: "etc_hosts", }, { body: map[string]any{ - "fallback_dns": []string{goodUps}, + "upstream_dns": []string{ups, "#this.is.comment"}, }, wantResp: map[string]any{ - goodUps: "OK", + ups: "OK", }, - name: "fallback_success", - }, { - body: map[string]any{ - "fallback_dns": []string{badUps}, - }, - wantResp: map[string]any{ - badUps: `couldn't communicate with upstream: exchanging with ` + - badUps + ` over tcp: dns: id mismatch`, - }, - name: "fallback_broken", - }, { - body: map[string]any{ - "fallback_dns": []string{goodUps, "#this.is.comment"}, - }, - wantResp: map[string]any{ - goodUps: "OK", - }, - name: "fallback_comment_mix", - }, { - body: map[string]any{ - "upstream_dns": []string{"[/domain.example/]" + goodUps + " " + badUps}, - }, - wantResp: map[string]any{ - goodUps: "OK", - badUps: `WARNING: couldn't communicate ` + - `with upstream: exchanging with ` + badUps + ` over tcp: ` + - `dns: id mismatch`, - }, - name: "multiple_domain_specific_upstreams", - }, { - body: map[string]any{ - "upstream_dns": []string{"[/domain.example/]/]1.2.3.4"}, - }, - wantResp: map[string]any{ - "[/domain.example/]/]1.2.3.4": `wrong upstream format: ` + - `bad upstream for domain "[/domain.example/]/]1.2.3.4": ` + - `duplicated separator`, - }, - name: "bad_specification", - }, { - body: map[string]any{ - "upstream_dns": []string{"[/domain.example/]" + goodAndBadUps}, - "fallback_dns": []string{"[/domain.example/]" + goodAndBadUps}, - "private_upstream": []string{"[/domain.example/]" + goodAndBadUps}, - }, - wantResp: map[string]any{ - goodUps: "OK", - badUps: `WARNING: couldn't communicate ` + - `with upstream: exchanging with ` + badUps + ` over tcp: ` + - `dns: id mismatch`, - }, - name: "all_different", + name: "comment_mix", }} for _, tc := range testCases { diff --git a/internal/dnsforward/upstreams.go b/internal/dnsforward/upstreams.go index e71a9672..3f877ac7 100644 --- a/internal/dnsforward/upstreams.go +++ b/internal/dnsforward/upstreams.go @@ -2,14 +2,43 @@ package dnsforward import ( "fmt" + "net" + "net/netip" "os" + "strings" "time" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" +) + +const ( + // errNotDomainSpecific is returned when the upstream should be + // domain-specific, but isn't. + errNotDomainSpecific errors.Error = "not a domain-specific upstream" + + // errMissingSeparator is returned when the domain-specific part of the + // upstream configuration line isn't closed. + errMissingSeparator errors.Error = "missing separator" + + // errDupSeparator is returned when the domain-specific part of the upstream + // configuration line contains more than one ending separator. + errDupSeparator errors.Error = "duplicated separator" + + // errNoDefaultUpstreams is returned when there are no default upstreams + // specified in the upstream configuration. + errNoDefaultUpstreams errors.Error = "no default upstreams specified" + + // errWrongResponse is returned when the checked upstream replies in an + // unexpected way. + errWrongResponse errors.Error = "wrong response" ) // loadUpstreams parses upstream DNS servers from the configured file or from @@ -158,3 +187,183 @@ func (s *Server) createBootstrap( return r, boots, nil } + +// IsCommentOrEmpty returns true if s starts with a "#" character or is empty. +// This function is useful for filtering out non-upstream lines from upstream +// configs. +func IsCommentOrEmpty(s string) (ok bool) { + return len(s) == 0 || s[0] == '#' +} + +// newUpstreamConfig validates upstreams and returns an appropriate upstream +// configuration or nil if it can't be built. +// +// TODO(e.burkov): Perhaps proxy.ParseUpstreamsConfig should validate upstreams +// slice already so that this function may be considered useless. +func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err error) { + // No need to validate comments and empty lines. + upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty) + if len(upstreams) == 0 { + // Consider this case valid since it means the default server should be + // used. + return nil, nil + } + + err = validateUpstreamConfig(upstreams) + if err != nil { + // Don't wrap the error since it's informative enough as is. + return nil, err + } + + conf, err = proxy.ParseUpstreamsConfig( + upstreams, + &upstream.Options{ + Bootstrap: net.DefaultResolver, + Timeout: DefaultTimeout, + }, + ) + if err != nil { + // Don't wrap the error since it's informative enough as is. + return nil, err + } else if len(conf.Upstreams) == 0 { + return nil, errNoDefaultUpstreams + } + + return conf, nil +} + +// validateUpstreamConfig validates each upstream from the upstream +// configuration and returns an error if any upstream is invalid. +// +// TODO(e.burkov): Merge with [upstreamConfigValidator] somehow. +func validateUpstreamConfig(conf []string) (err error) { + for _, u := range conf { + var ups []string + var isSpecific bool + ups, isSpecific, err = splitUpstreamLine(u) + if err != nil { + // Don't wrap the error since it's informative enough as is. + return err + } + + for _, addr := range ups { + _, err = validateUpstream(addr, isSpecific) + if err != nil { + return fmt.Errorf("validating upstream %q: %w", addr, err) + } + } + } + + return nil +} + +// ValidateUpstreams validates each upstream and returns an error if any +// upstream is invalid or if there are no default upstreams specified. +// +// TODO(e.burkov): Merge with [upstreamConfigValidator] somehow. +func ValidateUpstreams(upstreams []string) (err error) { + _, err = newUpstreamConfig(upstreams) + + return err +} + +// ValidateUpstreamsPrivate validates each upstream and returns an error if any +// upstream is invalid or if there are no default upstreams specified. It also +// checks each domain of domain-specific upstreams for being ARPA pointing to +// a locally-served network. privateNets must not be nil. +func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) { + conf, err := newUpstreamConfig(upstreams) + if err != nil { + return fmt.Errorf("creating config: %w", err) + } + + if conf == nil { + return nil + } + + keys := maps.Keys(conf.DomainReservedUpstreams) + slices.Sort(keys) + + var errs []error + for _, domain := range keys { + var subnet netip.Prefix + subnet, err = extractARPASubnet(domain) + if err != nil { + errs = append(errs, err) + + continue + } + + if !privateNets.Contains(subnet.Addr().AsSlice()) { + errs = append( + errs, + fmt.Errorf("arpa domain %q should point to a locally-served network", domain), + ) + } + } + + return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w") +} + +// protocols are the supported URL schemes for upstreams. +var protocols = []string{"h3", "https", "quic", "sdns", "tcp", "tls", "udp"} + +// validateUpstream returns an error if u alongside with domains is not a valid +// upstream configuration. useDefault is true if the upstream is +// domain-specific and is configured to point at the default upstream server +// which is validated separately. The upstream is considered domain-specific +// only if domains is at least not nil. +func validateUpstream(u string, isSpecific bool) (useDefault bool, err error) { + // The special server address '#' means that default server must be used. + if useDefault = u == "#" && isSpecific; useDefault { + return useDefault, nil + } + + // Check if the upstream has a valid protocol prefix. + // + // TODO(e.burkov): Validate the domain name. + if proto, _, ok := strings.Cut(u, "://"); ok { + if !slices.Contains(protocols, proto) { + return false, fmt.Errorf("bad protocol %q", proto) + } + } else if _, err = netip.ParseAddr(u); err == nil { + return false, nil + } else if _, err = netip.ParseAddrPort(u); err == nil { + return false, nil + } + + return false, err +} + +// splitUpstreamLine returns the upstreams and the specified domains. domains +// is nil when the upstream is not domains-specific. Otherwise it may also be +// empty. +func splitUpstreamLine(upstreamStr string) (upstreams []string, isSpecific bool, err error) { + if !strings.HasPrefix(upstreamStr, "[/") { + return []string{upstreamStr}, false, nil + } + + defer func() { err = errors.Annotate(err, "splitting upstream line %q: %w", upstreamStr) }() + + doms, ups, found := strings.Cut(upstreamStr[2:], "/]") + if !found { + return nil, false, errMissingSeparator + } else if strings.Contains(ups, "/]") { + return nil, false, errDupSeparator + } + + for i, host := range strings.Split(doms, "/") { + if host == "" { + continue + } + + err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*.")) + if err != nil { + return nil, false, fmt.Errorf("domain at index %d: %w", i, err) + } + + isSpecific = true + } + + return strings.Fields(ups), isSpecific, nil +} diff --git a/internal/dnsforward/upstreams_internal_test.go b/internal/dnsforward/upstreams_internal_test.go new file mode 100644 index 00000000..d3d4ebb6 --- /dev/null +++ b/internal/dnsforward/upstreams_internal_test.go @@ -0,0 +1,209 @@ +package dnsforward + +import ( + "net" + "net/url" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/testutil" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUpstreamConfigValidator(t *testing.T) { + goodHandler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { + err := w.WriteMsg(new(dns.Msg).SetReply(m)) + require.NoError(testutil.PanicT{}, err) + }) + badHandler := dns.HandlerFunc(func(w dns.ResponseWriter, _ *dns.Msg) { + err := w.WriteMsg(new(dns.Msg)) + require.NoError(testutil.PanicT{}, err) + }) + + goodUps := (&url.URL{ + Scheme: "tcp", + Host: newLocalUpstreamListener(t, 0, goodHandler).String(), + }).String() + badUps := (&url.URL{ + Scheme: "tcp", + Host: newLocalUpstreamListener(t, 0, badHandler).String(), + }).String() + + goodAndBadUps := strings.Join([]string{goodUps, badUps}, " ") + + // upsTimeout restricts the checking process to prevent the test from + // hanging. + const upsTimeout = 100 * time.Millisecond + + testCases := []struct { + want map[string]string + name string + general []string + fallback []string + private []string + }{{ + name: "success", + general: []string{goodUps}, + want: map[string]string{ + goodUps: "OK", + }, + }, { + name: "broken", + general: []string{badUps}, + want: map[string]string{ + badUps: `couldn't communicate with upstream: exchanging with ` + + badUps + ` over tcp: dns: id mismatch`, + }, + }, { + name: "both", + general: []string{goodUps, badUps, goodUps}, + want: map[string]string{ + goodUps: "OK", + badUps: `couldn't communicate with upstream: exchanging with ` + + badUps + ` over tcp: dns: id mismatch`, + }, + }, { + name: "domain_specific_error", + general: []string{"[/domain.example/]" + badUps}, + want: map[string]string{ + badUps: `WARNING: couldn't communicate ` + + `with upstream: exchanging with ` + badUps + ` over tcp: ` + + `dns: id mismatch`, + }, + }, { + name: "fallback_success", + fallback: []string{goodUps}, + want: map[string]string{ + goodUps: "OK", + }, + }, { + name: "fallback_broken", + fallback: []string{badUps}, + want: map[string]string{ + badUps: `couldn't communicate with upstream: exchanging with ` + + badUps + ` over tcp: dns: id mismatch`, + }, + }, { + name: "multiple_domain_specific_upstreams", + general: []string{"[/domain.example/]" + goodAndBadUps}, + want: map[string]string{ + goodUps: "OK", + badUps: `WARNING: couldn't communicate ` + + `with upstream: exchanging with ` + badUps + ` over tcp: ` + + `dns: id mismatch`, + }, + }, { + name: "bad_specification", + general: []string{"[/domain.example/]/]1.2.3.4"}, + want: map[string]string{ + "[/domain.example/]/]1.2.3.4": `splitting upstream line ` + + `"[/domain.example/]/]1.2.3.4": duplicated separator`, + }, + }, { + name: "all_different", + general: []string{"[/domain.example/]" + goodAndBadUps}, + fallback: []string{"[/domain.example/]" + goodAndBadUps}, + private: []string{"[/domain.example/]" + goodAndBadUps}, + want: map[string]string{ + goodUps: "OK", + badUps: `WARNING: couldn't communicate ` + + `with upstream: exchanging with ` + badUps + ` over tcp: ` + + `dns: id mismatch`, + }, + }, { + name: "bad_specific_domains", + general: []string{"[/example/]/]" + goodUps}, + fallback: []string{"[/example/" + goodUps}, + private: []string{"[/example//bad.123/]" + goodUps}, + want: map[string]string{ + `[/example/]/]` + goodUps: `splitting upstream line ` + + `"[/example/]/]` + goodUps + `": duplicated separator`, + `[/example/` + goodUps: `splitting upstream line ` + + `"[/example/` + goodUps + `": missing separator`, + `[/example//bad.123/]` + goodUps: `splitting upstream line ` + + `"[/example//bad.123/]` + goodUps + `": domain at index 2: ` + + `bad domain name "bad.123": ` + + `bad top-level domain name label "123": all octets are numeric`, + }, + }, { + name: "non-specific_default", + general: []string{ + "#", + "[/example/]#", + }, + want: map[string]string{ + "#": "not a domain-specific upstream", + }, + }, { + name: "bad_proto", + general: []string{ + "bad://1.2.3.4", + }, + want: map[string]string{ + "bad://1.2.3.4": `bad protocol "bad"`, + }, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cv := newUpstreamConfigValidator(tc.general, tc.fallback, tc.private, &upstream.Options{ + Timeout: upsTimeout, + Bootstrap: net.DefaultResolver, + }) + cv.check() + cv.close() + + assert.Equal(t, tc.want, cv.status()) + }) + } +} + +func TestUpstreamConfigValidator_Check_once(t *testing.T) { + reqs := atomic.Int32{} + reset := func() { reqs.Store(0) } + + hdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { + err := w.WriteMsg(new(dns.Msg).SetReply(m)) + require.NoError(testutil.PanicT{}, err) + reqs.Add(1) + }) + + addr := (&url.URL{ + Scheme: "tcp", + Host: newLocalUpstreamListener(t, 0, hdlr).String(), + }).String() + twoAddrs := strings.Join([]string{addr, addr}, " ") + + testCases := []struct { + name string + ups []string + }{{ + name: "common", + ups: []string{addr, addr, addr}, + }, { + name: "domain-specific", + ups: []string{"[/one.example/]" + addr, "[/two.example/]" + twoAddrs}, + }, { + name: "both", + ups: []string{addr, "[/one.example/]" + addr, addr, "[/two.example/]" + twoAddrs}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Cleanup(reset) + + cv := newUpstreamConfigValidator(tc.ups, nil, nil, &upstream.Options{ + Timeout: 100 * time.Millisecond, + }) + cv.check() + cv.close() + + assert.Equal(t, int32(1), reqs.Load()) + }) + } +}