package dnsforward import ( "fmt" "net" "net/netip" "os" "strings" "time" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" "golang.org/x/exp/maps" "golang.org/x/exp/slices" ) const ( // errNotDomainSpecific is returned when the upstream should be // domain-specific, but isn't. errNotDomainSpecific errors.Error = "not a domain-specific upstream" // errMissingSeparator is returned when the domain-specific part of the // upstream configuration line isn't closed. errMissingSeparator errors.Error = "missing separator" // errDupSeparator is returned when the domain-specific part of the upstream // configuration line contains more than one ending separator. errDupSeparator errors.Error = "duplicated separator" // errNoDefaultUpstreams is returned when there are no default upstreams // specified in the upstream configuration. errNoDefaultUpstreams errors.Error = "no default upstreams specified" // errWrongResponse is returned when the checked upstream replies in an // unexpected way. errWrongResponse errors.Error = "wrong response" ) // loadUpstreams parses upstream DNS servers from the configured file or from // the configuration itself. func (s *Server) loadUpstreams() (upstreams []string, err error) { if s.conf.UpstreamDNSFileName == "" { return stringutil.FilterOut(s.conf.UpstreamDNS, IsCommentOrEmpty), nil } var data []byte data, err = os.ReadFile(s.conf.UpstreamDNSFileName) if err != nil { return nil, fmt.Errorf("reading upstream from file: %w", err) } upstreams = stringutil.SplitTrimmed(string(data), "\n") log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), s.conf.UpstreamDNSFileName) return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil } // prepareUpstreamSettings sets upstream DNS server settings. func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) { // Load upstreams either from the file, or from the settings var upstreams []string upstreams, err = s.loadUpstreams() if err != nil { return fmt.Errorf("loading upstreams: %w", err) } s.conf.UpstreamConfig, err = s.prepareUpstreamConfig(upstreams, defaultDNS, &upstream.Options{ Bootstrap: boot, Timeout: s.conf.UpstreamTimeout, HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams), PreferIPv6: s.conf.BootstrapPreferIPv6, // Use a customized set of RootCAs, because Go's default mechanism of // loading TLS roots does not always work properly on some routers so we're // loading roots manually and pass it here. // // See [aghtls.SystemRootCAs]. // // TODO(a.garipov): Investigate if that's true. RootCAs: s.conf.TLSv12Roots, CipherSuites: s.conf.TLSCiphers, }) if err != nil { return fmt.Errorf("preparing upstream config: %w", err) } return nil } // prepareUpstreamConfig returns the upstream configuration based on upstreams // and configuration of s. func (s *Server) prepareUpstreamConfig( upstreams []string, defaultUpstreams []string, opts *upstream.Options, ) (uc *proxy.UpstreamConfig, err error) { uc, err = proxy.ParseUpstreamsConfig(upstreams, opts) if err != nil { return nil, fmt.Errorf("parsing upstream config: %w", err) } if len(uc.Upstreams) == 0 && defaultUpstreams != nil { log.Info("dnsforward: warning: no default upstreams specified, using %v", defaultUpstreams) var defaultUpstreamConfig *proxy.UpstreamConfig defaultUpstreamConfig, err = proxy.ParseUpstreamsConfig(defaultUpstreams, opts) if err != nil { return nil, fmt.Errorf("parsing default upstreams: %w", err) } uc.Upstreams = defaultUpstreamConfig.Upstreams } return uc, nil } // UpstreamHTTPVersions returns the HTTP versions for upstream configuration // depending on configuration. func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) { if !http3 { return upstream.DefaultHTTPVersions } return []upstream.HTTPVersion{ upstream.HTTPVersion3, upstream.HTTPVersion2, upstream.HTTPVersion11, } } // setProxyUpstreamMode sets the upstream mode and related settings in conf // based on provided parameters. func setProxyUpstreamMode( conf *proxy.Config, allServers bool, fastestAddr bool, fastestTimeout time.Duration, ) { if allServers { conf.UpstreamMode = proxy.UModeParallel } else if fastestAddr { conf.UpstreamMode = proxy.UModeFastestAddr conf.FastestPingTimeout = fastestTimeout } else { conf.UpstreamMode = proxy.UModeLoadBalance } } // createBootstrap returns a bootstrap resolver based on the configuration of s. // boots are the upstream resolvers that should be closed after use. r is the // actual bootstrap resolver, which may include the system hosts. // // TODO(e.burkov): This function currently returns a resolver and a slice of // the upstream resolvers, which are essentially the same. boots are returned // for being able to close them afterwards, but it introduces an implicit // contract that r could only be used before that. Anyway, this code should // improve when the [proxy.UpstreamConfig] will become an [upstream.Resolver] // and be used here. func (s *Server) createBootstrap( addrs []string, opts *upstream.Options, ) (r upstream.Resolver, boots []*upstream.UpstreamResolver, err error) { if len(addrs) == 0 { addrs = defaultBootstrap } boots, err = aghnet.ParseBootstraps(addrs, opts) if err != nil { // Don't wrap the error, since it's informative enough as is. return nil, nil, err } var parallel upstream.ParallelResolver for _, b := range boots { parallel = append(parallel, b) } if s.etcHosts != nil { r = upstream.ConsequentResolver{s.etcHosts, parallel} } else { r = parallel } return r, boots, nil } // IsCommentOrEmpty returns true if s starts with a "#" character or is empty. // This function is useful for filtering out non-upstream lines from upstream // configs. func IsCommentOrEmpty(s string) (ok bool) { return len(s) == 0 || s[0] == '#' } // newUpstreamConfig validates upstreams and returns an appropriate upstream // configuration or nil if it can't be built. // // TODO(e.burkov): Perhaps proxy.ParseUpstreamsConfig should validate upstreams // slice already so that this function may be considered useless. func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err error) { // No need to validate comments and empty lines. upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty) if len(upstreams) == 0 { // Consider this case valid since it means the default server should be // used. return nil, nil } err = validateUpstreamConfig(upstreams) if err != nil { // Don't wrap the error since it's informative enough as is. return nil, err } conf, err = proxy.ParseUpstreamsConfig( upstreams, &upstream.Options{ Bootstrap: net.DefaultResolver, Timeout: DefaultTimeout, }, ) if err != nil { // Don't wrap the error since it's informative enough as is. return nil, err } else if len(conf.Upstreams) == 0 { return nil, errNoDefaultUpstreams } return conf, nil } // validateUpstreamConfig validates each upstream from the upstream // configuration and returns an error if any upstream is invalid. // // TODO(e.burkov): Merge with [upstreamConfigValidator] somehow. func validateUpstreamConfig(conf []string) (err error) { for _, u := range conf { var ups []string var isSpecific bool ups, isSpecific, err = splitUpstreamLine(u) if err != nil { // Don't wrap the error since it's informative enough as is. return err } for _, addr := range ups { _, err = validateUpstream(addr, isSpecific) if err != nil { return fmt.Errorf("validating upstream %q: %w", addr, err) } } } return nil } // ValidateUpstreams validates each upstream and returns an error if any // upstream is invalid or if there are no default upstreams specified. // // TODO(e.burkov): Merge with [upstreamConfigValidator] somehow. func ValidateUpstreams(upstreams []string) (err error) { _, err = newUpstreamConfig(upstreams) return err } // ValidateUpstreamsPrivate validates each upstream and returns an error if any // upstream is invalid or if there are no default upstreams specified. It also // checks each domain of domain-specific upstreams for being ARPA pointing to // a locally-served network. privateNets must not be nil. func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) { conf, err := newUpstreamConfig(upstreams) if err != nil { return fmt.Errorf("creating config: %w", err) } if conf == nil { return nil } keys := maps.Keys(conf.DomainReservedUpstreams) slices.Sort(keys) var errs []error for _, domain := range keys { var subnet netip.Prefix subnet, err = extractARPASubnet(domain) if err != nil { errs = append(errs, err) continue } if !privateNets.Contains(subnet.Addr().AsSlice()) { errs = append( errs, fmt.Errorf("arpa domain %q should point to a locally-served network", domain), ) } } return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w") } // protocols are the supported URL schemes for upstreams. var protocols = []string{"h3", "https", "quic", "sdns", "tcp", "tls", "udp"} // validateUpstream returns an error if u alongside with domains is not a valid // upstream configuration. useDefault is true if the upstream is // domain-specific and is configured to point at the default upstream server // which is validated separately. The upstream is considered domain-specific // only if domains is at least not nil. func validateUpstream(u string, isSpecific bool) (useDefault bool, err error) { // The special server address '#' means that default server must be used. if useDefault = u == "#" && isSpecific; useDefault { return useDefault, nil } // Check if the upstream has a valid protocol prefix. // // TODO(e.burkov): Validate the domain name. if proto, _, ok := strings.Cut(u, "://"); ok { if !slices.Contains(protocols, proto) { return false, fmt.Errorf("bad protocol %q", proto) } } else if _, err = netip.ParseAddr(u); err == nil { return false, nil } else if _, err = netip.ParseAddrPort(u); err == nil { return false, nil } return false, err } // splitUpstreamLine returns the upstreams and the specified domains. domains // is nil when the upstream is not domains-specific. Otherwise it may also be // empty. func splitUpstreamLine(upstreamStr string) (upstreams []string, isSpecific bool, err error) { if !strings.HasPrefix(upstreamStr, "[/") { return []string{upstreamStr}, false, nil } defer func() { err = errors.Annotate(err, "splitting upstream line %q: %w", upstreamStr) }() doms, ups, found := strings.Cut(upstreamStr[2:], "/]") if !found { return nil, false, errMissingSeparator } else if strings.Contains(ups, "/]") { return nil, false, errDupSeparator } for i, host := range strings.Split(doms, "/") { if host == "" { continue } err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*.")) if err != nil { return nil, false, fmt.Errorf("domain at index %d: %w", i, err) } isSpecific = true } return strings.Fields(ups), isSpecific, nil }