package dnsforward import ( "encoding/json" "fmt" "io" "net/http" "net/netip" "time" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" "golang.org/x/exp/slices" ) // jsonDNSConfig is the JSON representation of the DNS server configuration. // // TODO(s.chzhen): Split it into smaller pieces. Use aghalg.NullBool instead // of *bool. type jsonDNSConfig struct { // Upstreams is the list of upstream DNS servers. Upstreams *[]string `json:"upstream_dns"` // UpstreamsFile is the file containing upstream DNS servers. UpstreamsFile *string `json:"upstream_dns_file"` // Bootstraps is the list of DNS servers resolving IP addresses of the // upstream DoH/DoT resolvers. Bootstraps *[]string `json:"bootstrap_dns"` // Fallbacks is the list of fallback DNS servers used when upstream DNS // servers are not responding. Fallbacks *[]string `json:"fallback_dns"` // ProtectionEnabled defines if protection is enabled. ProtectionEnabled *bool `json:"protection_enabled"` // Ratelimit is the number of requests per second allowed per client. Ratelimit *uint32 `json:"ratelimit"` // RatelimitSubnetLenIPv4 is a subnet length for IPv4 addresses used for // rate limiting requests. RatelimitSubnetLenIPv4 *int `json:"ratelimit_subnet_len_ipv4"` // RatelimitSubnetLenIPv6 is a subnet length for IPv6 addresses used for // rate limiting requests. RatelimitSubnetLenIPv6 *int `json:"ratelimit_subnet_len_ipv6"` // RatelimitWhitelist is a list of IP addresses excluded from rate limiting. RatelimitWhitelist *[]string `json:"ratelimit_whitelist"` // BlockingMode defines the way blocked responses are constructed. BlockingMode *filtering.BlockingMode `json:"blocking_mode"` // EDNSCSEnabled defines if EDNS Client Subnet is enabled. EDNSCSEnabled *bool `json:"edns_cs_enabled"` // EDNSCSUseCustom defines if EDNSCSCustomIP should be used. EDNSCSUseCustom *bool `json:"edns_cs_use_custom"` // DNSSECEnabled defines if DNSSEC is enabled. DNSSECEnabled *bool `json:"dnssec_enabled"` // DisableIPv6 defines if IPv6 addresses should be dropped. DisableIPv6 *bool `json:"disable_ipv6"` // UpstreamMode defines the way DNS requests are constructed. UpstreamMode *string `json:"upstream_mode"` // BlockedResponseTTL is the TTL for blocked responses. BlockedResponseTTL *uint32 `json:"blocked_response_ttl"` // CacheSize in bytes. CacheSize *uint32 `json:"cache_size"` // CacheMinTTL is custom minimum TTL for cached DNS responses. CacheMinTTL *uint32 `json:"cache_ttl_min"` // CacheMaxTTL is custom maximum TTL for cached DNS responses. CacheMaxTTL *uint32 `json:"cache_ttl_max"` // CacheOptimistic defines if expired entries should be served. CacheOptimistic *bool `json:"cache_optimistic"` // ResolveClients defines if clients IPs should be resolved into hostnames. ResolveClients *bool `json:"resolve_clients"` // UsePrivateRDNS defines if privates DNS resolvers should be used. UsePrivateRDNS *bool `json:"use_private_ptr_resolvers"` // LocalPTRUpstreams is the list of local private DNS resolvers. LocalPTRUpstreams *[]string `json:"local_ptr_upstreams"` // BlockingIPv4 is custom IPv4 address for blocked A requests. BlockingIPv4 netip.Addr `json:"blocking_ipv4"` // BlockingIPv6 is custom IPv6 address for blocked AAAA requests. BlockingIPv6 netip.Addr `json:"blocking_ipv6"` // DisabledUntil is a timestamp until when the protection is disabled. DisabledUntil *time.Time `json:"protection_disabled_until"` // EDNSCSCustomIP is custom IP for EDNS Client Subnet. EDNSCSCustomIP netip.Addr `json:"edns_cs_custom_ip"` // DefaultLocalPTRUpstreams is used to pass the addresses from // systemResolvers to the front-end. It's not a pointer to the slice since // there is no need to omit it while decoding from JSON. DefaultLocalPTRUpstreams []string `json:"default_local_ptr_upstreams,omitempty"` } func (s *Server) getDNSConfig() (c *jsonDNSConfig) { protectionEnabled, protectionDisabledUntil := s.UpdatedProtectionStatus() s.serverLock.RLock() defer s.serverLock.RUnlock() upstreams := stringutil.CloneSliceOrEmpty(s.conf.UpstreamDNS) upstreamFile := s.conf.UpstreamDNSFileName bootstraps := stringutil.CloneSliceOrEmpty(s.conf.BootstrapDNS) fallbacks := stringutil.CloneSliceOrEmpty(s.conf.FallbackDNS) blockingMode, blockingIPv4, blockingIPv6 := s.dnsFilter.BlockingMode() blockedResponseTTL := s.dnsFilter.BlockedResponseTTL() ratelimit := s.conf.Ratelimit ratelimitSubnetLenIPv4 := s.conf.RatelimitSubnetLenIPv4 ratelimitSubnetLenIPv6 := s.conf.RatelimitSubnetLenIPv6 ratelimitWhitelist := stringutil.CloneSliceOrEmpty(s.conf.RatelimitWhitelist) customIP := s.conf.EDNSClientSubnet.CustomIP enableEDNSClientSubnet := s.conf.EDNSClientSubnet.Enabled useCustom := s.conf.EDNSClientSubnet.UseCustom enableDNSSEC := s.conf.EnableDNSSEC aaaaDisabled := s.conf.AAAADisabled cacheSize := s.conf.CacheSize cacheMinTTL := s.conf.CacheMinTTL cacheMaxTTL := s.conf.CacheMaxTTL cacheOptimistic := s.conf.CacheOptimistic resolveClients := s.conf.AddrProcConf.UseRDNS usePrivateRDNS := s.conf.UsePrivateRDNS localPTRUpstreams := stringutil.CloneSliceOrEmpty(s.conf.LocalPTRResolvers) var upstreamMode string if s.conf.FastestAddr { upstreamMode = "fastest_addr" } else if s.conf.AllServers { upstreamMode = "parallel" } defPTRUps, err := s.defaultLocalPTRUpstreams() if err != nil { log.Error("dnsforward: %s", err) } return &jsonDNSConfig{ Upstreams: &upstreams, UpstreamsFile: &upstreamFile, Bootstraps: &bootstraps, Fallbacks: &fallbacks, ProtectionEnabled: &protectionEnabled, BlockingMode: &blockingMode, BlockingIPv4: blockingIPv4, BlockingIPv6: blockingIPv6, Ratelimit: &ratelimit, RatelimitSubnetLenIPv4: &ratelimitSubnetLenIPv4, RatelimitSubnetLenIPv6: &ratelimitSubnetLenIPv6, RatelimitWhitelist: &ratelimitWhitelist, EDNSCSCustomIP: customIP, EDNSCSEnabled: &enableEDNSClientSubnet, EDNSCSUseCustom: &useCustom, DNSSECEnabled: &enableDNSSEC, DisableIPv6: &aaaaDisabled, BlockedResponseTTL: &blockedResponseTTL, CacheSize: &cacheSize, CacheMinTTL: &cacheMinTTL, CacheMaxTTL: &cacheMaxTTL, CacheOptimistic: &cacheOptimistic, UpstreamMode: &upstreamMode, ResolveClients: &resolveClients, UsePrivateRDNS: &usePrivateRDNS, LocalPTRUpstreams: &localPTRUpstreams, DefaultLocalPTRUpstreams: defPTRUps, DisabledUntil: protectionDisabledUntil, } } // defaultLocalPTRUpstreams returns the list of default local PTR resolvers // filtered of AdGuard Home's own DNS server addresses. It may appear empty. func (s *Server) defaultLocalPTRUpstreams() (ups []string, err error) { matcher, err := s.conf.ourAddrsSet() if err != nil { // Don't wrap the error because it's informative enough as is. return nil, err } sysResolvers := slices.DeleteFunc(s.sysResolvers.Addrs(), matcher.Has) ups = make([]string, 0, len(sysResolvers)) for _, r := range sysResolvers { ups = append(ups, r.String()) } return ups, nil } // handleGetConfig handles requests to the GET /control/dns_info endpoint. func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) { resp := s.getDNSConfig() aghhttp.WriteJSONResponseOK(w, r, resp) } // checkBlockingMode returns an error if blocking mode is invalid. func (req *jsonDNSConfig) checkBlockingMode() (err error) { if req.BlockingMode == nil { return nil } return validateBlockingMode(*req.BlockingMode, req.BlockingIPv4, req.BlockingIPv6) } // checkUpstreamsMode returns an error if the upstream mode is invalid. func (req *jsonDNSConfig) checkUpstreamsMode() (err error) { if req.UpstreamMode == nil { return nil } mode := *req.UpstreamMode if ok := slices.Contains([]string{"", "fastest_addr", "parallel"}, mode); !ok { return fmt.Errorf("upstream_mode: incorrect value %q", mode) } return nil } // checkBootstrap returns an error if any bootstrap address is invalid. func (req *jsonDNSConfig) checkBootstrap() (err error) { if req.Bootstraps == nil { return nil } var b string defer func() { err = errors.Annotate(err, "checking bootstrap %s: invalid address: %w", b) }() for _, b = range *req.Bootstraps { if b == "" { return errors.Error("empty") } if _, err = upstream.NewUpstreamResolver(b, nil); err != nil { // Don't wrap the error because it's informative enough as is. return err } } return nil } // checkFallbacks returns an error if any fallback address is invalid. func (req *jsonDNSConfig) checkFallbacks() (err error) { if req.Fallbacks == nil { return nil } err = ValidateUpstreams(*req.Fallbacks) if err != nil { return fmt.Errorf("fallback servers: %w", err) } return nil } // validate returns an error if any field of req is invalid. // // TODO(s.chzhen): Parse, don't validate. func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) { defer func() { err = errors.Annotate(err, "validating dns config: %w") }() err = req.validateUpstreamDNSServers(privateNets) if err != nil { // Don't wrap the error since it's informative enough as is. return err } err = req.checkRatelimitSubnetMaskLen() if err != nil { // Don't wrap the error since it's informative enough as is. return err } err = req.checkRatelimitWhitelist() if err != nil { // Don't wrap the error since it's informative enough as is. return err } err = req.checkBlockingMode() if err != nil { // Don't wrap the error since it's informative enough as is. return err } err = req.checkUpstreamsMode() if err != nil { // Don't wrap the error since it's informative enough as is. return err } err = req.checkCacheTTL() if err != nil { // Don't wrap the error since it's informative enough as is. return err } return nil } // validateUpstreamDNSServers returns an error if any field of req is invalid. func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetSet) (err error) { if req.Upstreams != nil { err = ValidateUpstreams(*req.Upstreams) if err != nil { return fmt.Errorf("upstream servers: %w", err) } } if req.LocalPTRUpstreams != nil { err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets) if err != nil { return fmt.Errorf("private upstream servers: %w", err) } } err = req.checkBootstrap() if err != nil { // Don't wrap the error since it's informative enough as is. return err } err = req.checkFallbacks() if err != nil { // Don't wrap the error since it's informative enough as is. return err } return nil } // checkCacheTTL returns an error if the configuration of the cache TTL is // invalid. func (req *jsonDNSConfig) checkCacheTTL() (err error) { if req.CacheMinTTL == nil && req.CacheMaxTTL == nil { return nil } var minTTL, maxTTL uint32 if req.CacheMinTTL != nil { minTTL = *req.CacheMinTTL } if req.CacheMaxTTL != nil { maxTTL = *req.CacheMaxTTL } if minTTL <= maxTTL { return nil } return errors.Error("cache_ttl_min must be less or equal than cache_ttl_max") } // checkRatelimitSubnetMaskLen returns an error if the length of the subnet mask // for IPv4 or IPv6 addresses is invalid. func (req *jsonDNSConfig) checkRatelimitSubnetMaskLen() (err error) { err = checkInclusion(req.RatelimitSubnetLenIPv4, 0, netutil.IPv4BitLen) if err != nil { return fmt.Errorf("ratelimit_subnet_len_ipv4 is invalid: %w", err) } err = checkInclusion(req.RatelimitSubnetLenIPv6, 0, netutil.IPv6BitLen) if err != nil { return fmt.Errorf("ratelimit_subnet_len_ipv6 is invalid: %w", err) } return nil } // checkInclusion returns an error if a ptr is not nil and points to value, // that not in the inclusive range between minN and maxN. func checkInclusion(ptr *int, minN, maxN int) (err error) { if ptr == nil { return nil } n := *ptr switch { case n < minN: return fmt.Errorf("value %d less than min %d", n, minN) case n > maxN: return fmt.Errorf("value %d greater than max %d", n, maxN) } return nil } // checkRatelimitWhitelist returns an error if any of IP addresses is invalid. func (req *jsonDNSConfig) checkRatelimitWhitelist() (err error) { if req.RatelimitWhitelist == nil { return nil } for i, ipStr := range *req.RatelimitWhitelist { if _, err = netip.ParseAddr(ipStr); err != nil { return fmt.Errorf("ratelimit whitelist: at index %d: %w", i, err) } } return nil } // handleSetConfig handles requests to the POST /control/dns_config endpoint. func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) { req := &jsonDNSConfig{} err := json.NewDecoder(r.Body).Decode(req) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "decoding request: %s", err) return } err = req.validate(s.privateNets) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) return } restart := s.setConfig(req) s.conf.ConfigModified() if restart { err = s.Reconfigure(nil) if err != nil { aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) } } } // setConfig sets the server parameters. shouldRestart is true if the server // should be restarted to apply changes. func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) { s.serverLock.Lock() defer s.serverLock.Unlock() if dc.BlockingMode != nil { s.dnsFilter.SetBlockingMode(*dc.BlockingMode, dc.BlockingIPv4, dc.BlockingIPv6) } if dc.BlockedResponseTTL != nil { s.dnsFilter.SetBlockedResponseTTL(*dc.BlockedResponseTTL) } if dc.ProtectionEnabled != nil { s.dnsFilter.SetProtectionEnabled(*dc.ProtectionEnabled) } if dc.UpstreamMode != nil { s.conf.AllServers = *dc.UpstreamMode == "parallel" s.conf.FastestAddr = *dc.UpstreamMode == "fastest_addr" } if dc.EDNSCSUseCustom != nil && *dc.EDNSCSUseCustom { s.conf.EDNSClientSubnet.CustomIP = dc.EDNSCSCustomIP } setIfNotNil(&s.conf.EnableDNSSEC, dc.DNSSECEnabled) setIfNotNil(&s.conf.AAAADisabled, dc.DisableIPv6) return s.setConfigRestartable(dc) } // setIfNotNil sets the value pointed at by currentPtr to the value pointed at // by newPtr if newPtr is not nil. currentPtr must not be nil. func setIfNotNil[T any](currentPtr, newPtr *T) (hasSet bool) { if newPtr == nil { return false } *currentPtr = *newPtr return true } // setConfigRestartable sets the parameters which trigger a restart. // shouldRestart is true if the server should be restarted to apply changes. // s.serverLock is expected to be locked. // // TODO(a.garipov): Some of these could probably be updated without a restart. // Inspect and consider refactoring. func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) { for _, hasSet := range []bool{ setIfNotNil(&s.conf.UpstreamDNS, dc.Upstreams), setIfNotNil(&s.conf.LocalPTRResolvers, dc.LocalPTRUpstreams), setIfNotNil(&s.conf.UpstreamDNSFileName, dc.UpstreamsFile), setIfNotNil(&s.conf.BootstrapDNS, dc.Bootstraps), setIfNotNil(&s.conf.FallbackDNS, dc.Fallbacks), setIfNotNil(&s.conf.EDNSClientSubnet.Enabled, dc.EDNSCSEnabled), setIfNotNil(&s.conf.EDNSClientSubnet.UseCustom, dc.EDNSCSUseCustom), setIfNotNil(&s.conf.CacheSize, dc.CacheSize), setIfNotNil(&s.conf.CacheMinTTL, dc.CacheMinTTL), setIfNotNil(&s.conf.CacheMaxTTL, dc.CacheMaxTTL), setIfNotNil(&s.conf.CacheOptimistic, dc.CacheOptimistic), setIfNotNil(&s.conf.AddrProcConf.UseRDNS, dc.ResolveClients), setIfNotNil(&s.conf.UsePrivateRDNS, dc.UsePrivateRDNS), setIfNotNil(&s.conf.RatelimitSubnetLenIPv4, dc.RatelimitSubnetLenIPv4), setIfNotNil(&s.conf.RatelimitSubnetLenIPv6, dc.RatelimitSubnetLenIPv6), setIfNotNil(&s.conf.RatelimitWhitelist, dc.RatelimitWhitelist), } { shouldRestart = shouldRestart || hasSet if shouldRestart { break } } if dc.Ratelimit != nil && s.conf.Ratelimit != *dc.Ratelimit { s.conf.Ratelimit = *dc.Ratelimit shouldRestart = true } return shouldRestart } // upstreamJSON is a request body for handleTestUpstreamDNS endpoint. type upstreamJSON struct { Upstreams []string `json:"upstream_dns"` BootstrapDNS []string `json:"bootstrap_dns"` FallbackDNS []string `json:"fallback_dns"` PrivateUpstreams []string `json:"private_upstream"` } // closeBoots closes all the provided bootstrap servers and logs errors if any. func closeBoots(boots []*upstream.UpstreamResolver) { for _, c := range boots { logCloserErr(c, "dnsforward: closing bootstrap %s: %s", c.Address()) } } // handleTestUpstreamDNS handles requests to the POST /control/test_upstream_dns // endpoint. func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { req := &upstreamJSON{} err := json.NewDecoder(r.Body).Decode(req) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "Failed to read request body: %s", err) return } req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty) req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty) req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty) req.BootstrapDNS = stringutil.FilterOut(req.BootstrapDNS, IsCommentOrEmpty) opts := &upstream.Options{ Timeout: s.conf.UpstreamTimeout, PreferIPv6: s.conf.BootstrapPreferIPv6, } var boots []*upstream.UpstreamResolver opts.Bootstrap, boots, err = s.createBootstrap(req.BootstrapDNS, opts) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse bootstrap servers: %s", err) return } defer closeBoots(boots) cv := newUpstreamConfigValidator(req.Upstreams, req.FallbackDNS, req.PrivateUpstreams, opts) cv.check() cv.close() aghhttp.WriteJSONResponseOK(w, r, cv.status()) } // handleCacheClear is the handler for the POST /control/cache_clear HTTP API. func (s *Server) handleCacheClear(w http.ResponseWriter, _ *http.Request) { s.dnsProxy.ClearCache() _, _ = io.WriteString(w, "OK") } // protectionJSON is an object for /control/protection endpoint. type protectionJSON struct { Enabled bool `json:"enabled"` Duration uint `json:"duration"` } // handleSetProtection is a handler for the POST /control/protection HTTP API. func (s *Server) handleSetProtection(w http.ResponseWriter, r *http.Request) { protectionReq := &protectionJSON{} err := json.NewDecoder(r.Body).Decode(protectionReq) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "reading req: %s", err) return } var disabledUntil *time.Time if protectionReq.Duration > 0 { if protectionReq.Enabled { aghhttp.Error( r, w, http.StatusBadRequest, "Setting a duration is only allowed with protection disabling", ) return } calcTime := time.Now().Add(time.Duration(protectionReq.Duration) * time.Millisecond) disabledUntil = &calcTime } func() { s.serverLock.Lock() defer s.serverLock.Unlock() s.dnsFilter.SetProtectionStatus(protectionReq.Enabled, disabledUntil) }() s.conf.ConfigModified() aghhttp.OK(w) } // handleDoH is the DNS-over-HTTPs handler. // // Control flow: // // HTTP server // -> dnsforward.handleDoH // -> dnsforward.ServeHTTP // -> proxy.ServeHTTP // -> proxy.handleDNSRequest // -> dnsforward.handleDNSRequest func (s *Server) handleDoH(w http.ResponseWriter, r *http.Request) { if !s.conf.TLSAllowUnencryptedDoH && r.TLS == nil { aghhttp.Error(r, w, http.StatusNotFound, "Not Found") return } if !s.IsRunning() { aghhttp.Error(r, w, http.StatusInternalServerError, "dns server is not running") return } s.ServeHTTP(w, r) } func (s *Server) registerHandlers() { if webRegistered || s.conf.HTTPRegister == nil { return } s.conf.HTTPRegister(http.MethodGet, "/control/dns_info", s.handleGetConfig) s.conf.HTTPRegister(http.MethodPost, "/control/dns_config", s.handleSetConfig) s.conf.HTTPRegister(http.MethodPost, "/control/test_upstream_dns", s.handleTestUpstreamDNS) s.conf.HTTPRegister(http.MethodPost, "/control/protection", s.handleSetProtection) s.conf.HTTPRegister(http.MethodGet, "/control/access/list", s.handleAccessList) s.conf.HTTPRegister(http.MethodPost, "/control/access/set", s.handleAccessSet) s.conf.HTTPRegister(http.MethodPost, "/control/cache_clear", s.handleCacheClear) // Register both versions, with and without the trailing slash, to // prevent a 301 Moved Permanently redirect when clients request the // path without the trailing slash. Those redirects break some clients. // // See go doc net/http.ServeMux. // // See also https://github.com/AdguardTeam/AdGuardHome/issues/2628. s.conf.HTTPRegister("", "/dns-query", s.handleDoH) s.conf.HTTPRegister("", "/dns-query/", s.handleDoH) webRegistered = true }