diff --git a/config.go b/config.go index 8a753a49..ec0843b2 100644 --- a/config.go +++ b/config.go @@ -27,7 +27,7 @@ type configuration struct { Filters []filter `yaml:"filters"` UserRules []string `yaml:"user_rules"` - sync.Mutex `yaml:"-"` + sync.RWMutex `yaml:"-"` } type coreDNSConfig struct { diff --git a/control.go b/control.go index d60c6509..cc902f3a 100644 --- a/control.go +++ b/control.go @@ -789,10 +789,11 @@ func handleFilteringStatus(w http.ResponseWriter, r *http.Request) { "enabled": config.CoreDNS.FilteringEnabled, } + config.RLock() data["filters"] = config.Filters data["user_rules"] = config.UserRules - json, err := json.Marshal(data) + config.RUnlock() if err != nil { errortext := fmt.Sprintf("Unable to marshal status json: %s", err) @@ -1122,7 +1123,6 @@ func runFilterRefreshers() { func refreshFiltersIfNeccessary() int { now := time.Now() config.Lock() - defer config.Unlock() // deduplicate // TODO: move it somewhere else @@ -1154,6 +1154,7 @@ func refreshFiltersIfNeccessary() int { updateCount++ } } + config.Unlock() if updateCount > 0 { err := writeFilterFile() @@ -1237,6 +1238,7 @@ func writeFilterFile() error { log.Printf("Writing filter file: %s", filterpath) // TODO: check if file contents have modified data := []byte{} + config.RLock() filters := config.Filters for _, filter := range filters { if !filter.Enabled { @@ -1249,6 +1251,7 @@ func writeFilterFile() error { data = append(data, []byte(rule)...) data = append(data, '\n') } + config.RUnlock() err := ioutil.WriteFile(filterpath+".tmp", data, 0644) if err != nil { log.Printf("Couldn't write filter file: %s", err) diff --git a/coredns_plugin/coredns_plugin.go b/coredns_plugin/coredns_plugin.go index 530e2c2b..dcdb50c8 100644 --- a/coredns_plugin/coredns_plugin.go +++ b/coredns_plugin/coredns_plugin.go @@ -55,6 +55,8 @@ type plug struct { ParentalBlockHost string QueryLogEnabled bool BlockedTTL uint32 // in seconds, default 3600 + + sync.RWMutex } var defaultPlugin = plug{ @@ -246,17 +248,21 @@ func (p *plug) parseEtcHosts(text string) bool { } func (p *plug) onShutdown() error { + p.Lock() p.d.Destroy() p.d = nil + p.Unlock() return nil } func (p *plug) onFinalShutdown() error { + logBufferLock.Lock() err := flushToFile(logBuffer) if err != nil { log.Printf("failed to flush to file: %s", err) return err } + logBufferLock.Unlock() return nil } @@ -293,9 +299,11 @@ func doStatsLookup(ch interface{}, doFunc statsFunc, name string, lookupstats *d } func (p *plug) doStats(ch interface{}, doFunc statsFunc) { + p.RLock() stats := p.d.GetStats() doStatsLookup(ch, doFunc, "safebrowsing", &stats.Safebrowsing) doStatsLookup(ch, doFunc, "parental", &stats.Parental) + p.RUnlock() } // Describe is called by prometheus handler to know stat types @@ -365,12 +373,12 @@ func (p *plug) genSOA(r *dns.Msg) []dns.RR { } Ns := "fake-for-negative-caching.adguard.com." - soa := defaultSOA + soa := *defaultSOA soa.Hdr = header soa.Mbox = Mbox soa.Ns = Ns - soa.Serial = uint32(time.Now().Unix()) - return []dns.RR{soa} + soa.Serial = 100500 // faster than uint32(time.Now().Unix()) + return []dns.RR{&soa} } func (p *plug) writeNXdomain(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { @@ -397,13 +405,17 @@ func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dn for _, question := range r.Question { host := strings.ToLower(strings.TrimSuffix(question.Name, ".")) // is it a safesearch domain? + p.RLock() if val, ok := p.d.SafeSearchDomain(host); ok { rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question) if err != nil { + p.RUnlock() return rcode, dnsfilter.Result{}, err } + p.RUnlock() return rcode, dnsfilter.Result{Reason: dnsfilter.FilteredSafeSearch}, err } + p.RUnlock() // is it in hosts? if val, ok := p.hosts[host]; ok { @@ -425,11 +437,14 @@ func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dn } // needs to be filtered instead + p.RLock() result, err := p.d.CheckHost(host) if err != nil { log.Printf("plugin/dnsfilter: %s\n", err) + p.RUnlock() return dns.RcodeServerFailure, dnsfilter.Result{}, fmt.Errorf("plugin/dnsfilter: %s", err) } + p.RUnlock() if result.IsFiltered { switch result.Reason { diff --git a/coredns_plugin/querylog.go b/coredns_plugin/querylog.go index 808450db..eb848fb8 100644 --- a/coredns_plugin/querylog.go +++ b/coredns_plugin/querylog.go @@ -10,6 +10,7 @@ import ( "runtime" "strconv" "strings" + "sync" "time" "github.com/AdguardTeam/AdguardDNS/dnsfilter" @@ -23,6 +24,7 @@ const ( ) var ( + logBufferLock sync.RWMutex logBuffer []logEntry ) @@ -65,11 +67,13 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela } var flushBuffer []logEntry + logBufferLock.Lock() logBuffer = append(logBuffer, entry) if len(logBuffer) >= logBufferCap { flushBuffer = logBuffer logBuffer = nil } + logBufferLock.Unlock() if len(flushBuffer) > 0 { // write to file // do it in separate goroutine -- we are stalling DNS response this whole time @@ -81,7 +85,9 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela func handleQueryLog(w http.ResponseWriter, r *http.Request) { // TODO: fetch values from disk if len(logBuffer) < queryLogSize // TODO: cache output + logBufferLock.RLock() values := logBuffer + logBufferLock.RUnlock() var data = []map[string]interface{}{} for _, entry := range values { var q *dns.Msg diff --git a/stats.go b/stats.go index 29f36fbf..a3231faf 100644 --- a/stats.go +++ b/stats.go @@ -12,6 +12,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "syscall" "time" ) @@ -57,6 +58,7 @@ type stats struct { PerDay periodicStats LastSeen statsEntry + sync.RWMutex } var statistics stats @@ -71,10 +73,12 @@ func init() { } func purgeStats() { + statistics.Lock() initPeriodicStats(&statistics.PerSecond) initPeriodicStats(&statistics.PerMinute) initPeriodicStats(&statistics.PerHour) initPeriodicStats(&statistics.PerDay) + statistics.Unlock() } func runStatsCollectors() { @@ -121,10 +125,12 @@ func statsRotate(periodic *periodicStats, now time.Time, rotations int64) { // called every second, accumulates stats for each second, minute, hour and day func collectStats() { now := time.Now() + statistics.Lock() statsRotate(&statistics.PerSecond, now, int64(now.Sub(statistics.PerSecond.LastRotate)/time.Second)) statsRotate(&statistics.PerMinute, now, int64(now.Sub(statistics.PerMinute.LastRotate)/time.Minute)) statsRotate(&statistics.PerHour, now, int64(now.Sub(statistics.PerHour.LastRotate)/time.Hour)) statsRotate(&statistics.PerDay, now, int64(now.Sub(statistics.PerDay.LastRotate)/time.Hour/24)) + statistics.Unlock() // grab HTTP from prometheus resp, err := client.Get("http://127.0.0.1:9153/metrics") @@ -191,6 +197,7 @@ func collectStats() { } // calculate delta + statistics.Lock() delta := calcDelta(entry, statistics.LastSeen) // apply delta to second/minute/hour/day @@ -201,6 +208,7 @@ func collectStats() { // save last seen statistics.LastSeen = entry + statistics.Unlock() } func calcDelta(current, seen statsEntry) statsEntry { @@ -245,7 +253,9 @@ func loadStats() error { func writeStats() error { statsFile := filepath.Join(config.ourBinaryDir, "stats.json") log.Printf("Writing JSON file: %s", statsFile) + statistics.RLock() json, err := json.MarshalIndent(statistics, "", " ") + statistics.RUnlock() if err != nil { log.Printf("Couldn't generate JSON: %s", err) return err