mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-25 14:35:48 +03:00
*(dnsforward): cache upstream instances
✅ Closes: https://github.com/AdguardTeam/AdGuardHome/issues/1296
This commit is contained in:
parent
3dd91cf179
commit
cdd55139fa
3 changed files with 62 additions and 18 deletions
|
@ -113,7 +113,7 @@ type FilteringConfig struct {
|
|||
FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
|
||||
|
||||
// This callback function returns the list of upstream servers for a client specified by IP address
|
||||
GetUpstreamsByClient func(clientAddr string) []string `yaml:"-"`
|
||||
GetUpstreamsByClient func(clientAddr string) []upstream.Upstream `yaml:"-"`
|
||||
|
||||
ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features
|
||||
|
||||
|
@ -465,13 +465,9 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
|||
if d.Addr != nil && s.conf.GetUpstreamsByClient != nil {
|
||||
clientIP := ipFromAddr(d.Addr)
|
||||
upstreams := s.conf.GetUpstreamsByClient(clientIP)
|
||||
for _, us := range upstreams {
|
||||
u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: 30 * time.Second})
|
||||
if err != nil {
|
||||
log.Error("upstream.AddressToUpstream: %s: %s", us, err)
|
||||
continue
|
||||
}
|
||||
d.Upstreams = append(d.Upstreams, u)
|
||||
if len(upstreams) > 0 {
|
||||
log.Debug("Using custom upstreams for %s", clientIP)
|
||||
d.Upstreams = upstreams
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/utils"
|
||||
)
|
||||
|
@ -62,8 +63,14 @@ type clientsContainer struct {
|
|||
list map[string]*Client // name -> client
|
||||
idIndex map[string]*Client // IP -> client
|
||||
ipHost map[string]*ClientHost // IP -> Hostname
|
||||
lock sync.Mutex
|
||||
|
||||
// cache for Upstream instances that are used in the case
|
||||
// when custom DNS servers are configured for a client
|
||||
upstreamsCache map[string][]upstream.Upstream // name -> []Upstream
|
||||
|
||||
lock sync.Mutex
|
||||
|
||||
// dhcpServer is used for looking up clients IP addresses by MAC addresses
|
||||
dhcpServer *dhcpd.Server
|
||||
|
||||
testing bool // if TRUE, this object is used for internal tests
|
||||
|
@ -78,6 +85,7 @@ func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd.
|
|||
clients.list = make(map[string]*Client)
|
||||
clients.idIndex = make(map[string]*Client)
|
||||
clients.ipHost = make(map[string]*ClientHost)
|
||||
clients.upstreamsCache = make(map[string][]upstream.Upstream)
|
||||
clients.dhcpServer = dhcpServer
|
||||
clients.addFromConfig(objects)
|
||||
|
||||
|
@ -191,6 +199,45 @@ func (clients *clientsContainer) Find(ip string) (Client, bool) {
|
|||
return clients.findByIP(ip)
|
||||
}
|
||||
|
||||
// FindUpstreams looks for upstreams configured for the client
|
||||
// If no client found for this IP, or if no custom upstreams are configured,
|
||||
// this method returns nil
|
||||
func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
c, ok := clients.findByIP(ip)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(c.Upstreams) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
upstreams, ok := clients.upstreamsCache[c.Name]
|
||||
if ok {
|
||||
return upstreams
|
||||
}
|
||||
|
||||
for _, us := range c.Upstreams {
|
||||
u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout})
|
||||
if err != nil {
|
||||
log.Error("upstream.AddressToUpstream: %s: %s", us, err)
|
||||
continue
|
||||
}
|
||||
upstreams = append(upstreams, u)
|
||||
}
|
||||
|
||||
if len(upstreams) == 0 {
|
||||
clients.upstreamsCache[c.Name] = nil
|
||||
} else {
|
||||
clients.upstreamsCache[c.Name] = upstreams
|
||||
}
|
||||
|
||||
return upstreams
|
||||
}
|
||||
|
||||
// Find searches for a client by IP (and does not lock anything)
|
||||
func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
|
||||
ipAddr := net.ParseIP(ip)
|
||||
|
@ -355,6 +402,9 @@ func (clients *clientsContainer) Del(name string) bool {
|
|||
// update Name index
|
||||
delete(clients.list, name)
|
||||
|
||||
// update upstreams cache
|
||||
delete(clients.upstreamsCache, name)
|
||||
|
||||
// update ID index
|
||||
for _, id := range c.IDs {
|
||||
delete(clients.idIndex, id)
|
||||
|
@ -418,10 +468,13 @@ func (clients *clientsContainer) Update(name string, c Client) error {
|
|||
|
||||
// update Name index
|
||||
if old.Name != c.Name {
|
||||
delete(clients.list, old.Name)
|
||||
clients.list[c.Name] = old
|
||||
}
|
||||
|
||||
// update upstreams cache
|
||||
delete(clients.upstreamsCache, name)
|
||||
delete(clients.upstreamsCache, old.Name)
|
||||
|
||||
*old = c
|
||||
return nil
|
||||
}
|
||||
|
|
11
home/dns.go
11
home/dns.go
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/AdguardTeam/AdGuardHome/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/joomcode/errorx"
|
||||
)
|
||||
|
@ -178,18 +179,12 @@ func generateServerConfig() dnsforward.ServerConfig {
|
|||
return newconfig
|
||||
}
|
||||
|
||||
func getUpstreamsByClient(clientAddr string) []string {
|
||||
c, ok := Context.clients.Find(clientAddr)
|
||||
if !ok {
|
||||
return []string{}
|
||||
}
|
||||
log.Debug("Using upstreams %v for client %s (IP: %s)", c.Upstreams, c.Name, clientAddr)
|
||||
return c.Upstreams
|
||||
func getUpstreamsByClient(clientAddr string) []upstream.Upstream {
|
||||
return Context.clients.FindUpstreams(clientAddr)
|
||||
}
|
||||
|
||||
// If a client has his own settings, apply them
|
||||
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
|
||||
|
||||
ApplyBlockedServices(setts, config.DNS.BlockedServices)
|
||||
|
||||
if len(clientAddr) == 0 {
|
||||
|
|
Loading…
Reference in a new issue