diff --git a/internal/filtering/dnsrewrite_test.go b/internal/filtering/dnsrewrite_test.go index 7f8a56d9..06cd921b 100644 --- a/internal/filtering/dnsrewrite_test.go +++ b/internal/filtering/dnsrewrite_test.go @@ -1,17 +1,11 @@ package filtering import ( - "fmt" "net/netip" "path" "testing" - "testing/fstest" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/golibs/netutil" - "github.com/AdguardTeam/golibs/testutil" - "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -215,173 +209,3 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) { assert.Equal(t, "new-ptr-with-dot.", ptr) }) } - -func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { - addrv4 := netip.MustParseAddr("1.2.3.4") - addrv6 := netip.MustParseAddr("::1") - addrMapped := netip.MustParseAddr("::ffff:1.2.3.4") - addrv4Dup := netip.MustParseAddr("4.3.2.1") - - data := fmt.Sprintf( - ""+ - "%[1]s v4.host.example\n"+ - "%[2]s v6.host.example\n"+ - "%[3]s mapped.host.example\n"+ - "%[4]s v4.host.with-dup\n"+ - "%[4]s v4.host.with-dup\n", - addrv4, - addrv6, - addrMapped, - addrv4Dup, - ) - - files := fstest.MapFS{ - "hosts": &fstest.MapFile{ - Data: []byte(data), - }, - } - watcher := &aghtest.FSWatcher{ - OnEvents: func() (e <-chan struct{}) { return nil }, - OnAdd: func(name string) (err error) { return nil }, - OnClose: func() (err error) { return nil }, - } - hc, err := aghnet.NewHostsContainer(files, watcher, "hosts") - require.NoError(t, err) - testutil.CleanupAndRequireSuccess(t, hc.Close) - - f, _ := newForTest(t, &Config{EtcHosts: hc}, nil) - setts := &Settings{ - FilteringEnabled: true, - } - - testCases := []struct { - name string - host string - wantRules []*ResultRule - wantResps []rules.RRValue - dtyp uint16 - }{{ - name: "v4", - host: "v4.host.example", - dtyp: dns.TypeA, - wantRules: []*ResultRule{{ - Text: "1.2.3.4 v4.host.example", - FilterListID: SysHostsListID, - }}, - wantResps: []rules.RRValue{addrv4}, - }, { - name: "v6", - host: "v6.host.example", - dtyp: dns.TypeAAAA, - wantRules: []*ResultRule{{ - Text: "::1 v6.host.example", - FilterListID: SysHostsListID, - }}, - wantResps: []rules.RRValue{addrv6}, - }, { - name: "mapped", - host: "mapped.host.example", - dtyp: dns.TypeAAAA, - wantRules: []*ResultRule{{ - Text: "::ffff:1.2.3.4 mapped.host.example", - FilterListID: SysHostsListID, - }}, - wantResps: []rules.RRValue{addrMapped}, - }, { - name: "ptr", - host: "4.3.2.1.in-addr.arpa", - dtyp: dns.TypePTR, - wantRules: []*ResultRule{{ - Text: "1.2.3.4 v4.host.example", - FilterListID: SysHostsListID, - }}, - wantResps: []rules.RRValue{"v4.host.example"}, - }, { - name: "ptr-mapped", - host: "4.0.3.0.2.0.1.0.f.f.f.f.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", - dtyp: dns.TypePTR, - wantRules: []*ResultRule{{ - Text: "::ffff:1.2.3.4 mapped.host.example", - FilterListID: SysHostsListID, - }}, - wantResps: []rules.RRValue{"mapped.host.example"}, - }, { - name: "not_found_v4", - host: "non.existent.example", - dtyp: dns.TypeA, - wantRules: nil, - wantResps: nil, - }, { - name: "not_found_v6", - host: "non.existent.example", - dtyp: dns.TypeAAAA, - wantRules: nil, - wantResps: nil, - }, { - name: "not_found_ptr", - host: "4.3.2.2.in-addr.arpa", - dtyp: dns.TypePTR, - wantRules: nil, - wantResps: nil, - }, { - name: "v4_mismatch", - host: "v4.host.example", - dtyp: dns.TypeAAAA, - wantRules: []*ResultRule{{ - Text: fmt.Sprintf("%s v4.host.example", addrv4), - FilterListID: SysHostsListID, - }}, - wantResps: nil, - }, { - name: "v6_mismatch", - host: "v6.host.example", - dtyp: dns.TypeA, - wantRules: []*ResultRule{{ - Text: fmt.Sprintf("%s v6.host.example", addrv6), - FilterListID: SysHostsListID, - }}, - wantResps: nil, - }, { - name: "wrong_ptr", - host: "4.3.2.1.ip6.arpa", - dtyp: dns.TypePTR, - wantRules: nil, - wantResps: nil, - }, { - name: "unsupported_type", - host: "v4.host.example", - dtyp: dns.TypeCNAME, - wantRules: nil, - wantResps: nil, - }, { - name: "v4_dup", - host: "v4.host.with-dup", - dtyp: dns.TypeA, - wantRules: []*ResultRule{{ - Text: "4.3.2.1 v4.host.with-dup", - FilterListID: SysHostsListID, - }}, - wantResps: []rules.RRValue{addrv4Dup}, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var res Result - res, err = f.CheckHost(tc.host, tc.dtyp, setts) - require.NoError(t, err) - - if len(tc.wantRules) == 0 { - assert.Empty(t, res.Rules) - assert.Nil(t, res.DNSRewriteResult) - - return - } - - require.NotNil(t, res.DNSRewriteResult) - require.Contains(t, res.DNSRewriteResult.Response, tc.dtyp) - - assert.Equal(t, tc.wantResps, res.DNSRewriteResult.Response[tc.dtyp]) - assert.Equal(t, tc.wantRules, res.Rules) - }) - } -} diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index 68f56e94..2d382530 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -23,7 +23,6 @@ import ( "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/mathutil" - "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/syncutil" "github.com/AdguardTeam/urlfilter" @@ -618,88 +617,6 @@ func (d *DNSFilter) CheckHost( return Result{}, nil } -// matchSysHosts tries to match the host against the operating system's hosts -// database. err is always nil. -func (d *DNSFilter) matchSysHosts( - host string, - qtype uint16, - setts *Settings, -) (res Result, err error) { - // TODO(e.burkov): Where else is this checked? - if !setts.FilteringEnabled || d.conf.EtcHosts == nil { - return Result{}, nil - } - - vals, rs, matched := hostsRewrites(qtype, host, d.conf.EtcHosts) - if matched { - res.DNSRewriteResult = &DNSRewriteResult{ - Response: DNSRewriteResultResponse{ - qtype: vals, - }, - RCode: dns.RcodeSuccess, - } - res.Rules = rs - res.Reason = RewrittenAutoHosts - } - - return res, nil -} - -// hostsRewrites returns values and rules matched by qt and host within hs. -func hostsRewrites( - qtype uint16, - host string, - hs hostsfile.Storage, -) (vals []rules.RRValue, rls []*ResultRule, matched bool) { - var isValidProto func(netip.Addr) (ok bool) - switch qtype { - case dns.TypeA: - isValidProto = netip.Addr.Is4 - case dns.TypeAAAA: - isValidProto = netip.Addr.Is6 - case dns.TypePTR: - // TODO(e.burkov): Add some [netip]-aware alternative to [netutil]. - ip, err := netutil.IPFromReversedAddr(host) - if err != nil { - log.Debug("filtering: failed to parse PTR record %q: %s", host, err) - - return nil, nil, false - } - - addr, _ := netip.AddrFromSlice(ip) - - for _, name := range hs.ByAddr(addr) { - matched = true - - vals = append(vals, name) - rls = append(rls, &ResultRule{ - Text: fmt.Sprintf("%s %s", addr, name), - FilterListID: SysHostsListID, - }) - } - - return vals, rls, matched - default: - log.Debug("filtering: unsupported qtype %d", qtype) - - return nil, nil, false - } - - for _, addr := range hs.ByName(host) { - matched = true - - if isValidProto(addr) { - vals = append(vals, addr) - } - rls = append(rls, &ResultRule{ - Text: fmt.Sprintf("%s %s", addr, host), - FilterListID: SysHostsListID, - }) - } - - return vals, rls, matched -} - // processRewrites performs filtering based on the legacy rewrite records. // // Firstly, it finds CNAME rewrites for host. If the CNAME is the same as host, diff --git a/internal/filtering/hosts.go b/internal/filtering/hosts.go new file mode 100644 index 00000000..2f747669 --- /dev/null +++ b/internal/filtering/hosts.go @@ -0,0 +1,94 @@ +package filtering + +import ( + "fmt" + "net/netip" + + "github.com/AdguardTeam/golibs/hostsfile" + "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/urlfilter/rules" + "github.com/miekg/dns" +) + +// matchSysHosts tries to match the host against the operating system's hosts +// database. err is always nil. +func (d *DNSFilter) matchSysHosts( + host string, + qtype uint16, + setts *Settings, +) (res Result, err error) { + // TODO(e.burkov): Where else is this checked? + if !setts.FilteringEnabled || d.conf.EtcHosts == nil { + return Result{}, nil + } + + vals, rs, matched := hostsRewrites(qtype, host, d.conf.EtcHosts) + if !matched { + return Result{}, nil + } + + return Result{ + DNSRewriteResult: &DNSRewriteResult{ + Response: DNSRewriteResultResponse{ + qtype: vals, + }, + RCode: dns.RcodeSuccess, + }, + Rules: rs, + Reason: RewrittenAutoHosts, + }, nil +} + +// hostsRewrites returns values and rules matched by qt and host within hs. +func hostsRewrites( + qtype uint16, + host string, + hs hostsfile.Storage, +) (vals []rules.RRValue, rls []*ResultRule, matched bool) { + var isValidProto func(netip.Addr) (ok bool) + switch qtype { + case dns.TypeA: + isValidProto = netip.Addr.Is4 + case dns.TypeAAAA: + isValidProto = netip.Addr.Is6 + case dns.TypePTR: + // TODO(e.burkov): Add some [netip]-aware alternative to [netutil]. + ip, err := netutil.IPFromReversedAddr(host) + if err != nil { + log.Debug("filtering: failed to parse PTR record %q: %s", host, err) + + return nil, nil, false + } + + addr, _ := netip.AddrFromSlice(ip) + names := hs.ByAddr(addr) + + for _, name := range names { + vals = append(vals, name) + rls = append(rls, &ResultRule{ + Text: fmt.Sprintf("%s %s", addr, name), + FilterListID: SysHostsListID, + }) + } + + return vals, rls, len(names) > 0 + default: + log.Debug("filtering: unsupported qtype %d", qtype) + + return nil, nil, false + } + + addrs := hs.ByName(host) + for _, addr := range addrs { + if isValidProto(addr) { + vals = append(vals, addr) + } + rls = append(rls, &ResultRule{ + Text: fmt.Sprintf("%s %s", addr, host), + FilterListID: SysHostsListID, + }) + } + + return vals, rls, len(addrs) > 0 +} diff --git a/internal/filtering/hosts_test.go b/internal/filtering/hosts_test.go new file mode 100644 index 00000000..baa6675c --- /dev/null +++ b/internal/filtering/hosts_test.go @@ -0,0 +1,191 @@ +package filtering + +import ( + "fmt" + "net/netip" + "testing" + "testing/fstest" + + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/golibs/testutil" + "github.com/AdguardTeam/urlfilter/rules" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { + addrv4 := netip.MustParseAddr("1.2.3.4") + addrv6 := netip.MustParseAddr("::1") + addrMapped := netip.MustParseAddr("::ffff:1.2.3.4") + addrv4Dup := netip.MustParseAddr("4.3.2.1") + + data := fmt.Sprintf( + ""+ + "%[1]s v4.host.example\n"+ + "%[2]s v6.host.example\n"+ + "%[3]s mapped.host.example\n"+ + "%[4]s v4.host.with-dup\n"+ + "%[4]s v4.host.with-dup\n", + addrv4, + addrv6, + addrMapped, + addrv4Dup, + ) + + files := fstest.MapFS{ + "hosts": &fstest.MapFile{ + Data: []byte(data), + }, + } + watcher := &aghtest.FSWatcher{ + OnEvents: func() (e <-chan struct{}) { return nil }, + OnAdd: func(name string) (err error) { return nil }, + OnClose: func() (err error) { return nil }, + } + hc, err := aghnet.NewHostsContainer(files, watcher, "hosts") + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, hc.Close) + + conf := &Config{ + EtcHosts: hc, + } + f, err := New(conf, nil) + require.NoError(t, err) + + setts := &Settings{ + FilteringEnabled: true, + } + + testCases := []struct { + name string + host string + wantRules []*ResultRule + wantResps []rules.RRValue + dtyp uint16 + }{{ + name: "v4", + host: "v4.host.example", + dtyp: dns.TypeA, + wantRules: []*ResultRule{{ + Text: "1.2.3.4 v4.host.example", + FilterListID: SysHostsListID, + }}, + wantResps: []rules.RRValue{addrv4}, + }, { + name: "v6", + host: "v6.host.example", + dtyp: dns.TypeAAAA, + wantRules: []*ResultRule{{ + Text: "::1 v6.host.example", + FilterListID: SysHostsListID, + }}, + wantResps: []rules.RRValue{addrv6}, + }, { + name: "mapped", + host: "mapped.host.example", + dtyp: dns.TypeAAAA, + wantRules: []*ResultRule{{ + Text: "::ffff:1.2.3.4 mapped.host.example", + FilterListID: SysHostsListID, + }}, + wantResps: []rules.RRValue{addrMapped}, + }, { + name: "ptr", + host: "4.3.2.1.in-addr.arpa", + dtyp: dns.TypePTR, + wantRules: []*ResultRule{{ + Text: "1.2.3.4 v4.host.example", + FilterListID: SysHostsListID, + }}, + wantResps: []rules.RRValue{"v4.host.example"}, + }, { + name: "ptr-mapped", + host: "4.0.3.0.2.0.1.0.f.f.f.f.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", + dtyp: dns.TypePTR, + wantRules: []*ResultRule{{ + Text: "::ffff:1.2.3.4 mapped.host.example", + FilterListID: SysHostsListID, + }}, + wantResps: []rules.RRValue{"mapped.host.example"}, + }, { + name: "not_found_v4", + host: "non.existent.example", + dtyp: dns.TypeA, + wantRules: nil, + wantResps: nil, + }, { + name: "not_found_v6", + host: "non.existent.example", + dtyp: dns.TypeAAAA, + wantRules: nil, + wantResps: nil, + }, { + name: "not_found_ptr", + host: "4.3.2.2.in-addr.arpa", + dtyp: dns.TypePTR, + wantRules: nil, + wantResps: nil, + }, { + name: "v4_mismatch", + host: "v4.host.example", + dtyp: dns.TypeAAAA, + wantRules: []*ResultRule{{ + Text: fmt.Sprintf("%s v4.host.example", addrv4), + FilterListID: SysHostsListID, + }}, + wantResps: nil, + }, { + name: "v6_mismatch", + host: "v6.host.example", + dtyp: dns.TypeA, + wantRules: []*ResultRule{{ + Text: fmt.Sprintf("%s v6.host.example", addrv6), + FilterListID: SysHostsListID, + }}, + wantResps: nil, + }, { + name: "wrong_ptr", + host: "4.3.2.1.ip6.arpa", + dtyp: dns.TypePTR, + wantRules: nil, + wantResps: nil, + }, { + name: "unsupported_type", + host: "v4.host.example", + dtyp: dns.TypeCNAME, + wantRules: nil, + wantResps: nil, + }, { + name: "v4_dup", + host: "v4.host.with-dup", + dtyp: dns.TypeA, + wantRules: []*ResultRule{{ + Text: "4.3.2.1 v4.host.with-dup", + FilterListID: SysHostsListID, + }}, + wantResps: []rules.RRValue{addrv4Dup}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var res Result + res, err = f.CheckHost(tc.host, tc.dtyp, setts) + require.NoError(t, err) + + if len(tc.wantRules) == 0 { + assert.Empty(t, res.Rules) + assert.Nil(t, res.DNSRewriteResult) + + return + } + + require.NotNil(t, res.DNSRewriteResult) + require.Contains(t, res.DNSRewriteResult.Response, tc.dtyp) + + assert.Equal(t, tc.wantResps, res.DNSRewriteResult.Response[tc.dtyp]) + assert.Equal(t, tc.wantRules, res.Rules) + }) + } +}