diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index acf1c6d2..4db69cd2 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -255,10 +255,14 @@ func (d *DNSFilter) GetConfig() (s Settings) { // WriteDiskConfig - write configuration func (d *DNSFilter) WriteDiskConfig(c *Config) { d.confLock.Lock() + defer d.confLock.Unlock() + *c = d.Config - c.Rewrites = rewriteArrayDup(d.Config.Rewrites) - // BlockedServices - d.confLock.Unlock() + c.Rewrites = cloneRewrites(c.Rewrites) +} + +func cloneRewrites(entries []RewriteEntry) (clone []RewriteEntry) { + return append([]RewriteEntry(nil), entries...) } // SetFilters - set new filters (synchronously or asynchronously) diff --git a/internal/filtering/rewrites.go b/internal/filtering/rewrites.go index 80786d76..90eb262e 100644 --- a/internal/filtering/rewrites.go +++ b/internal/filtering/rewrites.go @@ -27,8 +27,66 @@ type RewriteEntry struct { Type uint16 `yaml:"-"` } -func (r *RewriteEntry) equals(b RewriteEntry) bool { - return r.Domain == b.Domain && r.Answer == b.Answer +// equal returns true if the entry is considered equal to the other. +func (e *RewriteEntry) equal(other RewriteEntry) (ok bool) { + return e.Domain == other.Domain && e.Answer == other.Answer +} + +// matchesQType returns true if the entry matched qtype. +func (e *RewriteEntry) matchesQType(qtype uint16) (ok bool) { + // Add CNAMEs, since they match for all types requests. + if e.Type == dns.TypeCNAME { + return true + } + + // Reject types other than A and AAAA. + if qtype != dns.TypeA && qtype != dns.TypeAAAA { + return false + } + + // If the types match or the entry is set to allow only the other type, + // include them. + return e.Type == qtype || e.IP == nil +} + +// normalize makes sure that the a new or decoded entry is normalized with +// regards to domain name case, IP length, and so on. +func (e *RewriteEntry) normalize() { + // TODO(a.garipov): Write a case-agnostic version of strings.HasSuffix + // and use it in matchDomainWildcard instead of using strings.ToLower + // everywhere. + e.Domain = strings.ToLower(e.Domain) + + switch e.Answer { + case "AAAA": + e.IP = nil + e.Type = dns.TypeAAAA + + return + case "A": + e.IP = nil + e.Type = dns.TypeA + + return + default: + // Go on. + } + + ip := net.ParseIP(e.Answer) + if ip == nil { + e.Type = dns.TypeCNAME + + return + } + + ip4 := ip.To4() + if ip4 != nil { + e.IP = ip4 + e.Type = dns.TypeA + } else { + e.IP = ip + e.Type = dns.TypeAAAA + } } func isWildcard(host string) bool { @@ -78,48 +136,9 @@ func (a rewritesSorted) Less(i, j int) bool { return len(a[i].Domain) > len(a[j].Domain) } -// prepare prepares the a new or decoded entry. -func (r *RewriteEntry) prepare() { - // TODO(a.garipov): Write a case-agnostic version of strings.HasSuffix - // and use it in matchDomainWildcard instead of using strings.ToLower - // everywhere. - r.Domain = strings.ToLower(r.Domain) - - switch r.Answer { - case "AAAA": - r.IP = nil - r.Type = dns.TypeAAAA - - return - case "A": - r.IP = nil - r.Type = dns.TypeA - - return - default: - // Go on. - } - - ip := net.ParseIP(r.Answer) - if ip == nil { - r.Type = dns.TypeCNAME - - return - } - - ip4 := ip.To4() - if ip4 != nil { - r.IP = ip4 - r.Type = dns.TypeA - } else { - r.IP = ip - r.Type = dns.TypeAAAA - } -} - func (d *DNSFilter) prepareRewrites() { for i := range d.Rewrites { - d.Rewrites[i].prepare() + d.Rewrites[i].normalize() } } @@ -127,18 +146,15 @@ func (d *DNSFilter) prepareRewrites() { // CNAME, then A and AAAA; exact, then wildcard. If the host is matched // exactly, wildcard entries aren't returned. If the host matched by wildcards, // return the most specific for the question type. -func findRewrites(a []RewriteEntry, host string, qtype uint16) []RewriteEntry { +func findRewrites(entries []RewriteEntry, host string, qtype uint16) (matched []RewriteEntry) { rr := rewritesSorted{} - for _, r := range a { - if r.Domain != host && !matchDomainWildcard(host, r.Domain) { + for _, e := range entries { + if e.Domain != host && !matchDomainWildcard(host, e.Domain) { continue } - // Return CNAMEs for all types requests, but only the - // appropriate ones for A and AAAA. - if r.Type == dns.TypeCNAME || - (r.Type == qtype && (qtype == dns.TypeA || qtype == dns.TypeAAAA)) { - rr = append(rr, r) + if e.matchesQType(qtype) { + rr = append(rr, e) } } @@ -169,12 +185,6 @@ func max(a, b int) int { return b } -func rewriteArrayDup(a []RewriteEntry) []RewriteEntry { - a2 := make([]RewriteEntry, len(a)) - copy(a2, a) - return a2 -} - type rewriteEntryJSON struct { Domain string `json:"domain"` Answer string `json:"answer"` @@ -213,7 +223,7 @@ func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) { Domain: jsent.Domain, Answer: jsent.Answer, } - ent.prepare() + ent.normalize() d.confLock.Lock() d.Config.Rewrites = append(d.Config.Rewrites, ent) d.confLock.Unlock() @@ -238,7 +248,7 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request) arr := []RewriteEntry{} d.confLock.Lock() for _, ent := range d.Config.Rewrites { - if ent.equals(entDel) { + if ent.equal(entDel) { log.Debug("Rewrites: removed element: %s -> %s", ent.Domain, ent.Answer) continue } diff --git a/internal/filtering/rewrites_test.go b/internal/filtering/rewrites_test.go index d0591988..551a7faf 100644 --- a/internal/filtering/rewrites_test.go +++ b/internal/filtering/rewrites_test.go @@ -316,7 +316,7 @@ func TestRewritesExceptionIP(t *testing.T) { }, { name: "match_AAAA_host3.com", host: "host3.com", - want: nil, + want: []net.IP{}, dtyp: dns.TypeAAAA, }}