From 2244c21b765e089e49fede0cc27af7f9f76bae4d Mon Sep 17 00:00:00 2001
From: Eugene Bujak <hmage@hmage.net>
Date: Sun, 7 Oct 2018 00:58:59 +0300
Subject: [PATCH] Fix race conditions found by go's race detector

---
 config.go                        |  2 +-
 control.go                       |  7 +++++--
 coredns_plugin/coredns_plugin.go | 21 ++++++++++++++++++---
 coredns_plugin/querylog.go       |  6 ++++++
 stats.go                         | 10 ++++++++++
 5 files changed, 40 insertions(+), 6 deletions(-)

diff --git a/config.go b/config.go
index 8a753a49..ec0843b2 100644
--- a/config.go
+++ b/config.go
@@ -27,7 +27,7 @@ type configuration struct {
 	Filters   []filter      `yaml:"filters"`
 	UserRules []string      `yaml:"user_rules"`
 
-	sync.Mutex `yaml:"-"`
+	sync.RWMutex `yaml:"-"`
 }
 
 type coreDNSConfig struct {
diff --git a/control.go b/control.go
index d60c6509..cc902f3a 100644
--- a/control.go
+++ b/control.go
@@ -789,10 +789,11 @@ func handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
 		"enabled": config.CoreDNS.FilteringEnabled,
 	}
 
+	config.RLock()
 	data["filters"] = config.Filters
 	data["user_rules"] = config.UserRules
-
 	json, err := json.Marshal(data)
+	config.RUnlock()
 
 	if err != nil {
 		errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
@@ -1122,7 +1123,6 @@ func runFilterRefreshers() {
 func refreshFiltersIfNeccessary() int {
 	now := time.Now()
 	config.Lock()
-	defer config.Unlock()
 
 	// deduplicate
 	// TODO: move it somewhere else
@@ -1154,6 +1154,7 @@ func refreshFiltersIfNeccessary() int {
 			updateCount++
 		}
 	}
+	config.Unlock()
 
 	if updateCount > 0 {
 		err := writeFilterFile()
@@ -1237,6 +1238,7 @@ func writeFilterFile() error {
 	log.Printf("Writing filter file: %s", filterpath)
 	// TODO: check if file contents have modified
 	data := []byte{}
+	config.RLock()
 	filters := config.Filters
 	for _, filter := range filters {
 		if !filter.Enabled {
@@ -1249,6 +1251,7 @@ func writeFilterFile() error {
 		data = append(data, []byte(rule)...)
 		data = append(data, '\n')
 	}
+	config.RUnlock()
 	err := ioutil.WriteFile(filterpath+".tmp", data, 0644)
 	if err != nil {
 		log.Printf("Couldn't write filter file: %s", err)
diff --git a/coredns_plugin/coredns_plugin.go b/coredns_plugin/coredns_plugin.go
index 530e2c2b..dcdb50c8 100644
--- a/coredns_plugin/coredns_plugin.go
+++ b/coredns_plugin/coredns_plugin.go
@@ -55,6 +55,8 @@ type plug struct {
 	ParentalBlockHost     string
 	QueryLogEnabled       bool
 	BlockedTTL            uint32 // in seconds, default 3600
+
+	sync.RWMutex
 }
 
 var defaultPlugin = plug{
@@ -246,17 +248,21 @@ func (p *plug) parseEtcHosts(text string) bool {
 }
 
 func (p *plug) onShutdown() error {
+	p.Lock()
 	p.d.Destroy()
 	p.d = nil
+	p.Unlock()
 	return nil
 }
 
 func (p *plug) onFinalShutdown() error {
+	logBufferLock.Lock()
 	err := flushToFile(logBuffer)
 	if err != nil {
 		log.Printf("failed to flush to file: %s", err)
 		return err
 	}
+	logBufferLock.Unlock()
 	return nil
 }
 
@@ -293,9 +299,11 @@ func doStatsLookup(ch interface{}, doFunc statsFunc, name string, lookupstats *d
 }
 
 func (p *plug) doStats(ch interface{}, doFunc statsFunc) {
+	p.RLock()
 	stats := p.d.GetStats()
 	doStatsLookup(ch, doFunc, "safebrowsing", &stats.Safebrowsing)
 	doStatsLookup(ch, doFunc, "parental", &stats.Parental)
+	p.RUnlock()
 }
 
 // Describe is called by prometheus handler to know stat types
@@ -365,12 +373,12 @@ func (p *plug) genSOA(r *dns.Msg) []dns.RR {
 	}
 	Ns := "fake-for-negative-caching.adguard.com."
 
-	soa := defaultSOA
+	soa := *defaultSOA
 	soa.Hdr = header
 	soa.Mbox = Mbox
 	soa.Ns = Ns
-	soa.Serial = uint32(time.Now().Unix())
-	return []dns.RR{soa}
+	soa.Serial = 100500 // faster than uint32(time.Now().Unix())
+	return []dns.RR{&soa}
 }
 
 func (p *plug) writeNXdomain(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
@@ -397,13 +405,17 @@ func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dn
 	for _, question := range r.Question {
 		host := strings.ToLower(strings.TrimSuffix(question.Name, "."))
 		// is it a safesearch domain?
+		p.RLock()
 		if val, ok := p.d.SafeSearchDomain(host); ok {
 			rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question)
 			if err != nil {
+				p.RUnlock()
 				return rcode, dnsfilter.Result{}, err
 			}
+			p.RUnlock()
 			return rcode, dnsfilter.Result{Reason: dnsfilter.FilteredSafeSearch}, err
 		}
+		p.RUnlock()
 
 		// is it in hosts?
 		if val, ok := p.hosts[host]; ok {
@@ -425,11 +437,14 @@ func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dn
 		}
 
 		// needs to be filtered instead
+		p.RLock()
 		result, err := p.d.CheckHost(host)
 		if err != nil {
 			log.Printf("plugin/dnsfilter: %s\n", err)
+			p.RUnlock()
 			return dns.RcodeServerFailure, dnsfilter.Result{}, fmt.Errorf("plugin/dnsfilter: %s", err)
 		}
+		p.RUnlock()
 
 		if result.IsFiltered {
 			switch result.Reason {
diff --git a/coredns_plugin/querylog.go b/coredns_plugin/querylog.go
index 808450db..eb848fb8 100644
--- a/coredns_plugin/querylog.go
+++ b/coredns_plugin/querylog.go
@@ -10,6 +10,7 @@ import (
 	"runtime"
 	"strconv"
 	"strings"
+	"sync"
 	"time"
 
 	"github.com/AdguardTeam/AdguardDNS/dnsfilter"
@@ -23,6 +24,7 @@ const (
 )
 
 var (
+	logBufferLock sync.RWMutex
 	logBuffer     []logEntry
 )
 
@@ -65,11 +67,13 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela
 	}
 	var flushBuffer []logEntry
 
+	logBufferLock.Lock()
 	logBuffer = append(logBuffer, entry)
 	if len(logBuffer) >= logBufferCap {
 		flushBuffer = logBuffer
 		logBuffer = nil
 	}
+	logBufferLock.Unlock()
 	if len(flushBuffer) > 0 {
 		// write to file
 		// do it in separate goroutine -- we are stalling DNS response this whole time
@@ -81,7 +85,9 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela
 func handleQueryLog(w http.ResponseWriter, r *http.Request) {
 	// TODO: fetch values from disk if len(logBuffer) < queryLogSize
 	// TODO: cache output
+	logBufferLock.RLock()
 	values := logBuffer
+	logBufferLock.RUnlock()
 	var data = []map[string]interface{}{}
 	for _, entry := range values {
 		var q *dns.Msg
diff --git a/stats.go b/stats.go
index 29f36fbf..a3231faf 100644
--- a/stats.go
+++ b/stats.go
@@ -12,6 +12,7 @@ import (
 	"path/filepath"
 	"strconv"
 	"strings"
+	"sync"
 	"syscall"
 	"time"
 )
@@ -57,6 +58,7 @@ type stats struct {
 	PerDay    periodicStats
 
 	LastSeen statsEntry
+	sync.RWMutex
 }
 
 var statistics stats
@@ -71,10 +73,12 @@ func init() {
 }
 
 func purgeStats() {
+	statistics.Lock()
 	initPeriodicStats(&statistics.PerSecond)
 	initPeriodicStats(&statistics.PerMinute)
 	initPeriodicStats(&statistics.PerHour)
 	initPeriodicStats(&statistics.PerDay)
+	statistics.Unlock()
 }
 
 func runStatsCollectors() {
@@ -121,10 +125,12 @@ func statsRotate(periodic *periodicStats, now time.Time, rotations int64) {
 // called every second, accumulates stats for each second, minute, hour and day
 func collectStats() {
 	now := time.Now()
+	statistics.Lock()
 	statsRotate(&statistics.PerSecond, now, int64(now.Sub(statistics.PerSecond.LastRotate)/time.Second))
 	statsRotate(&statistics.PerMinute, now, int64(now.Sub(statistics.PerMinute.LastRotate)/time.Minute))
 	statsRotate(&statistics.PerHour, now, int64(now.Sub(statistics.PerHour.LastRotate)/time.Hour))
 	statsRotate(&statistics.PerDay, now, int64(now.Sub(statistics.PerDay.LastRotate)/time.Hour/24))
+	statistics.Unlock()
 
 	// grab HTTP from prometheus
 	resp, err := client.Get("http://127.0.0.1:9153/metrics")
@@ -191,6 +197,7 @@ func collectStats() {
 	}
 
 	// calculate delta
+	statistics.Lock()
 	delta := calcDelta(entry, statistics.LastSeen)
 
 	// apply delta to second/minute/hour/day
@@ -201,6 +208,7 @@ func collectStats() {
 
 	// save last seen
 	statistics.LastSeen = entry
+	statistics.Unlock()
 }
 
 func calcDelta(current, seen statsEntry) statsEntry {
@@ -245,7 +253,9 @@ func loadStats() error {
 func writeStats() error {
 	statsFile := filepath.Join(config.ourBinaryDir, "stats.json")
 	log.Printf("Writing JSON file: %s", statsFile)
+	statistics.RLock()
 	json, err := json.MarshalIndent(statistics, "", "  ")
+	statistics.RUnlock()
 	if err != nil {
 		log.Printf("Couldn't generate JSON: %s", err)
 		return err