diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 0af717be..a2cf5b65 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -113,7 +113,7 @@ type FilteringConfig struct { FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"` // This callback function returns the list of upstream servers for a client specified by IP address - GetUpstreamsByClient func(clientAddr string) []string `yaml:"-"` + GetUpstreamsByClient func(clientAddr string) []upstream.Upstream `yaml:"-"` ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features @@ -465,13 +465,9 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { if d.Addr != nil && s.conf.GetUpstreamsByClient != nil { clientIP := ipFromAddr(d.Addr) upstreams := s.conf.GetUpstreamsByClient(clientIP) - for _, us := range upstreams { - u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: 30 * time.Second}) - if err != nil { - log.Error("upstream.AddressToUpstream: %s: %s", us, err) - continue - } - d.Upstreams = append(d.Upstreams, u) + if len(upstreams) > 0 { + log.Debug("Using custom upstreams for %s", clientIP) + d.Upstreams = upstreams } } diff --git a/home/clients.go b/home/clients.go index 468cdfe0..8769ebc3 100644 --- a/home/clients.go +++ b/home/clients.go @@ -14,6 +14,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/dhcpd" "github.com/AdguardTeam/AdGuardHome/dnsforward" + "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/utils" ) @@ -62,8 +63,14 @@ type clientsContainer struct { list map[string]*Client // name -> client idIndex map[string]*Client // IP -> client ipHost map[string]*ClientHost // IP -> Hostname - lock sync.Mutex + // cache for Upstream instances that are used in the case + // when custom DNS servers are configured for a client + upstreamsCache map[string][]upstream.Upstream // name -> []Upstream + + lock sync.Mutex + + // dhcpServer is used for looking up clients IP addresses by MAC addresses dhcpServer *dhcpd.Server testing bool // if TRUE, this object is used for internal tests @@ -78,6 +85,7 @@ func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd. clients.list = make(map[string]*Client) clients.idIndex = make(map[string]*Client) clients.ipHost = make(map[string]*ClientHost) + clients.upstreamsCache = make(map[string][]upstream.Upstream) clients.dhcpServer = dhcpServer clients.addFromConfig(objects) @@ -191,6 +199,45 @@ func (clients *clientsContainer) Find(ip string) (Client, bool) { return clients.findByIP(ip) } +// FindUpstreams looks for upstreams configured for the client +// If no client found for this IP, or if no custom upstreams are configured, +// this method returns nil +func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream { + clients.lock.Lock() + defer clients.lock.Unlock() + + c, ok := clients.findByIP(ip) + if !ok { + return nil + } + + if len(c.Upstreams) == 0 { + return nil + } + + upstreams, ok := clients.upstreamsCache[c.Name] + if ok { + return upstreams + } + + for _, us := range c.Upstreams { + u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout}) + if err != nil { + log.Error("upstream.AddressToUpstream: %s: %s", us, err) + continue + } + upstreams = append(upstreams, u) + } + + if len(upstreams) == 0 { + clients.upstreamsCache[c.Name] = nil + } else { + clients.upstreamsCache[c.Name] = upstreams + } + + return upstreams +} + // Find searches for a client by IP (and does not lock anything) func (clients *clientsContainer) findByIP(ip string) (Client, bool) { ipAddr := net.ParseIP(ip) @@ -355,6 +402,9 @@ func (clients *clientsContainer) Del(name string) bool { // update Name index delete(clients.list, name) + // update upstreams cache + delete(clients.upstreamsCache, name) + // update ID index for _, id := range c.IDs { delete(clients.idIndex, id) @@ -418,10 +468,13 @@ func (clients *clientsContainer) Update(name string, c Client) error { // update Name index if old.Name != c.Name { - delete(clients.list, old.Name) clients.list[c.Name] = old } + // update upstreams cache + delete(clients.upstreamsCache, name) + delete(clients.upstreamsCache, old.Name) + *old = c return nil } diff --git a/home/dns.go b/home/dns.go index 1a0666fb..f760ff77 100644 --- a/home/dns.go +++ b/home/dns.go @@ -11,6 +11,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/querylog" "github.com/AdguardTeam/AdGuardHome/stats" "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" "github.com/joomcode/errorx" ) @@ -178,18 +179,12 @@ func generateServerConfig() dnsforward.ServerConfig { return newconfig } -func getUpstreamsByClient(clientAddr string) []string { - c, ok := Context.clients.Find(clientAddr) - if !ok { - return []string{} - } - log.Debug("Using upstreams %v for client %s (IP: %s)", c.Upstreams, c.Name, clientAddr) - return c.Upstreams +func getUpstreamsByClient(clientAddr string) []upstream.Upstream { + return Context.clients.FindUpstreams(clientAddr) } // If a client has his own settings, apply them func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) { - ApplyBlockedServices(setts, config.DNS.BlockedServices) if len(clientAddr) == 0 {