*(dnsforward): cache upstream instances

 Closes: https://github.com/AdguardTeam/AdGuardHome/issues/1296
This commit is contained in:
Andrey Meshkov 2019-12-23 19:31:27 +03:00
parent 3dd91cf179
commit cdd55139fa
3 changed files with 62 additions and 18 deletions

View file

@ -113,7 +113,7 @@ type FilteringConfig struct {
FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"` FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
// This callback function returns the list of upstream servers for a client specified by IP address // This callback function returns the list of upstream servers for a client specified by IP address
GetUpstreamsByClient func(clientAddr string) []string `yaml:"-"` GetUpstreamsByClient func(clientAddr string) []upstream.Upstream `yaml:"-"`
ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features
@ -465,13 +465,9 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
if d.Addr != nil && s.conf.GetUpstreamsByClient != nil { if d.Addr != nil && s.conf.GetUpstreamsByClient != nil {
clientIP := ipFromAddr(d.Addr) clientIP := ipFromAddr(d.Addr)
upstreams := s.conf.GetUpstreamsByClient(clientIP) upstreams := s.conf.GetUpstreamsByClient(clientIP)
for _, us := range upstreams { if len(upstreams) > 0 {
u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: 30 * time.Second}) log.Debug("Using custom upstreams for %s", clientIP)
if err != nil { d.Upstreams = upstreams
log.Error("upstream.AddressToUpstream: %s: %s", us, err)
continue
}
d.Upstreams = append(d.Upstreams, u)
} }
} }

View file

@ -14,6 +14,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/dhcpd" "github.com/AdguardTeam/AdGuardHome/dhcpd"
"github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/utils" "github.com/AdguardTeam/golibs/utils"
) )
@ -62,8 +63,14 @@ type clientsContainer struct {
list map[string]*Client // name -> client list map[string]*Client // name -> client
idIndex map[string]*Client // IP -> client idIndex map[string]*Client // IP -> client
ipHost map[string]*ClientHost // IP -> Hostname ipHost map[string]*ClientHost // IP -> Hostname
lock sync.Mutex
// cache for Upstream instances that are used in the case
// when custom DNS servers are configured for a client
upstreamsCache map[string][]upstream.Upstream // name -> []Upstream
lock sync.Mutex
// dhcpServer is used for looking up clients IP addresses by MAC addresses
dhcpServer *dhcpd.Server dhcpServer *dhcpd.Server
testing bool // if TRUE, this object is used for internal tests testing bool // if TRUE, this object is used for internal tests
@ -78,6 +85,7 @@ func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd.
clients.list = make(map[string]*Client) clients.list = make(map[string]*Client)
clients.idIndex = make(map[string]*Client) clients.idIndex = make(map[string]*Client)
clients.ipHost = make(map[string]*ClientHost) clients.ipHost = make(map[string]*ClientHost)
clients.upstreamsCache = make(map[string][]upstream.Upstream)
clients.dhcpServer = dhcpServer clients.dhcpServer = dhcpServer
clients.addFromConfig(objects) clients.addFromConfig(objects)
@ -191,6 +199,45 @@ func (clients *clientsContainer) Find(ip string) (Client, bool) {
return clients.findByIP(ip) return clients.findByIP(ip)
} }
// FindUpstreams looks for upstreams configured for the client
// If no client found for this IP, or if no custom upstreams are configured,
// this method returns nil
func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream {
clients.lock.Lock()
defer clients.lock.Unlock()
c, ok := clients.findByIP(ip)
if !ok {
return nil
}
if len(c.Upstreams) == 0 {
return nil
}
upstreams, ok := clients.upstreamsCache[c.Name]
if ok {
return upstreams
}
for _, us := range c.Upstreams {
u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout})
if err != nil {
log.Error("upstream.AddressToUpstream: %s: %s", us, err)
continue
}
upstreams = append(upstreams, u)
}
if len(upstreams) == 0 {
clients.upstreamsCache[c.Name] = nil
} else {
clients.upstreamsCache[c.Name] = upstreams
}
return upstreams
}
// Find searches for a client by IP (and does not lock anything) // Find searches for a client by IP (and does not lock anything)
func (clients *clientsContainer) findByIP(ip string) (Client, bool) { func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
ipAddr := net.ParseIP(ip) ipAddr := net.ParseIP(ip)
@ -355,6 +402,9 @@ func (clients *clientsContainer) Del(name string) bool {
// update Name index // update Name index
delete(clients.list, name) delete(clients.list, name)
// update upstreams cache
delete(clients.upstreamsCache, name)
// update ID index // update ID index
for _, id := range c.IDs { for _, id := range c.IDs {
delete(clients.idIndex, id) delete(clients.idIndex, id)
@ -418,10 +468,13 @@ func (clients *clientsContainer) Update(name string, c Client) error {
// update Name index // update Name index
if old.Name != c.Name { if old.Name != c.Name {
delete(clients.list, old.Name)
clients.list[c.Name] = old clients.list[c.Name] = old
} }
// update upstreams cache
delete(clients.upstreamsCache, name)
delete(clients.upstreamsCache, old.Name)
*old = c *old = c
return nil return nil
} }

View file

@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/querylog" "github.com/AdguardTeam/AdGuardHome/querylog"
"github.com/AdguardTeam/AdGuardHome/stats" "github.com/AdguardTeam/AdGuardHome/stats"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/joomcode/errorx" "github.com/joomcode/errorx"
) )
@ -178,18 +179,12 @@ func generateServerConfig() dnsforward.ServerConfig {
return newconfig return newconfig
} }
func getUpstreamsByClient(clientAddr string) []string { func getUpstreamsByClient(clientAddr string) []upstream.Upstream {
c, ok := Context.clients.Find(clientAddr) return Context.clients.FindUpstreams(clientAddr)
if !ok {
return []string{}
}
log.Debug("Using upstreams %v for client %s (IP: %s)", c.Upstreams, c.Name, clientAddr)
return c.Upstreams
} }
// If a client has his own settings, apply them // If a client has his own settings, apply them
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) { func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
ApplyBlockedServices(setts, config.DNS.BlockedServices) ApplyBlockedServices(setts, config.DNS.BlockedServices)
if len(clientAddr) == 0 { if len(clientAddr) == 0 {