mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-22 21:15:35 +03:00
* dnsfilter: use a single global context object
This commit is contained in:
parent
f1e6a30931
commit
2307f55715
1 changed files with 30 additions and 23 deletions
|
@ -128,14 +128,15 @@ const (
|
|||
FilteredSafeSearch
|
||||
)
|
||||
|
||||
// these variables need to survive coredns reload
|
||||
var (
|
||||
type dnsfContext struct {
|
||||
stats Stats
|
||||
dialCache gcache.Cache // "host" -> "IP" cache for safebrowsing and parental control servers
|
||||
safebrowsingCache gcache.Cache
|
||||
parentalCache gcache.Cache
|
||||
safeSearchCache gcache.Cache
|
||||
)
|
||||
}
|
||||
|
||||
var gctx dnsfContext // global dnsfilter context
|
||||
|
||||
// Result holds state of hostname check
|
||||
type Result struct {
|
||||
|
@ -298,14 +299,10 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
|
|||
defer timer.LogElapsed("SafeSearch HTTP lookup for %s", host)
|
||||
}
|
||||
|
||||
if safeSearchCache == nil {
|
||||
safeSearchCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
|
||||
}
|
||||
|
||||
// Check cache. Return cached result if it was found
|
||||
cachedValue, isFound, err := getCachedReason(safeSearchCache, host)
|
||||
cachedValue, isFound, err := getCachedReason(gctx.safeSearchCache, host)
|
||||
if isFound {
|
||||
atomic.AddUint64(&stats.Safesearch.CacheHits, 1)
|
||||
atomic.AddUint64(&gctx.stats.Safesearch.CacheHits, 1)
|
||||
log.Tracef("%s: found in SafeSearch cache", host)
|
||||
return cachedValue, nil
|
||||
}
|
||||
|
@ -322,7 +319,7 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
|
|||
res := Result{IsFiltered: true, Reason: FilteredSafeSearch}
|
||||
if ip := net.ParseIP(safeHost); ip != nil {
|
||||
res.IP = ip
|
||||
err = safeSearchCache.Set(host, res)
|
||||
err = gctx.safeSearchCache.Set(host, res)
|
||||
if err != nil {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
@ -349,7 +346,7 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
|
|||
}
|
||||
|
||||
// Cache result
|
||||
err = safeSearchCache.Set(host, res)
|
||||
err = gctx.safeSearchCache.Set(host, res)
|
||||
if err != nil {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
@ -395,10 +392,7 @@ func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) {
|
|||
}
|
||||
return result, nil
|
||||
}
|
||||
if safebrowsingCache == nil {
|
||||
safebrowsingCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
|
||||
}
|
||||
result, err := d.lookupCommon(host, &stats.Safebrowsing, safebrowsingCache, true, format, handleBody)
|
||||
result, err := d.lookupCommon(host, &gctx.stats.Safebrowsing, gctx.safebrowsingCache, true, format, handleBody)
|
||||
return result, err
|
||||
}
|
||||
|
||||
|
@ -450,10 +444,7 @@ func (d *Dnsfilter) checkParental(host string) (Result, error) {
|
|||
}
|
||||
return result, nil
|
||||
}
|
||||
if parentalCache == nil {
|
||||
parentalCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
|
||||
}
|
||||
result, err := d.lookupCommon(host, &stats.Parental, parentalCache, false, format, handleBody)
|
||||
result, err := d.lookupCommon(host, &gctx.stats.Parental, gctx.parentalCache, false, format, handleBody)
|
||||
return result, err
|
||||
}
|
||||
|
||||
|
@ -620,7 +611,7 @@ func (d *Dnsfilter) shouldBeInDialCache(host string) bool {
|
|||
|
||||
// Search for an IP address by host name
|
||||
func searchInDialCache(host string) string {
|
||||
rawValue, err := dialCache.Get(host)
|
||||
rawValue, err := gctx.dialCache.Get(host)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
@ -632,7 +623,7 @@ func searchInDialCache(host string) string {
|
|||
|
||||
// Add "hostname" -> "IP address" entry to cache
|
||||
func addToDialCache(host, ip string) {
|
||||
err := dialCache.Set(host, ip)
|
||||
err := gctx.dialCache.Set(host, ip)
|
||||
if err != nil {
|
||||
log.Debug("dialCache.Set: %s", err)
|
||||
}
|
||||
|
@ -701,6 +692,23 @@ func (d *Dnsfilter) createCustomDialContext(resolverAddr string) dialFunctionTyp
|
|||
|
||||
// New creates properly initialized DNS Filter that is ready to be used
|
||||
func New(c *Config, filters map[int]string) *Dnsfilter {
|
||||
|
||||
if c != nil {
|
||||
// initialize objects only once
|
||||
if c.SafeBrowsingEnabled && gctx.safebrowsingCache == nil {
|
||||
gctx.safebrowsingCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
|
||||
}
|
||||
if c.SafeSearchEnabled && gctx.safeSearchCache == nil {
|
||||
gctx.safeSearchCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
|
||||
}
|
||||
if c.ParentalEnabled && gctx.parentalCache == nil {
|
||||
gctx.parentalCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
|
||||
}
|
||||
if len(c.ResolverAddress) != 0 && gctx.dialCache == nil {
|
||||
gctx.dialCache = gcache.New(maxDialCacheSize).LRU().Expiration(defaultCacheTime).Build()
|
||||
}
|
||||
}
|
||||
|
||||
d := new(Dnsfilter)
|
||||
|
||||
// Customize the Transport to have larger connection pool,
|
||||
|
@ -714,7 +722,6 @@ func New(c *Config, filters map[int]string) *Dnsfilter {
|
|||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
if c != nil && len(c.ResolverAddress) != 0 {
|
||||
dialCache = gcache.New(maxDialCacheSize).LRU().Expiration(defaultCacheTime).Build()
|
||||
d.transport.DialContext = d.createCustomDialContext(c.ResolverAddress)
|
||||
}
|
||||
d.client = http.Client{
|
||||
|
@ -790,5 +797,5 @@ func (d *Dnsfilter) SafeSearchDomain(host string) (string, bool) {
|
|||
|
||||
// GetStats return dns filtering stats since startup
|
||||
func (d *Dnsfilter) GetStats() Stats {
|
||||
return stats
|
||||
return gctx.stats
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue