Fix more race conditions found by race detector

This commit is contained in:
Eugene Bujak 2018-10-07 21:24:22 +03:00
parent dc1042c3e9
commit 3b1faa1365
2 changed files with 23 additions and 19 deletions

View file

@ -45,21 +45,24 @@ func init() {
})
}
type plugSettings struct {
SafeBrowsingBlockHost string
ParentalBlockHost string
QueryLogEnabled bool
BlockedTTL uint32 // in seconds, default 3600
}
type plug struct {
d *dnsfilter.Dnsfilter
Next plugin.Handler
upstream upstream.Upstream
hosts map[string]net.IP
SafeBrowsingBlockHost string
ParentalBlockHost string
QueryLogEnabled bool
BlockedTTL uint32 // in seconds, default 3600
settings plugSettings
sync.RWMutex
}
var defaultPlugin = plug{
var defaultPluginSettings = plugSettings{
SafeBrowsingBlockHost: "safebrowsing.block.dns.adguard.com",
ParentalBlockHost: "family.block.dns.adguard.com",
BlockedTTL: 3600, // in seconds
@ -91,10 +94,11 @@ var (
//
func setupPlugin(c *caddy.Controller) (*plug, error) {
// create new Plugin and copy default values
var p = new(plug)
*p = defaultPlugin
p.d = dnsfilter.New()
p.hosts = make(map[string]net.IP)
p := &plug{
settings: defaultPluginSettings,
d: dnsfilter.New(),
hosts: make(map[string]net.IP),
}
filterFileNames := []string{}
for c.Next() {
@ -130,7 +134,7 @@ func setupPlugin(c *caddy.Controller) (*plug, error) {
if len(c.Val()) == 0 {
return nil, c.ArgErr()
}
p.ParentalBlockHost = c.Val()
p.settings.ParentalBlockHost = c.Val()
}
case "blocked_ttl":
if !c.NextArg() {
@ -140,9 +144,9 @@ func setupPlugin(c *caddy.Controller) (*plug, error) {
if err != nil {
return nil, c.ArgErr()
}
p.BlockedTTL = uint32(blockttl)
p.settings.BlockedTTL = uint32(blockttl)
case "querylog":
p.QueryLogEnabled = true
p.settings.QueryLogEnabled = true
onceQueryLog.Do(func() {
go startQueryLogServer() // TODO: how to handle errors?
})
@ -323,7 +327,7 @@ func (p *plug) replaceHostWithValAndReply(ctx context.Context, w dns.ResponseWri
log.Println("Will give", val, "instead of", host)
if addr != nil {
// this is an IP address, return it
result, err := dns.NewRR(fmt.Sprintf("%s %d A %s", host, p.BlockedTTL, val))
result, err := dns.NewRR(fmt.Sprintf("%s %d A %s", host, p.settings.BlockedTTL, val))
if err != nil {
log.Printf("Got error %s\n", err)
return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err)
@ -365,7 +369,7 @@ func (p *plug) replaceHostWithValAndReply(ctx context.Context, w dns.ResponseWri
// the only value that is important is TTL in header, other values like refresh, retry, expire and minttl are irrelevant
func (p *plug) genSOA(r *dns.Msg) []dns.RR {
zone := r.Question[0].Name
header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: p.BlockedTTL, Class: dns.ClassINET}
header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: p.settings.BlockedTTL, Class: dns.ClassINET}
Mbox := "hostmaster."
if zone[0] != '.' {
@ -450,7 +454,7 @@ func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dn
switch result.Reason {
case dnsfilter.FilteredSafeBrowsing:
// return cname safebrowsing.block.dns.adguard.com
val := p.SafeBrowsingBlockHost
val := p.settings.SafeBrowsingBlockHost
rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question)
if err != nil {
return rcode, dnsfilter.Result{}, err
@ -458,7 +462,7 @@ func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dn
return rcode, result, err
case dnsfilter.FilteredParental:
// return cname family.block.dns.adguard.com
val := p.ParentalBlockHost
val := p.settings.ParentalBlockHost
rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question)
if err != nil {
return rcode, dnsfilter.Result{}, err
@ -549,7 +553,7 @@ func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (
}
// log
if p.QueryLogEnabled {
if p.settings.QueryLogEnabled {
logRequest(r, rrw.Msg, result, time.Since(start), ip)
}
return rcode, err

View file

@ -254,7 +254,7 @@ func writeStats() error {
statsFile := filepath.Join(config.ourBinaryDir, "stats.json")
log.Printf("Writing JSON file: %s", statsFile)
statistics.RLock()
json, err := json.MarshalIndent(statistics, "", " ")
json, err := json.MarshalIndent(&statistics, "", " ")
statistics.RUnlock()
if err != nil {
log.Printf("Couldn't generate JSON: %s", err)