From 3b1faa1365dc72c01728263b4dfd61c5274dbf6d Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Sun, 7 Oct 2018 21:24:22 +0300 Subject: [PATCH] Fix more race conditions found by race detector --- coredns_plugin/coredns_plugin.go | 40 ++++++++++++++++++-------------- stats.go | 2 +- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/coredns_plugin/coredns_plugin.go b/coredns_plugin/coredns_plugin.go index dcdb50c8..2db776ac 100644 --- a/coredns_plugin/coredns_plugin.go +++ b/coredns_plugin/coredns_plugin.go @@ -45,21 +45,24 @@ func init() { }) } +type plugSettings struct { + SafeBrowsingBlockHost string + ParentalBlockHost string + QueryLogEnabled bool + BlockedTTL uint32 // in seconds, default 3600 +} + type plug struct { d *dnsfilter.Dnsfilter Next plugin.Handler upstream upstream.Upstream hosts map[string]net.IP - - SafeBrowsingBlockHost string - ParentalBlockHost string - QueryLogEnabled bool - BlockedTTL uint32 // in seconds, default 3600 + settings plugSettings sync.RWMutex } -var defaultPlugin = plug{ +var defaultPluginSettings = plugSettings{ SafeBrowsingBlockHost: "safebrowsing.block.dns.adguard.com", ParentalBlockHost: "family.block.dns.adguard.com", BlockedTTL: 3600, // in seconds @@ -91,10 +94,11 @@ var ( // func setupPlugin(c *caddy.Controller) (*plug, error) { // create new Plugin and copy default values - var p = new(plug) - *p = defaultPlugin - p.d = dnsfilter.New() - p.hosts = make(map[string]net.IP) + p := &plug{ + settings: defaultPluginSettings, + d: dnsfilter.New(), + hosts: make(map[string]net.IP), + } filterFileNames := []string{} for c.Next() { @@ -130,7 +134,7 @@ func setupPlugin(c *caddy.Controller) (*plug, error) { if len(c.Val()) == 0 { return nil, c.ArgErr() } - p.ParentalBlockHost = c.Val() + p.settings.ParentalBlockHost = c.Val() } case "blocked_ttl": if !c.NextArg() { @@ -140,9 +144,9 @@ func setupPlugin(c *caddy.Controller) (*plug, error) { if err != nil { return nil, c.ArgErr() } - p.BlockedTTL = uint32(blockttl) + p.settings.BlockedTTL = uint32(blockttl) case "querylog": - p.QueryLogEnabled = true + p.settings.QueryLogEnabled = true onceQueryLog.Do(func() { go startQueryLogServer() // TODO: how to handle errors? }) @@ -323,7 +327,7 @@ func (p *plug) replaceHostWithValAndReply(ctx context.Context, w dns.ResponseWri log.Println("Will give", val, "instead of", host) if addr != nil { // this is an IP address, return it - result, err := dns.NewRR(fmt.Sprintf("%s %d A %s", host, p.BlockedTTL, val)) + result, err := dns.NewRR(fmt.Sprintf("%s %d A %s", host, p.settings.BlockedTTL, val)) if err != nil { log.Printf("Got error %s\n", err) return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err) @@ -365,7 +369,7 @@ func (p *plug) replaceHostWithValAndReply(ctx context.Context, w dns.ResponseWri // the only value that is important is TTL in header, other values like refresh, retry, expire and minttl are irrelevant func (p *plug) genSOA(r *dns.Msg) []dns.RR { zone := r.Question[0].Name - header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: p.BlockedTTL, Class: dns.ClassINET} + header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: p.settings.BlockedTTL, Class: dns.ClassINET} Mbox := "hostmaster." if zone[0] != '.' { @@ -450,7 +454,7 @@ func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dn switch result.Reason { case dnsfilter.FilteredSafeBrowsing: // return cname safebrowsing.block.dns.adguard.com - val := p.SafeBrowsingBlockHost + val := p.settings.SafeBrowsingBlockHost rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question) if err != nil { return rcode, dnsfilter.Result{}, err @@ -458,7 +462,7 @@ func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dn return rcode, result, err case dnsfilter.FilteredParental: // return cname family.block.dns.adguard.com - val := p.ParentalBlockHost + val := p.settings.ParentalBlockHost rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question) if err != nil { return rcode, dnsfilter.Result{}, err @@ -549,7 +553,7 @@ func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ( } // log - if p.QueryLogEnabled { + if p.settings.QueryLogEnabled { logRequest(r, rrw.Msg, result, time.Since(start), ip) } return rcode, err diff --git a/stats.go b/stats.go index a3231faf..f06380d2 100644 --- a/stats.go +++ b/stats.go @@ -254,7 +254,7 @@ func writeStats() error { statsFile := filepath.Join(config.ourBinaryDir, "stats.json") log.Printf("Writing JSON file: %s", statsFile) statistics.RLock() - json, err := json.MarshalIndent(statistics, "", " ") + json, err := json.MarshalIndent(&statistics, "", " ") statistics.RUnlock() if err != nil { log.Printf("Couldn't generate JSON: %s", err)