all: rewrite package dependency

This commit is contained in:
Dimitry Kolyshev 2022-12-30 12:36:47 +07:00
parent c2abedec70
commit 18392943fa
14 changed files with 320 additions and 374 deletions

View file

@ -68,7 +68,7 @@ func createTestServer(
ID: 0, Data: []byte(rules),
}}
f, err := filtering.New(filterConf, filters)
f, err := filtering.New(filterConf, filters, nil)
require.NoError(t, err)
f.SetEnabled(true)
@ -761,7 +761,7 @@ func TestBlockedCustomIP(t *testing.T) {
Data: []byte(rules),
}}
f, err := filtering.New(&filtering.Config{}, filters)
f, err := filtering.New(&filtering.Config{}, filters, nil)
require.NoError(t, err)
s, err := NewServer(DNSCreateParams{
@ -881,7 +881,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
func TestRewrite(t *testing.T) {
c := &filtering.Config{
Rewrites: []*rewrite.Item{{
Rewrites: []*filtering.RewriteItem{{
Domain: "test.com",
Answer: "1.2.3.4",
}, {
@ -892,7 +892,11 @@ func TestRewrite(t *testing.T) {
Answer: "example.org",
}},
}
f, err := filtering.New(c, nil)
rewriteStorage, err := rewrite.NewDefaultStorage(c.Rewrites)
require.NoError(t, err)
f, err := filtering.New(c, nil, rewriteStorage)
require.NoError(t, err)
f.SetEnabled(true)
@ -943,6 +947,12 @@ func TestRewrite(t *testing.T) {
assert.Empty(t, reply.Answer)
req = createTestMessageWithType("test.com.", dns.TypeTXT)
reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, eerr)
assert.Empty(t, reply.Answer)
req = createTestMessageWithType("alias.test.com.", dns.TypeA)
reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, eerr)
@ -953,6 +963,12 @@ func TestRewrite(t *testing.T) {
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
req = createTestMessageWithType("alias.test.com.", dns.TypeTXT)
reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, eerr)
assert.Empty(t, reply.Answer)
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, eerr)
@ -966,6 +982,12 @@ func TestRewrite(t *testing.T) {
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
req = createTestMessageWithType("my.alias.test.com.", dns.TypeTXT)
reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, eerr)
assert.Empty(t, reply.Answer)
}
for _, protect := range []bool{true, false} {
@ -1010,7 +1032,7 @@ var testDHCP = &dhcpd.MockInterface{
func TestPTRResponseFromDHCPLeases(t *testing.T) {
const localDomain = "lan"
flt, err := filtering.New(&filtering.Config{}, nil)
flt, err := filtering.New(&filtering.Config{}, nil, nil)
require.NoError(t, err)
s, err := NewServer(DNSCreateParams{
@ -1084,7 +1106,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
flt, err := filtering.New(&filtering.Config{
EtcHosts: hc,
}, nil)
}, nil, nil)
require.NoError(t, err)
flt.SetEnabled(true)

View file

@ -35,7 +35,7 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
ID: 0, Data: []byte(rules),
}}
f, err := filtering.New(&filtering.Config{}, filters)
f, err := filtering.New(&filtering.Config{}, filters, nil)
require.NoError(t, err)
f.SetEnabled(true)

View file

@ -68,7 +68,7 @@ func TestFilters(t *testing.T) {
HTTPClient: &http.Client{
Timeout: 5 * time.Second,
},
}, nil)
}, nil, nil)
require.NoError(t, err)
f := &FilterYAML{

View file

@ -18,7 +18,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rewrite"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/errors"
@ -91,7 +90,7 @@ type Config struct {
ParentalCacheSize uint `yaml:"parental_cache_size"` // (in bytes)
CacheTime uint `yaml:"cache_time"` // Element's TTL (in minutes)
Rewrites []*rewrite.Item `yaml:"rewrites"`
Rewrites []*RewriteItem `yaml:"rewrites"`
// Names of services to block (globally).
// Per-client settings can override this configuration.
@ -195,7 +194,7 @@ type DNSFilter struct {
// TODO(e.burkov): Don't use regexp for such a simple text processing task.
filterTitleRegexp *regexp.Regexp
rewriteStorage *rewrite.DefaultStorage
rewriteStorage RewriteStorage
hostCheckers []hostChecker
}
@ -544,6 +543,10 @@ func (d *DNSFilter) processRewrites(host string, qtype uint16) (res Result) {
d.confLock.RLock()
defer d.confLock.RUnlock()
if d.rewriteStorage == nil {
return res
}
dnsr := d.rewriteStorage.MatchRequest(&urlfilter.DNSRequest{
Hostname: host,
DNSType: qtype,
@ -893,7 +896,7 @@ func InitModule() {
// New creates properly initialized DNS Filter that is ready to be used. c must
// be non-nil.
func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
func New(c *Config, blockFilters []Filter, rewriteStorage RewriteStorage) (d *DNSFilter, err error) {
d = &DNSFilter{
resolver: net.DefaultResolver,
refreshLock: &sync.Mutex{},
@ -946,11 +949,7 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
d.Config = *c
d.filtersMu = &sync.RWMutex{}
d.rewriteStorage, err = rewrite.NewDefaultStorage(RewritesListID, d.Rewrites)
if err != nil {
return nil, fmt.Errorf("rewrites: init: %w", err)
}
d.rewriteStorage = rewriteStorage
bsvcs := []string{}
for _, s := range d.BlockedServices {

View file

@ -9,7 +9,6 @@ import (
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rewrite"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/testutil"
@ -47,6 +46,7 @@ func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts
ProtectionEnabled: true,
FilteringEnabled: true,
}
if c != nil {
c.SafeBrowsingCacheSize = 10000
c.ParentalCacheSize = 10000
@ -59,7 +59,8 @@ func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts
// It must not be nil.
c = &Config{}
}
f, err := New(c, filters)
f, err := New(c, filters, nil)
require.NoError(t, err)
purgeCaches(f)
@ -695,96 +696,6 @@ func TestMatching(t *testing.T) {
}
}
func TestRewrites(t *testing.T) {
rewrites := []*rewrite.Item{{
Domain: "example.org",
Answer: "1.1.1.1",
}, {
Domain: "example-v6.org",
Answer: "1:2:3::4",
}, {
Domain: "cname.org",
Answer: "cname-res.org",
}}
testCases := []struct {
name string
host string
wantReason Reason
wantIsFiltered bool
qtype uint16
}{{
name: "not_found_a",
host: "not-example.org",
wantIsFiltered: false,
wantReason: NotFilteredNotFound,
qtype: dns.TypeA,
}, {
name: "not_found_aaaa",
host: "not-example.org",
wantIsFiltered: false,
wantReason: NotFilteredNotFound,
qtype: dns.TypeAAAA,
}, {
name: "not_found_txt",
host: "not-example.org",
wantIsFiltered: false,
wantReason: NotFilteredNotFound,
qtype: dns.TypeTXT,
}, {
name: "found_a",
host: "example.org",
wantIsFiltered: false,
wantReason: Rewritten,
qtype: dns.TypeA,
}, {
name: "found_aaaa",
host: "example-v6.org",
wantIsFiltered: false,
wantReason: Rewritten,
qtype: dns.TypeAAAA,
}, {
name: "found_txt",
host: "example.org",
wantIsFiltered: false,
wantReason: NotFilteredNotFound,
qtype: dns.TypeTXT,
}, {
name: "cname_a",
host: "cname.org",
wantIsFiltered: false,
wantReason: Rewritten,
qtype: dns.TypeA,
}, {
name: "cname_aaaa",
host: "cname.org",
wantIsFiltered: false,
wantReason: Rewritten,
qtype: dns.TypeAAAA,
}, {
name: "cname_txt",
host: "cname.org",
wantIsFiltered: false,
wantReason: NotFilteredNotFound,
qtype: dns.TypeTXT,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
d, setts := newForTest(t, &Config{
Rewrites: rewrites,
}, nil)
t.Cleanup(d.Close)
res, err := d.CheckHost(tc.host, tc.qtype, setts)
require.NoError(t, err)
assert.Equal(t, tc.wantIsFiltered, res.IsFiltered)
assert.Equal(t, tc.wantReason, res.Reason)
})
}
}
func TestWhitelist(t *testing.T) {
rules := `||host1^
||host2^

View file

@ -105,7 +105,7 @@ func TestDNSFilter_handleFilteringSetURL(t *testing.T) {
},
ConfigModified: func() { confModifiedCalled = true },
DataDir: filtersDir,
}, nil)
}, nil, nil)
require.NoError(t, err)
t.Cleanup(d.Close)

View file

@ -0,0 +1,42 @@
package filtering
import (
"github.com/AdguardTeam/urlfilter"
"github.com/AdguardTeam/urlfilter/rules"
)
// RewriteStorage is a storage for rewrite rules.
type RewriteStorage interface {
// MatchRequest returns matching dnsrewrites for the specified request.
MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.DNSRewrite)
// Add adds item to the storage.
Add(item *RewriteItem) (err error)
// Remove deletes item from the storage.
Remove(item *RewriteItem) (err error)
// List returns all items from the storage.
List() (items []*RewriteItem)
}
// RewriteItem is a single DNS rewrite record.
type RewriteItem struct {
// Domain is the domain pattern for which this rewrite should work.
Domain string `yaml:"domain" json:"domain"`
// Answer is the IP address, canonical name, or one of the special
// values: "A" or "AAAA".
Answer string `yaml:"answer" json:"answer"`
}
// Equal returns true if rw is Equal to other.
func (rw *RewriteItem) Equal(other *RewriteItem) (ok bool) {
if rw == nil {
return other == nil
} else if other == nil {
return false
}
return *rw == *other
}

View file

@ -1,73 +0,0 @@
package rewrite
import (
"fmt"
"net/netip"
"strings"
"github.com/miekg/dns"
)
// Item is a single DNS rewrite record.
type Item struct {
// Domain is the domain pattern for which this rewrite should work.
Domain string `yaml:"domain" json:"domain"`
// Answer is the IP address, canonical name, or one of the special
// values: "A" or "AAAA".
Answer string `yaml:"answer" json:"answer"`
}
// equal returns true if rw is equal to other.
func (rw *Item) equal(other *Item) (ok bool) {
if rw == nil {
return other == nil
} else if other == nil {
return false
}
return *rw == *other
}
// toRule converts rw to a filter rule.
func (rw *Item) toRule() (res string) {
if rw == nil {
return ""
}
domain := strings.ToLower(rw.Domain)
dType, exception := rw.rewriteParams()
dTypeKey := dns.TypeToString[dType]
if exception {
return fmt.Sprintf("@@||%s^$dnstype=%s,dnsrewrite", domain, dTypeKey)
}
return fmt.Sprintf("|%s^$dnsrewrite=NOERROR;%s;%s", domain, dTypeKey, rw.Answer)
}
// rewriteParams returns dns request type and exception flag for rw.
func (rw *Item) rewriteParams() (dType uint16, exception bool) {
switch rw.Answer {
case "AAAA":
return dns.TypeAAAA, true
case "A":
return dns.TypeA, true
default:
// Go on.
}
addr, err := netip.ParseAddr(rw.Answer)
if err != nil {
// TODO(d.kolyshev): Validate rw.Answer as a domain name.
return dns.TypeCNAME, false
}
if addr.Is4() {
dType = dns.TypeA
} else {
dType = dns.TypeAAAA
}
return dType, false
}

View file

@ -1,124 +0,0 @@
package rewrite
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestItem_equal(t *testing.T) {
const (
testDomain = "example.org"
testAnswer = "1.1.1.1"
)
testItem := &Item{
Domain: testDomain,
Answer: testAnswer,
}
testCases := []struct {
name string
left *Item
right *Item
want bool
}{{
name: "nil_left",
left: nil,
right: testItem,
want: false,
}, {
name: "nil_right",
left: testItem,
right: nil,
want: false,
}, {
name: "nils",
left: nil,
right: nil,
want: true,
}, {
name: "equal",
left: testItem,
right: testItem,
want: true,
}, {
name: "distinct",
left: testItem,
right: &Item{
Domain: "other",
Answer: "other",
},
want: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
res := tc.left.equal(tc.right)
assert.Equal(t, tc.want, res)
})
}
}
func TestItem_toRule(t *testing.T) {
const testDomain = "example.org"
testCases := []struct {
name string
item *Item
want string
}{{
name: "nil",
item: nil,
want: "",
}, {
name: "a_rule",
item: &Item{
Domain: testDomain,
Answer: "1.1.1.1",
},
want: "|example.org^$dnsrewrite=NOERROR;A;1.1.1.1",
}, {
name: "aaaa_rule",
item: &Item{
Domain: testDomain,
Answer: "1:2:3::4",
},
want: "|example.org^$dnsrewrite=NOERROR;AAAA;1:2:3::4",
}, {
name: "cname_rule",
item: &Item{
Domain: testDomain,
Answer: "other.org",
},
want: "|example.org^$dnsrewrite=NOERROR;CNAME;other.org",
}, {
name: "wildcard_rule",
item: &Item{
Domain: "*.example.org",
Answer: "other.org",
},
want: "|*.example.org^$dnsrewrite=NOERROR;CNAME;other.org",
}, {
name: "aaaa_exception",
item: &Item{
Domain: testDomain,
Answer: "A",
},
want: "@@||example.org^$dnstype=A,dnsrewrite",
}, {
name: "aaaa_exception",
item: &Item{
Domain: testDomain,
Answer: "AAAA",
},
want: "@@||example.org^$dnstype=AAAA,dnsrewrite",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
res := tc.item.toRule()
assert.Equal(t, tc.want, res)
})
}
}

View file

@ -3,9 +3,11 @@ package rewrite
import (
"fmt"
"net/netip"
"strings"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/urlfilter"
@ -15,21 +17,6 @@ import (
"golang.org/x/exp/slices"
)
// Storage is a storage for rewrite rules.
type Storage interface {
// MatchRequest returns matching dnsrewrites for the specified request.
MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.DNSRewrite)
// Add adds item to the storage.
Add(item *Item) (err error)
// Remove deletes item from the storage.
Remove(item *Item) (err error)
// List returns all items from the storage.
List() (items []*Item)
}
// DefaultStorage is the default storage for rewrite rules.
type DefaultStorage struct {
// mu protects items.
@ -42,7 +29,7 @@ type DefaultStorage struct {
ruleList filterlist.RuleList
// rewrites stores the rewrite entries from configuration.
rewrites []*Item
rewrites []*filtering.RewriteItem
// urlFilterID is the synthetic integer identifier for the urlfilter engine.
//
@ -53,10 +40,10 @@ type DefaultStorage struct {
// NewDefaultStorage returns new rewrites storage. listID is used as an
// identifier of the underlying rules list. rewrites must not be nil.
func NewDefaultStorage(listID int, rewrites []*Item) (s *DefaultStorage, err error) {
func NewDefaultStorage(rewrites []*filtering.RewriteItem) (s *DefaultStorage, err error) {
s = &DefaultStorage{
mu: &sync.RWMutex{},
urlFilterID: listID,
urlFilterID: filtering.RewritesListID,
rewrites: rewrites,
}
@ -69,9 +56,9 @@ func NewDefaultStorage(listID int, rewrites []*Item) (s *DefaultStorage, err err
}
// type check
var _ Storage = (*DefaultStorage)(nil)
var _ filtering.RewriteStorage = (*DefaultStorage)(nil)
// MatchRequest implements the [Storage] interface for *DefaultStorage.
// MatchRequest implements the [RewriteStorage] interface for *DefaultStorage.
func (s *DefaultStorage) MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.DNSRewrite) {
s.mu.RLock()
defer s.mu.RUnlock()
@ -160,8 +147,8 @@ func (s *DefaultStorage) rewriteRulesForReq(dReq *urlfilter.DNSRequest) (rules [
return res.DNSRewrites()
}
// Add implements the [Storage] interface for *DefaultStorage.
func (s *DefaultStorage) Add(item *Item) (err error) {
// Add implements the [RewriteStorage] interface for *DefaultStorage.
func (s *DefaultStorage) Add(item *filtering.RewriteItem) (err error) {
s.mu.Lock()
defer s.mu.Unlock()
@ -171,16 +158,16 @@ func (s *DefaultStorage) Add(item *Item) (err error) {
return s.resetRules()
}
// Remove implements the [Storage] interface for *DefaultStorage.
func (s *DefaultStorage) Remove(item *Item) (err error) {
// Remove implements the [RewriteStorage] interface for *DefaultStorage.
func (s *DefaultStorage) Remove(item *filtering.RewriteItem) (err error) {
s.mu.Lock()
defer s.mu.Unlock()
arr := []*Item{}
arr := []*filtering.RewriteItem{}
// TODO(d.kolyshev): Use slices.IndexFunc + slices.Delete?
for _, ent := range s.rewrites {
if ent.equal(item) {
if ent.Equal(item) {
log.Debug("rewrite: removed element: %s -> %s", ent.Domain, ent.Answer)
continue
@ -193,8 +180,8 @@ func (s *DefaultStorage) Remove(item *Item) (err error) {
return s.resetRules()
}
// List implements the [Storage] interface for *DefaultStorage.
func (s *DefaultStorage) List() (items []*Item) {
// List implements the [RewriteStorage] interface for *DefaultStorage.
func (s *DefaultStorage) List() (items []*filtering.RewriteItem) {
s.mu.RLock()
defer s.mu.RUnlock()
@ -206,7 +193,7 @@ func (s *DefaultStorage) resetRules() (err error) {
// TODO(a.garipov): Use strings.Builder.
var rulesText []string
for _, rewrite := range s.rewrites {
rulesText = append(rulesText, rewrite.toRule())
rulesText = append(rulesText, toRule(rewrite))
}
strList := &filterlist.StringRuleList{
@ -247,3 +234,46 @@ func matchesQType(dnsrr *rules.DNSRewrite, qt uint16) (ok bool) {
func isWildcard(pat string) (res bool) {
return strings.HasPrefix(pat, "|*.")
}
// toRule converts rw to a filter rule.
func toRule(rw *filtering.RewriteItem) (res string) {
if rw == nil {
return ""
}
domain := strings.ToLower(rw.Domain)
dType, exception := rewriteParams(rw)
dTypeKey := dns.TypeToString[dType]
if exception {
return fmt.Sprintf("@@||%s^$dnstype=%s,dnsrewrite", domain, dTypeKey)
}
return fmt.Sprintf("|%s^$dnsrewrite=NOERROR;%s;%s", domain, dTypeKey, rw.Answer)
}
// RewriteParams returns dns request type and exception flag for rw.
func rewriteParams(rw *filtering.RewriteItem) (dType uint16, exception bool) {
switch rw.Answer {
case "AAAA":
return dns.TypeAAAA, true
case "A":
return dns.TypeA, true
default:
// Go on.
}
addr, err := netip.ParseAddr(rw.Answer)
if err != nil {
// TODO(d.kolyshev): Validate rw.Answer as a domain name.
return dns.TypeCNAME, false
}
if addr.Is4() {
dType = dns.TypeA
} else {
dType = dns.TypeAAAA
}
return dType, false
}

View file

@ -4,6 +4,7 @@ import (
"net"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/urlfilter"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
@ -12,32 +13,32 @@ import (
)
func TestNewDefaultStorage(t *testing.T) {
items := []*Item{{
items := []*filtering.RewriteItem{{
Domain: "example.com",
Answer: "answer.com",
}}
s, err := NewDefaultStorage(-1, items)
s, err := NewDefaultStorage(items)
require.NoError(t, err)
require.Len(t, s.List(), 1)
}
func TestDefaultStorage_CRUD(t *testing.T) {
var items []*Item
var items []*filtering.RewriteItem
s, err := NewDefaultStorage(-1, items)
s, err := NewDefaultStorage(items)
require.NoError(t, err)
require.Len(t, s.List(), 0)
item := &Item{Domain: "example.com", Answer: "answer.com"}
item := &filtering.RewriteItem{Domain: "example.com", Answer: "answer.com"}
err = s.Add(item)
require.NoError(t, err)
list := s.List()
require.Len(t, list, 1)
require.True(t, item.equal(list[0]))
require.True(t, item.Equal(list[0]))
err = s.Remove(item)
require.NoError(t, err)
@ -45,7 +46,7 @@ func TestDefaultStorage_CRUD(t *testing.T) {
}
func TestDefaultStorage_MatchRequest(t *testing.T) {
items := []*Item{{
items := []*filtering.RewriteItem{{
// This one and below are about CNAME, A and AAAA.
Domain: "somecname",
Answer: "somehost.com",
@ -101,7 +102,7 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
Answer: "sub.issue4016.com",
}}
s, err := NewDefaultStorage(-1, items)
s, err := NewDefaultStorage(items)
require.NoError(t, err)
testCases := []struct {
@ -285,7 +286,7 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
func TestDefaultStorage_MatchRequest_Levels(t *testing.T) {
// Exact host, wildcard L2, wildcard L3.
items := []*Item{{
items := []*filtering.RewriteItem{{
Domain: "host.com",
Answer: "1.1.1.1",
}, {
@ -296,7 +297,7 @@ func TestDefaultStorage_MatchRequest_Levels(t *testing.T) {
Answer: "3.3.3.3",
}}
s, err := NewDefaultStorage(-1, items)
s, err := NewDefaultStorage(items)
require.NoError(t, err)
testCases := []struct {
@ -355,7 +356,7 @@ func TestDefaultStorage_MatchRequest_Levels(t *testing.T) {
func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) {
// Wildcard and exception for a sub-domain.
items := []*Item{{
items := []*filtering.RewriteItem{{
Domain: "*.host.com",
Answer: "2.2.2.2",
}, {
@ -366,7 +367,7 @@ func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) {
Answer: "sub.host.com",
}}
s, err := NewDefaultStorage(-1, items)
s, err := NewDefaultStorage(items)
require.NoError(t, err)
testCases := []struct {
@ -410,7 +411,7 @@ func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) {
func TestDefaultStorage_MatchRequest_CNAMEs(t *testing.T) {
// Two cname rules for one subdomain
items := []*Item{{
items := []*filtering.RewriteItem{{
Domain: "cname.org",
Answer: "1.1.1.1",
}, {
@ -424,7 +425,7 @@ func TestDefaultStorage_MatchRequest_CNAMEs(t *testing.T) {
Answer: "sub_cname.org",
}}
s, err := NewDefaultStorage(-1, items)
s, err := NewDefaultStorage(items)
require.NoError(t, err)
testCases := []struct {
@ -478,7 +479,7 @@ func TestDefaultStorage_MatchRequest_CNAMEs(t *testing.T) {
func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) {
// Exception for AAAA record.
items := []*Item{{
items := []*filtering.RewriteItem{{
Domain: "host.com",
Answer: "1.2.3.4",
}, {
@ -495,7 +496,7 @@ func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) {
Answer: "A",
}}
s, err := NewDefaultStorage(-1, items)
s, err := NewDefaultStorage(items)
require.NoError(t, err)
testCases := []struct {
@ -556,3 +557,66 @@ func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) {
})
}
}
func TestToRule(t *testing.T) {
const testDomain = "example.org"
testCases := []struct {
name string
item *filtering.RewriteItem
want string
}{{
name: "nil",
item: nil,
want: "",
}, {
name: "a_rule",
item: &filtering.RewriteItem{
Domain: testDomain,
Answer: "1.1.1.1",
},
want: "|example.org^$dnsrewrite=NOERROR;A;1.1.1.1",
}, {
name: "aaaa_rule",
item: &filtering.RewriteItem{
Domain: testDomain,
Answer: "1:2:3::4",
},
want: "|example.org^$dnsrewrite=NOERROR;AAAA;1:2:3::4",
}, {
name: "cname_rule",
item: &filtering.RewriteItem{
Domain: testDomain,
Answer: "other.org",
},
want: "|example.org^$dnsrewrite=NOERROR;CNAME;other.org",
}, {
name: "wildcard_rule",
item: &filtering.RewriteItem{
Domain: "*.example.org",
Answer: "other.org",
},
want: "|*.example.org^$dnsrewrite=NOERROR;CNAME;other.org",
}, {
name: "aaaa_exception",
item: &filtering.RewriteItem{
Domain: testDomain,
Answer: "A",
},
want: "@@||example.org^$dnstype=A,dnsrewrite",
}, {
name: "aaaa_exception",
item: &filtering.RewriteItem{
Domain: testDomain,
Answer: "AAAA",
},
want: "@@||example.org^$dnstype=AAAA,dnsrewrite",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
res := toRule(tc.item)
assert.Equal(t, tc.want, res)
})
}
}

View file

@ -0,0 +1,61 @@
package filtering
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestItem_equal(t *testing.T) {
const (
testDomain = "example.org"
testAnswer = "1.1.1.1"
)
testItem := &RewriteItem{
Domain: testDomain,
Answer: testAnswer,
}
testCases := []struct {
name string
left *RewriteItem
right *RewriteItem
want bool
}{{
name: "nil_left",
left: nil,
right: testItem,
want: false,
}, {
name: "nil_right",
left: testItem,
right: nil,
want: false,
}, {
name: "nils",
left: nil,
right: nil,
want: true,
}, {
name: "equal",
left: testItem,
right: testItem,
want: true,
}, {
name: "distinct",
left: testItem,
right: &RewriteItem{
Domain: "other",
Answer: "other",
},
want: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
res := tc.left.Equal(tc.right)
assert.Equal(t, tc.want, res)
})
}
}

View file

@ -5,7 +5,6 @@ import (
"net/http"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rewrite"
"github.com/AdguardTeam/golibs/log"
)
@ -16,7 +15,7 @@ func (d *DNSFilter) handleRewriteList(w http.ResponseWriter, r *http.Request) {
// handleRewriteAdd is the handler for the POST /control/rewrite/add HTTP API.
func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
rw := &rewrite.Item{}
rw := &RewriteItem{}
err := json.NewDecoder(r.Body).Decode(rw)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
@ -43,7 +42,7 @@ func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
// handleRewriteDelete is the handler for the POST /control/rewrite/delete HTTP
// API.
func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request) {
entDel := rewrite.Item{}
entDel := RewriteItem{}
err := json.NewDecoder(r.Body).Decode(&entDel)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)

View file

@ -12,6 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rewrite"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/proxy"
@ -76,30 +77,21 @@ func initDNSServer() (err error) {
}
Context.queryLog = querylog.New(conf)
Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil)
rewriteStorage, err := rewrite.NewDefaultStorage(config.DNS.DnsfilterConf.Rewrites)
if err != nil {
return fmt.Errorf("rewrites: init: %w", err)
}
Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil, rewriteStorage)
if err != nil {
// Don't wrap the error, since it's informative enough as is.
return err
}
var privateNets netutil.SubnetSet
switch len(config.DNS.PrivateNets) {
case 0:
// Use an optimized locally-served matcher.
privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
case 1:
privateNets, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
if err != nil {
return fmt.Errorf("preparing the set of private subnets: %w", err)
}
default:
var nets []*net.IPNet
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
if err != nil {
return fmt.Errorf("preparing the set of private subnets: %w", err)
}
privateNets = netutil.SliceSubnetSet(nets)
privateNets, err := initPrivateNets()
if err != nil {
// Don't wrap the error, since it's informative enough as is.
return err
}
p := dnsforward.DNSCreateParams{
@ -146,6 +138,29 @@ func initDNSServer() (err error) {
return nil
}
func initPrivateNets() (privateNets netutil.SubnetSet, err error) {
switch len(config.DNS.PrivateNets) {
case 0:
// Use an optimized locally-served matcher.
privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
case 1:
privateNets, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
if err != nil {
return nil, fmt.Errorf("preparing the set of private subnets: %w", err)
}
default:
var nets []*net.IPNet
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
if err != nil {
return nil, fmt.Errorf("preparing the set of private subnets: %w", err)
}
privateNets = netutil.SliceSubnetSet(nets)
}
return privateNets, nil
}
func isRunning() bool {
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
}