Fix race conditions found by go's race detector

This commit is contained in:
Eugene Bujak 2018-10-07 00:58:59 +03:00
parent 2c33905a79
commit 2244c21b76
5 changed files with 40 additions and 6 deletions

View file

@ -27,7 +27,7 @@ type configuration struct {
Filters []filter `yaml:"filters"` Filters []filter `yaml:"filters"`
UserRules []string `yaml:"user_rules"` UserRules []string `yaml:"user_rules"`
sync.Mutex `yaml:"-"` sync.RWMutex `yaml:"-"`
} }
type coreDNSConfig struct { type coreDNSConfig struct {

View file

@ -789,10 +789,11 @@ func handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
"enabled": config.CoreDNS.FilteringEnabled, "enabled": config.CoreDNS.FilteringEnabled,
} }
config.RLock()
data["filters"] = config.Filters data["filters"] = config.Filters
data["user_rules"] = config.UserRules data["user_rules"] = config.UserRules
json, err := json.Marshal(data) json, err := json.Marshal(data)
config.RUnlock()
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
@ -1122,7 +1123,6 @@ func runFilterRefreshers() {
func refreshFiltersIfNeccessary() int { func refreshFiltersIfNeccessary() int {
now := time.Now() now := time.Now()
config.Lock() config.Lock()
defer config.Unlock()
// deduplicate // deduplicate
// TODO: move it somewhere else // TODO: move it somewhere else
@ -1154,6 +1154,7 @@ func refreshFiltersIfNeccessary() int {
updateCount++ updateCount++
} }
} }
config.Unlock()
if updateCount > 0 { if updateCount > 0 {
err := writeFilterFile() err := writeFilterFile()
@ -1237,6 +1238,7 @@ func writeFilterFile() error {
log.Printf("Writing filter file: %s", filterpath) log.Printf("Writing filter file: %s", filterpath)
// TODO: check if file contents have modified // TODO: check if file contents have modified
data := []byte{} data := []byte{}
config.RLock()
filters := config.Filters filters := config.Filters
for _, filter := range filters { for _, filter := range filters {
if !filter.Enabled { if !filter.Enabled {
@ -1249,6 +1251,7 @@ func writeFilterFile() error {
data = append(data, []byte(rule)...) data = append(data, []byte(rule)...)
data = append(data, '\n') data = append(data, '\n')
} }
config.RUnlock()
err := ioutil.WriteFile(filterpath+".tmp", data, 0644) err := ioutil.WriteFile(filterpath+".tmp", data, 0644)
if err != nil { if err != nil {
log.Printf("Couldn't write filter file: %s", err) log.Printf("Couldn't write filter file: %s", err)

View file

@ -55,6 +55,8 @@ type plug struct {
ParentalBlockHost string ParentalBlockHost string
QueryLogEnabled bool QueryLogEnabled bool
BlockedTTL uint32 // in seconds, default 3600 BlockedTTL uint32 // in seconds, default 3600
sync.RWMutex
} }
var defaultPlugin = plug{ var defaultPlugin = plug{
@ -246,17 +248,21 @@ func (p *plug) parseEtcHosts(text string) bool {
} }
func (p *plug) onShutdown() error { func (p *plug) onShutdown() error {
p.Lock()
p.d.Destroy() p.d.Destroy()
p.d = nil p.d = nil
p.Unlock()
return nil return nil
} }
func (p *plug) onFinalShutdown() error { func (p *plug) onFinalShutdown() error {
logBufferLock.Lock()
err := flushToFile(logBuffer) err := flushToFile(logBuffer)
if err != nil { if err != nil {
log.Printf("failed to flush to file: %s", err) log.Printf("failed to flush to file: %s", err)
return err return err
} }
logBufferLock.Unlock()
return nil return nil
} }
@ -293,9 +299,11 @@ func doStatsLookup(ch interface{}, doFunc statsFunc, name string, lookupstats *d
} }
func (p *plug) doStats(ch interface{}, doFunc statsFunc) { func (p *plug) doStats(ch interface{}, doFunc statsFunc) {
p.RLock()
stats := p.d.GetStats() stats := p.d.GetStats()
doStatsLookup(ch, doFunc, "safebrowsing", &stats.Safebrowsing) doStatsLookup(ch, doFunc, "safebrowsing", &stats.Safebrowsing)
doStatsLookup(ch, doFunc, "parental", &stats.Parental) doStatsLookup(ch, doFunc, "parental", &stats.Parental)
p.RUnlock()
} }
// Describe is called by prometheus handler to know stat types // Describe is called by prometheus handler to know stat types
@ -365,12 +373,12 @@ func (p *plug) genSOA(r *dns.Msg) []dns.RR {
} }
Ns := "fake-for-negative-caching.adguard.com." Ns := "fake-for-negative-caching.adguard.com."
soa := defaultSOA soa := *defaultSOA
soa.Hdr = header soa.Hdr = header
soa.Mbox = Mbox soa.Mbox = Mbox
soa.Ns = Ns soa.Ns = Ns
soa.Serial = uint32(time.Now().Unix()) soa.Serial = 100500 // faster than uint32(time.Now().Unix())
return []dns.RR{soa} return []dns.RR{&soa}
} }
func (p *plug) writeNXdomain(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { func (p *plug) writeNXdomain(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
@ -397,13 +405,17 @@ func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dn
for _, question := range r.Question { for _, question := range r.Question {
host := strings.ToLower(strings.TrimSuffix(question.Name, ".")) host := strings.ToLower(strings.TrimSuffix(question.Name, "."))
// is it a safesearch domain? // is it a safesearch domain?
p.RLock()
if val, ok := p.d.SafeSearchDomain(host); ok { if val, ok := p.d.SafeSearchDomain(host); ok {
rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question) rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question)
if err != nil { if err != nil {
p.RUnlock()
return rcode, dnsfilter.Result{}, err return rcode, dnsfilter.Result{}, err
} }
p.RUnlock()
return rcode, dnsfilter.Result{Reason: dnsfilter.FilteredSafeSearch}, err return rcode, dnsfilter.Result{Reason: dnsfilter.FilteredSafeSearch}, err
} }
p.RUnlock()
// is it in hosts? // is it in hosts?
if val, ok := p.hosts[host]; ok { if val, ok := p.hosts[host]; ok {
@ -425,11 +437,14 @@ func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dn
} }
// needs to be filtered instead // needs to be filtered instead
p.RLock()
result, err := p.d.CheckHost(host) result, err := p.d.CheckHost(host)
if err != nil { if err != nil {
log.Printf("plugin/dnsfilter: %s\n", err) log.Printf("plugin/dnsfilter: %s\n", err)
p.RUnlock()
return dns.RcodeServerFailure, dnsfilter.Result{}, fmt.Errorf("plugin/dnsfilter: %s", err) return dns.RcodeServerFailure, dnsfilter.Result{}, fmt.Errorf("plugin/dnsfilter: %s", err)
} }
p.RUnlock()
if result.IsFiltered { if result.IsFiltered {
switch result.Reason { switch result.Reason {

View file

@ -10,6 +10,7 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/AdguardTeam/AdguardDNS/dnsfilter" "github.com/AdguardTeam/AdguardDNS/dnsfilter"
@ -23,6 +24,7 @@ const (
) )
var ( var (
logBufferLock sync.RWMutex
logBuffer []logEntry logBuffer []logEntry
) )
@ -65,11 +67,13 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela
} }
var flushBuffer []logEntry var flushBuffer []logEntry
logBufferLock.Lock()
logBuffer = append(logBuffer, entry) logBuffer = append(logBuffer, entry)
if len(logBuffer) >= logBufferCap { if len(logBuffer) >= logBufferCap {
flushBuffer = logBuffer flushBuffer = logBuffer
logBuffer = nil logBuffer = nil
} }
logBufferLock.Unlock()
if len(flushBuffer) > 0 { if len(flushBuffer) > 0 {
// write to file // write to file
// do it in separate goroutine -- we are stalling DNS response this whole time // do it in separate goroutine -- we are stalling DNS response this whole time
@ -81,7 +85,9 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela
func handleQueryLog(w http.ResponseWriter, r *http.Request) { func handleQueryLog(w http.ResponseWriter, r *http.Request) {
// TODO: fetch values from disk if len(logBuffer) < queryLogSize // TODO: fetch values from disk if len(logBuffer) < queryLogSize
// TODO: cache output // TODO: cache output
logBufferLock.RLock()
values := logBuffer values := logBuffer
logBufferLock.RUnlock()
var data = []map[string]interface{}{} var data = []map[string]interface{}{}
for _, entry := range values { for _, entry := range values {
var q *dns.Msg var q *dns.Msg

View file

@ -12,6 +12,7 @@ import (
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"sync"
"syscall" "syscall"
"time" "time"
) )
@ -57,6 +58,7 @@ type stats struct {
PerDay periodicStats PerDay periodicStats
LastSeen statsEntry LastSeen statsEntry
sync.RWMutex
} }
var statistics stats var statistics stats
@ -71,10 +73,12 @@ func init() {
} }
func purgeStats() { func purgeStats() {
statistics.Lock()
initPeriodicStats(&statistics.PerSecond) initPeriodicStats(&statistics.PerSecond)
initPeriodicStats(&statistics.PerMinute) initPeriodicStats(&statistics.PerMinute)
initPeriodicStats(&statistics.PerHour) initPeriodicStats(&statistics.PerHour)
initPeriodicStats(&statistics.PerDay) initPeriodicStats(&statistics.PerDay)
statistics.Unlock()
} }
func runStatsCollectors() { func runStatsCollectors() {
@ -121,10 +125,12 @@ func statsRotate(periodic *periodicStats, now time.Time, rotations int64) {
// called every second, accumulates stats for each second, minute, hour and day // called every second, accumulates stats for each second, minute, hour and day
func collectStats() { func collectStats() {
now := time.Now() now := time.Now()
statistics.Lock()
statsRotate(&statistics.PerSecond, now, int64(now.Sub(statistics.PerSecond.LastRotate)/time.Second)) statsRotate(&statistics.PerSecond, now, int64(now.Sub(statistics.PerSecond.LastRotate)/time.Second))
statsRotate(&statistics.PerMinute, now, int64(now.Sub(statistics.PerMinute.LastRotate)/time.Minute)) statsRotate(&statistics.PerMinute, now, int64(now.Sub(statistics.PerMinute.LastRotate)/time.Minute))
statsRotate(&statistics.PerHour, now, int64(now.Sub(statistics.PerHour.LastRotate)/time.Hour)) statsRotate(&statistics.PerHour, now, int64(now.Sub(statistics.PerHour.LastRotate)/time.Hour))
statsRotate(&statistics.PerDay, now, int64(now.Sub(statistics.PerDay.LastRotate)/time.Hour/24)) statsRotate(&statistics.PerDay, now, int64(now.Sub(statistics.PerDay.LastRotate)/time.Hour/24))
statistics.Unlock()
// grab HTTP from prometheus // grab HTTP from prometheus
resp, err := client.Get("http://127.0.0.1:9153/metrics") resp, err := client.Get("http://127.0.0.1:9153/metrics")
@ -191,6 +197,7 @@ func collectStats() {
} }
// calculate delta // calculate delta
statistics.Lock()
delta := calcDelta(entry, statistics.LastSeen) delta := calcDelta(entry, statistics.LastSeen)
// apply delta to second/minute/hour/day // apply delta to second/minute/hour/day
@ -201,6 +208,7 @@ func collectStats() {
// save last seen // save last seen
statistics.LastSeen = entry statistics.LastSeen = entry
statistics.Unlock()
} }
func calcDelta(current, seen statsEntry) statsEntry { func calcDelta(current, seen statsEntry) statsEntry {
@ -245,7 +253,9 @@ func loadStats() error {
func writeStats() error { func writeStats() error {
statsFile := filepath.Join(config.ourBinaryDir, "stats.json") statsFile := filepath.Join(config.ourBinaryDir, "stats.json")
log.Printf("Writing JSON file: %s", statsFile) log.Printf("Writing JSON file: %s", statsFile)
statistics.RLock()
json, err := json.MarshalIndent(statistics, "", " ") json, err := json.MarshalIndent(statistics, "", " ")
statistics.RUnlock()
if err != nil { if err != nil {
log.Printf("Couldn't generate JSON: %s", err) log.Printf("Couldn't generate JSON: %s", err)
return err return err