diff --git a/CHANGELOG.md b/CHANGELOG.md index eed5f0f1..efb91061 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -122,6 +122,8 @@ In this release, the schema version has changed from 10 to 12. ### Fixed +- Incorrect `$dnsrewrite` results for entries from the operating system's hosts + file ([#3815]). - Matching against rules with `|` at the end of the domain name ([#3371]). - Incorrect assignment of explicitly configured DHCP options ([#3744]). - Occasional panic during shutdown ([#3655]). @@ -219,6 +221,7 @@ In this release, the schema version has changed from 10 to 12. [#3655]: https://github.com/AdguardTeam/AdGuardHome/issues/3655 [#3707]: https://github.com/AdguardTeam/AdGuardHome/issues/3707 [#3744]: https://github.com/AdguardTeam/AdGuardHome/issues/3744 +[#3815]: https://github.com/AdguardTeam/AdGuardHome/issues/3815 diff --git a/internal/aghnet/hostscontainer.go b/internal/aghnet/hostscontainer.go index 852bf299..67aad623 100644 --- a/internal/aghnet/hostscontainer.go +++ b/internal/aghnet/hostscontainer.go @@ -17,12 +17,13 @@ import ( "github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/urlfilter" "github.com/AdguardTeam/urlfilter/filterlist" + "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" ) // DefaultHostsPaths returns the slice of paths default for the operating system // to files and directories which are containing the hosts database. The result -// is intended to use within fs.FS so the initial slash is omitted. +// is intended to be used within fs.FS so the initial slash is omitted. func DefaultHostsPaths() (paths []string) { return defaultHostsPaths() } @@ -42,9 +43,10 @@ type HostsContainer struct { // engine serves rulesStrg. engine *urlfilter.DNSEngine - // Updates is the channel for receiving updated hosts. The receivable map's - // values has a type of slice of strings. + // updates is the channel for receiving updated hosts. updates chan *netutil.IPMap + // last is the set of hosts that was cached within last detected change. + last *netutil.IPMap // fsys is the working file system to read hosts files from. fsys fs.FS @@ -81,6 +83,7 @@ func NewHostsContainer( hc = &HostsContainer{ engLock: &sync.RWMutex{}, updates: make(chan *netutil.IPMap, 1), + last: &netutil.IPMap{}, fsys: fsys, w: w, patterns: patterns, @@ -127,17 +130,20 @@ func (hc *HostsContainer) MatchRequest( hc.engLock.RLock() defer hc.engLock.RUnlock() - return hc.engine.MatchRequest(req) + res, ok = hc.engine.MatchRequest(req) + + return res, ok } // Close implements the io.Closer interface for *HostsContainer. func (hc *HostsContainer) Close() (err error) { - log.Debug("%s: closing hosts container", hostsContainerPref) + log.Debug("%s: closing", hostsContainerPref) return errors.Annotate(hc.w.Close(), "%s: closing: %w", hostsContainerPref) } -// Upd returns the channel into which the updates are sent. +// Upd returns the channel into which the updates are sent. The receivable +// map's values are guaranteed to be of type of *stringutil.Set. func (hc *HostsContainer) Upd() (updates <-chan *netutil.IPMap) { return hc.updates } @@ -185,11 +191,18 @@ type hostsParser struct { table *netutil.IPMap } -// parseHostsFile is a aghtest.FileWalker for parsing the files with hosts -// syntax. It never signs to stop the walking. +func (hc *HostsContainer) newHostsParser() (hp *hostsParser) { + return &hostsParser{ + rules: &strings.Builder{}, + table: netutil.NewIPMap(hc.last.Len()), + } +} + +// parseFile is a aghos.FileWalker for parsing the files with hosts syntax. It +// never signs to stop walking and never returns any additional patterns. // // See man hosts(5). -func (hp hostsParser) parseHostsFile( +func (hp *hostsParser) parseFile( r io.Reader, ) (patterns []string, cont bool, err error) { s := bufio.NewScanner(r) @@ -208,7 +221,7 @@ func (hp hostsParser) parseHostsFile( } // parseLine parses the line having the hosts syntax ignoring invalid ones. -func (hp hostsParser) parseLine(line string) (ip net.IP, hosts []string) { +func (hp *hostsParser) parseLine(line string) (ip net.IP, hosts []string) { line = strings.TrimSpace(line) fields := strings.Fields(line) if len(fields) < 2 { @@ -240,20 +253,24 @@ loop: } // add returns true if the pair of ip and host wasn't added to the hp before. -func (hp hostsParser) add(ip net.IP, host string) (added bool) { +func (hp *hostsParser) add(ip net.IP, host string) (added bool) { v, ok := hp.table.Get(ip) - hosts, _ := v.([]string) - if ok && stringutil.InSlice(hosts, host) { + hosts, _ := v.(*stringutil.Set) + switch { + case ok && hosts.Has(host): return false + case hosts == nil: + hosts = stringutil.NewSet(host) + hp.table.Set(ip, hosts) + default: + hosts.Add(host) } - hp.table.Set(ip, append(hosts, host)) - return true } // addPair puts the pair of ip and host to the rules builder if needed. -func (hp hostsParser) addPair(ip net.IP, host string) { +func (hp *hostsParser) addPair(ip net.IP, host string) { arpa, err := netutil.IPToReversedAddr(ip) if err != nil { return @@ -269,61 +286,110 @@ func (hp hostsParser) addPair(ip net.IP, host string) { qtype = "A" } - stringutil.WriteToBuilder( - hp.rules, - "||", - host, - "^$dnsrewrite=NOERROR;", - qtype, - ";", - ip.String(), - "\n", - "||", - arpa, - "^$dnsrewrite=NOERROR;PTR;", - dns.Fqdn(host), - "\n", + const ( + nl = "\n" + sc = ";" + + rewriteSuccess = "$dnsrewrite=NOERROR" + sc + rewriteSuccessPTR = rewriteSuccess + "PTR" + sc ) + ipStr := ip.String() + fqdn := dns.Fqdn(host) + + for _, ruleData := range [...][]string{{ + // A/AAAA. + rules.MaskStartURL, + host, + rules.MaskSeparator, + rewriteSuccess, + qtype, + sc, + ipStr, + nl, + }, { + // PTR. + rules.MaskStartURL, + arpa, + rules.MaskSeparator, + rewriteSuccessPTR, + fqdn, + nl, + }} { + stringutil.WriteToBuilder(hp.rules, ruleData...) + } + log.Debug("%s: added ip-host pair %q/%q", hostsContainerPref, ip, host) } +// equalSet returns true if the internal hosts table just parsed equals target. +func (hp *hostsParser) equalSet(target *netutil.IPMap) (ok bool) { + if hp.table.Len() != target.Len() { + return false + } + + hp.table.Range(func(ip net.IP, val interface{}) (cont bool) { + v, hasIP := target.Get(ip) + // ok is set to true if the target doesn't contain ip or if the + // appropriate hosts set isn't equal to the checked one, i.e. the maps + // have at least one disperancy. + ok = !hasIP || !v.(*stringutil.Set).Equal(val.(*stringutil.Set)) + + // Continue only if maps has no discrepancies. + return !ok + }) + + // Return true if every value from the IP map has no disperancies with the + // appropriate one from the target. + return !ok +} + // sendUpd tries to send the parsed data to the ch. -func (hp hostsParser) sendUpd(ch chan *netutil.IPMap) { +func (hp *hostsParser) sendUpd(ch chan *netutil.IPMap) { log.Debug("%s: sending upd", hostsContainerPref) + + upd := hp.table select { - case ch <- hp.table: + case ch <- upd: // Updates are delivered. Go on. + case <-ch: + ch <- upd + log.Debug("%s: replaced the last update", hostsContainerPref) + case ch <- upd: + // The previous update was just read and the next one pushed. Go on. default: - log.Debug("%s: the buffer is full", hostsContainerPref) + log.Debug("%s: the channel is broken", hostsContainerPref) } } // newStrg creates a new rules storage from parsed data. -func (hp hostsParser) newStrg() (s *filterlist.RuleStorage, err error) { +func (hp *hostsParser) newStrg() (s *filterlist.RuleStorage, err error) { return filterlist.NewRuleStorage([]filterlist.RuleList{&filterlist.StringRuleList{ - ID: 1, + ID: -1, RulesText: hp.rules.String(), IgnoreCosmetic: true, }}) } -// refresh gets the data from specified files and propagates the updates. +// refresh gets the data from specified files and propagates the updates if +// needed. func (hc *HostsContainer) refresh() (err error) { log.Debug("%s: refreshing", hostsContainerPref) - hp := hostsParser{ - rules: &strings.Builder{}, - table: netutil.NewIPMap(0), + hp := hc.newHostsParser() + if _, err = aghos.FileWalker(hp.parseFile).Walk(hc.fsys, hc.patterns...); err != nil { + return fmt.Errorf("refreshing : %w", err) } - _, err = aghos.FileWalker(hp.parseHostsFile).Walk(hc.fsys, hc.patterns...) - if err != nil { - return fmt.Errorf("updating: %w", err) - } + if hp.equalSet(hc.last) { + log.Debug("%s: no updates detected", hostsContainerPref) + return nil + } defer hp.sendUpd(hc.updates) + hc.last = hp.table.ShallowClone() + var rulesStrg *filterlist.RuleStorage if rulesStrg, err = hp.newStrg(); err != nil { return fmt.Errorf("initializing rules storage: %w", err) diff --git a/internal/aghnet/hostscontainer_test.go b/internal/aghnet/hostscontainer_test.go index 213f4e7b..b30790a3 100644 --- a/internal/aghnet/hostscontainer_test.go +++ b/internal/aghnet/hostscontainer_test.go @@ -12,6 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/urlfilter" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -164,11 +165,14 @@ func TestHostsContainer_Refresh(t *testing.T) { }, } - eventsCh := make(chan struct{}, 1) + // event is a convenient alias for an empty struct{} to emit test events. + type event = struct{} + + eventsCh := make(chan event, 1) t.Cleanup(func() { close(eventsCh) }) w := &aghtest.FSWatcher{ - OnEvents: func() (e <-chan struct{}) { return eventsCh }, + OnEvents: func() (e <-chan event) { return eventsCh }, OnAdd: func(name string) (err error) { assert.Equal(t, dirname, name) @@ -181,7 +185,7 @@ func TestHostsContainer_Refresh(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { require.ErrorIs(t, hc.Close(), closeCalled) }) - checkRefresh := func(t *testing.T, wantHosts []string) { + checkRefresh := func(t *testing.T, wantHosts *stringutil.Set) { upd, ok := <-hc.Upd() require.True(t, ok) require.NotNil(t, upd) @@ -191,26 +195,30 @@ func TestHostsContainer_Refresh(t *testing.T) { v, ok := upd.Get(knownIP) require.True(t, ok) - var hosts []string - hosts, ok = v.([]string) + var hosts *stringutil.Set + hosts, ok = v.(*stringutil.Set) require.True(t, ok) - require.Len(t, hosts, len(wantHosts)) - assert.Equal(t, wantHosts, hosts) + assert.True(t, hosts.Equal(wantHosts)) } t.Run("initial_refresh", func(t *testing.T) { - checkRefresh(t, []string{knownHost}) + checkRefresh(t, stringutil.NewSet(knownHost)) }) testFS[p2] = &fstest.MapFile{ Data: []byte(strings.Join([]string{knownIP.String(), knownAlias}, sp) + nl), } - - eventsCh <- struct{}{} + eventsCh <- event{} t.Run("second_refresh", func(t *testing.T) { - checkRefresh(t, []string{knownHost, knownAlias}) + checkRefresh(t, stringutil.NewSet(knownHost, knownAlias)) + }) + + eventsCh <- event{} + + t.Run("no_changes_refresh", func(t *testing.T) { + assert.Empty(t, hc.Upd()) }) } @@ -218,7 +226,7 @@ func TestHostsContainer_MatchRequest(t *testing.T) { var ( ip4 = net.IP{127, 0, 0, 1} ip6 = net.IP{ - 0, 0, 0, 0, + 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, @@ -236,9 +244,9 @@ func TestHostsContainer_MatchRequest(t *testing.T) { gsfs := fstest.MapFS{ filename: &fstest.MapFile{Data: []byte( - strings.Join([]string{ip4.String(), hostname4, hostname4a}, sp) + nl + - strings.Join([]string{ip6.String(), hostname6}, sp) + nl + - strings.Join([]string{"256.256.256.256", "fakebroadcast"}, sp) + nl, + ip4.String() + " " + hostname4 + " " + hostname4a + nl + + ip6.String() + " " + hostname6 + nl + + `256.256.256.256 fakebroadcast` + nl, )}, } @@ -265,6 +273,15 @@ func TestHostsContainer_MatchRequest(t *testing.T) { Hostname: hostname4, DNSType: dns.TypeA, }, + }, { + name: "a_for_aaaa", + want: []interface{}{ + ip4.To16(), + }, + req: urlfilter.DNSRequest{ + Hostname: hostname4, + DNSType: dns.TypeAAAA, + }, }, { name: "aaaa", want: []interface{}{ip6}, @@ -408,7 +425,7 @@ func TestUniqueRules_AddPair(t *testing.T) { const knownHost = "host1" ipToHost := netutil.NewIPMap(0) - ipToHost.Set(knownIP, []string{knownHost}) + ipToHost.Set(knownIP, *stringutil.NewSet(knownHost)) testCases := []struct { name string @@ -422,10 +439,11 @@ func TestUniqueRules_AddPair(t *testing.T) { "||4.3.2.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;host2.\n", ip: knownIP, }, { - name: "existing_one", - host: knownHost, - wantRules: "", - ip: knownIP, + name: "existing_one", + host: knownHost, + wantRules: "||" + knownHost + "^$dnsrewrite=NOERROR;A;1.2.3.4\n" + + "||4.3.2.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;host1.\n", + ip: knownIP, }, { name: "new_ip", host: knownHost, diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index 7300b43c..471b463e 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -82,25 +82,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) { // original question is readded in processFilteringAfterResponse. ctx.origQuestion = q req.Question[0].Name = dns.Fqdn(res.CanonName) - case res.Reason == filtering.RewrittenAutoHosts && len(res.ReverseHosts) != 0: - resp := s.makeResponse(req) - hdr := dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypePTR, - Ttl: s.conf.BlockedResponseTTL, - Class: dns.ClassINET, - } - for _, h := range res.ReverseHosts { - ptr := &dns.PTR{ - Hdr: hdr, - Ptr: h, - } - - resp.Answer = append(resp.Answer, ptr) - } - - d.Res = resp - case res.Reason.In(filtering.Rewritten, filtering.RewrittenAutoHosts): + case res.Reason == filtering.Rewritten: resp := s.makeResponse(req) name := host @@ -123,7 +105,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) { } d.Res = resp - case res.Reason == filtering.RewrittenRule: + case res.Reason.In(filtering.RewrittenRule, filtering.RewrittenAutoHosts): if err = s.filterDNSRewrite(req, res, d); err != nil { return nil, err } diff --git a/internal/filtering/dnsrewrite.go b/internal/filtering/dnsrewrite.go index e98dfa3d..a6dda4a6 100644 --- a/internal/filtering/dnsrewrite.go +++ b/internal/filtering/dnsrewrite.go @@ -17,12 +17,8 @@ type DNSRewriteResultResponse map[rules.RRType][]rules.RRValue // processDNSRewrites processes DNS rewrite rules in dnsr. It returns an empty // result if dnsr is empty. Otherwise, the result will have either CanonName or -// DNSRewriteResult set. +// DNSRewriteResult set. dnsr is expected to be non-empty. func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) { - if len(dnsr) == 0 { - return Result{} - } - var rules []*ResultRule dnsrr := &DNSRewriteResult{ Response: DNSRewriteResultResponse{}, @@ -31,8 +27,7 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) { for _, nr := range dnsr { dr := nr.DNSRewrite if dr.NewCNAME != "" { - // NewCNAME rules have a higher priority than - // the other rules. + // NewCNAME rules have a higher priority than the other rules. rules = []*ResultRule{{ FilterListID: int64(nr.GetFilterListID()), Text: nr.RuleText, @@ -54,8 +49,8 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) { Text: nr.RuleText, }) default: - // RcodeRefused and other such codes have higher - // priority. Return immediately. + // RcodeRefused and other such codes have higher priority. Return + // immediately. rules = []*ResultRule{{ FilterListID: int64(nr.GetFilterListID()), Text: nr.RuleText, diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index defd9e2f..c4dd4c05 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -378,6 +378,10 @@ type Result struct { // ReverseHosts is the reverse lookup rewrite result. It is empty unless // Reason is set to RewrittenAutoHosts. + // + // TODO(e.burkov): There is no need for AutoHosts-related fields any more + // since the hosts container now uses $dnsrewrite rules. These fields are + // only used in query log to decode old format. ReverseHosts []string `json:",omitempty"` // IPList is the lookup rewrite result. It is empty unless Reason is set to @@ -450,53 +454,39 @@ func (d *DNSFilter) CheckHost( } // matchSysHosts tries to match the host against the operating system's hosts -// database. +// database. err is always nil. func (d *DNSFilter) matchSysHosts( host string, qtype uint16, setts *Settings, ) (res Result, err error) { if !setts.FilteringEnabled || d.EtcHosts == nil { - return Result{}, nil + return res, nil } dnsres, _ := d.EtcHosts.MatchRequest(urlfilter.DNSRequest{ Hostname: host, SortedClientTags: setts.ClientTags, - // TODO(e.burkov): Wait for urlfilter update to pass net.IP. + // TODO(e.burkov): Wait for urlfilter update to pass net.IP. ClientIP: setts.ClientIP.String(), ClientName: setts.ClientName, DNSType: qtype, }) if dnsres == nil { - return Result{}, nil + return res, nil } - dnsr := dnsres.DNSRewrites() - if len(dnsr) == 0 { - return Result{}, nil + if dnsr := dnsres.DNSRewrites(); len(dnsr) > 0 { + // Check DNS rewrites first, because the API there is a bit awkward. + res = d.processDNSRewrites(dnsr) + res.Reason = RewrittenAutoHosts + // TODO(e.burkov): Put real hosts-syntax rules. + // + // See https://github.com/AdguardTeam/AdGuardHome/issues/3846. + res.Rules = nil } - var ips []net.IP - var revHosts []string - for _, nr := range dnsr { - if nr.DNSRewrite == nil { - continue - } - - switch val := nr.DNSRewrite.Value.(type) { - case net.IP: - ips = append(ips, val) - case string: - revHosts = append(revHosts, val) - } - } - - return Result{ - Reason: RewrittenAutoHosts, - IPList: ips, - ReverseHosts: revHosts, - }, nil + return res, nil } // Process rewrites table diff --git a/internal/home/clients.go b/internal/home/clients.go index 85eec9fd..02c87f03 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -772,17 +772,18 @@ func (clients *clientsContainer) addFromHostsFile(hosts *netutil.IPMap) { n := 0 hosts.Range(func(ip net.IP, v interface{}) (cont bool) { - names, ok := v.([]string) + names, ok := v.(*stringutil.Set) if !ok { log.Error("dns: bad type %T in ipToRC for %s", v, ip) } - for _, name := range names { - ok = clients.addHostLocked(ip, name, ClientSourceHostsFile) - if ok { + names.Range(func(name string) (cont bool) { + if clients.addHostLocked(ip, name, ClientSourceHostsFile) { n++ } - } + + return true + }) return true })