mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2025-03-14 22:48:35 +03:00
all: rewrite package dependency
This commit is contained in:
parent
c2abedec70
commit
18392943fa
14 changed files with 320 additions and 374 deletions
|
@ -68,7 +68,7 @@ func createTestServer(
|
|||
ID: 0, Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f, err := filtering.New(filterConf, filters)
|
||||
f, err := filtering.New(filterConf, filters, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
f.SetEnabled(true)
|
||||
|
@ -761,7 +761,7 @@ func TestBlockedCustomIP(t *testing.T) {
|
|||
Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f, err := filtering.New(&filtering.Config{}, filters)
|
||||
f, err := filtering.New(&filtering.Config{}, filters, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
|
@ -881,7 +881,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
|||
|
||||
func TestRewrite(t *testing.T) {
|
||||
c := &filtering.Config{
|
||||
Rewrites: []*rewrite.Item{{
|
||||
Rewrites: []*filtering.RewriteItem{{
|
||||
Domain: "test.com",
|
||||
Answer: "1.2.3.4",
|
||||
}, {
|
||||
|
@ -892,7 +892,11 @@ func TestRewrite(t *testing.T) {
|
|||
Answer: "example.org",
|
||||
}},
|
||||
}
|
||||
f, err := filtering.New(c, nil)
|
||||
|
||||
rewriteStorage, err := rewrite.NewDefaultStorage(c.Rewrites)
|
||||
require.NoError(t, err)
|
||||
|
||||
f, err := filtering.New(c, nil, rewriteStorage)
|
||||
require.NoError(t, err)
|
||||
|
||||
f.SetEnabled(true)
|
||||
|
@ -943,6 +947,12 @@ func TestRewrite(t *testing.T) {
|
|||
|
||||
assert.Empty(t, reply.Answer)
|
||||
|
||||
req = createTestMessageWithType("test.com.", dns.TypeTXT)
|
||||
reply, eerr = dns.Exchange(req, addr.String())
|
||||
require.NoError(t, eerr)
|
||||
|
||||
assert.Empty(t, reply.Answer)
|
||||
|
||||
req = createTestMessageWithType("alias.test.com.", dns.TypeA)
|
||||
reply, eerr = dns.Exchange(req, addr.String())
|
||||
require.NoError(t, eerr)
|
||||
|
@ -953,6 +963,12 @@ func TestRewrite(t *testing.T) {
|
|||
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
|
||||
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
|
||||
|
||||
req = createTestMessageWithType("alias.test.com.", dns.TypeTXT)
|
||||
reply, eerr = dns.Exchange(req, addr.String())
|
||||
require.NoError(t, eerr)
|
||||
|
||||
assert.Empty(t, reply.Answer)
|
||||
|
||||
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
|
||||
reply, eerr = dns.Exchange(req, addr.String())
|
||||
require.NoError(t, eerr)
|
||||
|
@ -966,6 +982,12 @@ func TestRewrite(t *testing.T) {
|
|||
|
||||
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
|
||||
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
|
||||
|
||||
req = createTestMessageWithType("my.alias.test.com.", dns.TypeTXT)
|
||||
reply, eerr = dns.Exchange(req, addr.String())
|
||||
require.NoError(t, eerr)
|
||||
|
||||
assert.Empty(t, reply.Answer)
|
||||
}
|
||||
|
||||
for _, protect := range []bool{true, false} {
|
||||
|
@ -1010,7 +1032,7 @@ var testDHCP = &dhcpd.MockInterface{
|
|||
func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
const localDomain = "lan"
|
||||
|
||||
flt, err := filtering.New(&filtering.Config{}, nil)
|
||||
flt, err := filtering.New(&filtering.Config{}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
|
@ -1084,7 +1106,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
|||
|
||||
flt, err := filtering.New(&filtering.Config{
|
||||
EtcHosts: hc,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
flt.SetEnabled(true)
|
||||
|
|
|
@ -35,7 +35,7 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
|||
ID: 0, Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f, err := filtering.New(&filtering.Config{}, filters)
|
||||
f, err := filtering.New(&filtering.Config{}, filters, nil)
|
||||
require.NoError(t, err)
|
||||
f.SetEnabled(true)
|
||||
|
||||
|
|
|
@ -68,7 +68,7 @@ func TestFilters(t *testing.T) {
|
|||
HTTPClient: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
f := &FilterYAML{
|
||||
|
|
|
@ -18,7 +18,6 @@ import (
|
|||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rewrite"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
|
@ -91,7 +90,7 @@ type Config struct {
|
|||
ParentalCacheSize uint `yaml:"parental_cache_size"` // (in bytes)
|
||||
CacheTime uint `yaml:"cache_time"` // Element's TTL (in minutes)
|
||||
|
||||
Rewrites []*rewrite.Item `yaml:"rewrites"`
|
||||
Rewrites []*RewriteItem `yaml:"rewrites"`
|
||||
|
||||
// Names of services to block (globally).
|
||||
// Per-client settings can override this configuration.
|
||||
|
@ -195,7 +194,7 @@ type DNSFilter struct {
|
|||
// TODO(e.burkov): Don't use regexp for such a simple text processing task.
|
||||
filterTitleRegexp *regexp.Regexp
|
||||
|
||||
rewriteStorage *rewrite.DefaultStorage
|
||||
rewriteStorage RewriteStorage
|
||||
|
||||
hostCheckers []hostChecker
|
||||
}
|
||||
|
@ -544,6 +543,10 @@ func (d *DNSFilter) processRewrites(host string, qtype uint16) (res Result) {
|
|||
d.confLock.RLock()
|
||||
defer d.confLock.RUnlock()
|
||||
|
||||
if d.rewriteStorage == nil {
|
||||
return res
|
||||
}
|
||||
|
||||
dnsr := d.rewriteStorage.MatchRequest(&urlfilter.DNSRequest{
|
||||
Hostname: host,
|
||||
DNSType: qtype,
|
||||
|
@ -893,7 +896,7 @@ func InitModule() {
|
|||
|
||||
// New creates properly initialized DNS Filter that is ready to be used. c must
|
||||
// be non-nil.
|
||||
func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
||||
func New(c *Config, blockFilters []Filter, rewriteStorage RewriteStorage) (d *DNSFilter, err error) {
|
||||
d = &DNSFilter{
|
||||
resolver: net.DefaultResolver,
|
||||
refreshLock: &sync.Mutex{},
|
||||
|
@ -946,11 +949,7 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
|||
|
||||
d.Config = *c
|
||||
d.filtersMu = &sync.RWMutex{}
|
||||
|
||||
d.rewriteStorage, err = rewrite.NewDefaultStorage(RewritesListID, d.Rewrites)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rewrites: init: %w", err)
|
||||
}
|
||||
d.rewriteStorage = rewriteStorage
|
||||
|
||||
bsvcs := []string{}
|
||||
for _, s := range d.BlockedServices {
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rewrite"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
|
@ -47,6 +46,7 @@ func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts
|
|||
ProtectionEnabled: true,
|
||||
FilteringEnabled: true,
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
c.SafeBrowsingCacheSize = 10000
|
||||
c.ParentalCacheSize = 10000
|
||||
|
@ -59,7 +59,8 @@ func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts
|
|||
// It must not be nil.
|
||||
c = &Config{}
|
||||
}
|
||||
f, err := New(c, filters)
|
||||
|
||||
f, err := New(c, filters, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
purgeCaches(f)
|
||||
|
@ -695,96 +696,6 @@ func TestMatching(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRewrites(t *testing.T) {
|
||||
rewrites := []*rewrite.Item{{
|
||||
Domain: "example.org",
|
||||
Answer: "1.1.1.1",
|
||||
}, {
|
||||
Domain: "example-v6.org",
|
||||
Answer: "1:2:3::4",
|
||||
}, {
|
||||
Domain: "cname.org",
|
||||
Answer: "cname-res.org",
|
||||
}}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
wantReason Reason
|
||||
wantIsFiltered bool
|
||||
qtype uint16
|
||||
}{{
|
||||
name: "not_found_a",
|
||||
host: "not-example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "not_found_aaaa",
|
||||
host: "not-example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
qtype: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "not_found_txt",
|
||||
host: "not-example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
qtype: dns.TypeTXT,
|
||||
}, {
|
||||
name: "found_a",
|
||||
host: "example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: Rewritten,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "found_aaaa",
|
||||
host: "example-v6.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: Rewritten,
|
||||
qtype: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "found_txt",
|
||||
host: "example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
qtype: dns.TypeTXT,
|
||||
}, {
|
||||
name: "cname_a",
|
||||
host: "cname.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: Rewritten,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "cname_aaaa",
|
||||
host: "cname.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: Rewritten,
|
||||
qtype: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "cname_txt",
|
||||
host: "cname.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
qtype: dns.TypeTXT,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
d, setts := newForTest(t, &Config{
|
||||
Rewrites: rewrites,
|
||||
}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
res, err := d.CheckHost(tc.host, tc.qtype, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantIsFiltered, res.IsFiltered)
|
||||
assert.Equal(t, tc.wantReason, res.Reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhitelist(t *testing.T) {
|
||||
rules := `||host1^
|
||||
||host2^
|
||||
|
|
|
@ -105,7 +105,7 @@ func TestDNSFilter_handleFilteringSetURL(t *testing.T) {
|
|||
},
|
||||
ConfigModified: func() { confModifiedCalled = true },
|
||||
DataDir: filtersDir,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
|
|
42
internal/filtering/rewrite.go
Normal file
42
internal/filtering/rewrite.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package filtering
|
||||
|
||||
import (
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
)
|
||||
|
||||
// RewriteStorage is a storage for rewrite rules.
|
||||
type RewriteStorage interface {
|
||||
// MatchRequest returns matching dnsrewrites for the specified request.
|
||||
MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.DNSRewrite)
|
||||
|
||||
// Add adds item to the storage.
|
||||
Add(item *RewriteItem) (err error)
|
||||
|
||||
// Remove deletes item from the storage.
|
||||
Remove(item *RewriteItem) (err error)
|
||||
|
||||
// List returns all items from the storage.
|
||||
List() (items []*RewriteItem)
|
||||
}
|
||||
|
||||
// RewriteItem is a single DNS rewrite record.
|
||||
type RewriteItem struct {
|
||||
// Domain is the domain pattern for which this rewrite should work.
|
||||
Domain string `yaml:"domain" json:"domain"`
|
||||
|
||||
// Answer is the IP address, canonical name, or one of the special
|
||||
// values: "A" or "AAAA".
|
||||
Answer string `yaml:"answer" json:"answer"`
|
||||
}
|
||||
|
||||
// Equal returns true if rw is Equal to other.
|
||||
func (rw *RewriteItem) Equal(other *RewriteItem) (ok bool) {
|
||||
if rw == nil {
|
||||
return other == nil
|
||||
} else if other == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return *rw == *other
|
||||
}
|
|
@ -1,73 +0,0 @@
|
|||
package rewrite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Item is a single DNS rewrite record.
|
||||
type Item struct {
|
||||
// Domain is the domain pattern for which this rewrite should work.
|
||||
Domain string `yaml:"domain" json:"domain"`
|
||||
|
||||
// Answer is the IP address, canonical name, or one of the special
|
||||
// values: "A" or "AAAA".
|
||||
Answer string `yaml:"answer" json:"answer"`
|
||||
}
|
||||
|
||||
// equal returns true if rw is equal to other.
|
||||
func (rw *Item) equal(other *Item) (ok bool) {
|
||||
if rw == nil {
|
||||
return other == nil
|
||||
} else if other == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return *rw == *other
|
||||
}
|
||||
|
||||
// toRule converts rw to a filter rule.
|
||||
func (rw *Item) toRule() (res string) {
|
||||
if rw == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
domain := strings.ToLower(rw.Domain)
|
||||
|
||||
dType, exception := rw.rewriteParams()
|
||||
dTypeKey := dns.TypeToString[dType]
|
||||
if exception {
|
||||
return fmt.Sprintf("@@||%s^$dnstype=%s,dnsrewrite", domain, dTypeKey)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("|%s^$dnsrewrite=NOERROR;%s;%s", domain, dTypeKey, rw.Answer)
|
||||
}
|
||||
|
||||
// rewriteParams returns dns request type and exception flag for rw.
|
||||
func (rw *Item) rewriteParams() (dType uint16, exception bool) {
|
||||
switch rw.Answer {
|
||||
case "AAAA":
|
||||
return dns.TypeAAAA, true
|
||||
case "A":
|
||||
return dns.TypeA, true
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(rw.Answer)
|
||||
if err != nil {
|
||||
// TODO(d.kolyshev): Validate rw.Answer as a domain name.
|
||||
return dns.TypeCNAME, false
|
||||
}
|
||||
|
||||
if addr.Is4() {
|
||||
dType = dns.TypeA
|
||||
} else {
|
||||
dType = dns.TypeAAAA
|
||||
}
|
||||
|
||||
return dType, false
|
||||
}
|
|
@ -1,124 +0,0 @@
|
|||
package rewrite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestItem_equal(t *testing.T) {
|
||||
const (
|
||||
testDomain = "example.org"
|
||||
testAnswer = "1.1.1.1"
|
||||
)
|
||||
|
||||
testItem := &Item{
|
||||
Domain: testDomain,
|
||||
Answer: testAnswer,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
left *Item
|
||||
right *Item
|
||||
want bool
|
||||
}{{
|
||||
name: "nil_left",
|
||||
left: nil,
|
||||
right: testItem,
|
||||
want: false,
|
||||
}, {
|
||||
name: "nil_right",
|
||||
left: testItem,
|
||||
right: nil,
|
||||
want: false,
|
||||
}, {
|
||||
name: "nils",
|
||||
left: nil,
|
||||
right: nil,
|
||||
want: true,
|
||||
}, {
|
||||
name: "equal",
|
||||
left: testItem,
|
||||
right: testItem,
|
||||
want: true,
|
||||
}, {
|
||||
name: "distinct",
|
||||
left: testItem,
|
||||
right: &Item{
|
||||
Domain: "other",
|
||||
Answer: "other",
|
||||
},
|
||||
want: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
res := tc.left.equal(tc.right)
|
||||
assert.Equal(t, tc.want, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestItem_toRule(t *testing.T) {
|
||||
const testDomain = "example.org"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
item *Item
|
||||
want string
|
||||
}{{
|
||||
name: "nil",
|
||||
item: nil,
|
||||
want: "",
|
||||
}, {
|
||||
name: "a_rule",
|
||||
item: &Item{
|
||||
Domain: testDomain,
|
||||
Answer: "1.1.1.1",
|
||||
},
|
||||
want: "|example.org^$dnsrewrite=NOERROR;A;1.1.1.1",
|
||||
}, {
|
||||
name: "aaaa_rule",
|
||||
item: &Item{
|
||||
Domain: testDomain,
|
||||
Answer: "1:2:3::4",
|
||||
},
|
||||
want: "|example.org^$dnsrewrite=NOERROR;AAAA;1:2:3::4",
|
||||
}, {
|
||||
name: "cname_rule",
|
||||
item: &Item{
|
||||
Domain: testDomain,
|
||||
Answer: "other.org",
|
||||
},
|
||||
want: "|example.org^$dnsrewrite=NOERROR;CNAME;other.org",
|
||||
}, {
|
||||
name: "wildcard_rule",
|
||||
item: &Item{
|
||||
Domain: "*.example.org",
|
||||
Answer: "other.org",
|
||||
},
|
||||
want: "|*.example.org^$dnsrewrite=NOERROR;CNAME;other.org",
|
||||
}, {
|
||||
name: "aaaa_exception",
|
||||
item: &Item{
|
||||
Domain: testDomain,
|
||||
Answer: "A",
|
||||
},
|
||||
want: "@@||example.org^$dnstype=A,dnsrewrite",
|
||||
}, {
|
||||
name: "aaaa_exception",
|
||||
item: &Item{
|
||||
Domain: testDomain,
|
||||
Answer: "AAAA",
|
||||
},
|
||||
want: "@@||example.org^$dnstype=AAAA,dnsrewrite",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
res := tc.item.toRule()
|
||||
assert.Equal(t, tc.want, res)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -3,9 +3,11 @@ package rewrite
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
|
@ -15,21 +17,6 @@ import (
|
|||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// Storage is a storage for rewrite rules.
|
||||
type Storage interface {
|
||||
// MatchRequest returns matching dnsrewrites for the specified request.
|
||||
MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.DNSRewrite)
|
||||
|
||||
// Add adds item to the storage.
|
||||
Add(item *Item) (err error)
|
||||
|
||||
// Remove deletes item from the storage.
|
||||
Remove(item *Item) (err error)
|
||||
|
||||
// List returns all items from the storage.
|
||||
List() (items []*Item)
|
||||
}
|
||||
|
||||
// DefaultStorage is the default storage for rewrite rules.
|
||||
type DefaultStorage struct {
|
||||
// mu protects items.
|
||||
|
@ -42,7 +29,7 @@ type DefaultStorage struct {
|
|||
ruleList filterlist.RuleList
|
||||
|
||||
// rewrites stores the rewrite entries from configuration.
|
||||
rewrites []*Item
|
||||
rewrites []*filtering.RewriteItem
|
||||
|
||||
// urlFilterID is the synthetic integer identifier for the urlfilter engine.
|
||||
//
|
||||
|
@ -53,10 +40,10 @@ type DefaultStorage struct {
|
|||
|
||||
// NewDefaultStorage returns new rewrites storage. listID is used as an
|
||||
// identifier of the underlying rules list. rewrites must not be nil.
|
||||
func NewDefaultStorage(listID int, rewrites []*Item) (s *DefaultStorage, err error) {
|
||||
func NewDefaultStorage(rewrites []*filtering.RewriteItem) (s *DefaultStorage, err error) {
|
||||
s = &DefaultStorage{
|
||||
mu: &sync.RWMutex{},
|
||||
urlFilterID: listID,
|
||||
urlFilterID: filtering.RewritesListID,
|
||||
rewrites: rewrites,
|
||||
}
|
||||
|
||||
|
@ -69,9 +56,9 @@ func NewDefaultStorage(listID int, rewrites []*Item) (s *DefaultStorage, err err
|
|||
}
|
||||
|
||||
// type check
|
||||
var _ Storage = (*DefaultStorage)(nil)
|
||||
var _ filtering.RewriteStorage = (*DefaultStorage)(nil)
|
||||
|
||||
// MatchRequest implements the [Storage] interface for *DefaultStorage.
|
||||
// MatchRequest implements the [RewriteStorage] interface for *DefaultStorage.
|
||||
func (s *DefaultStorage) MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.DNSRewrite) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
@ -160,8 +147,8 @@ func (s *DefaultStorage) rewriteRulesForReq(dReq *urlfilter.DNSRequest) (rules [
|
|||
return res.DNSRewrites()
|
||||
}
|
||||
|
||||
// Add implements the [Storage] interface for *DefaultStorage.
|
||||
func (s *DefaultStorage) Add(item *Item) (err error) {
|
||||
// Add implements the [RewriteStorage] interface for *DefaultStorage.
|
||||
func (s *DefaultStorage) Add(item *filtering.RewriteItem) (err error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
|
@ -171,16 +158,16 @@ func (s *DefaultStorage) Add(item *Item) (err error) {
|
|||
return s.resetRules()
|
||||
}
|
||||
|
||||
// Remove implements the [Storage] interface for *DefaultStorage.
|
||||
func (s *DefaultStorage) Remove(item *Item) (err error) {
|
||||
// Remove implements the [RewriteStorage] interface for *DefaultStorage.
|
||||
func (s *DefaultStorage) Remove(item *filtering.RewriteItem) (err error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
arr := []*Item{}
|
||||
arr := []*filtering.RewriteItem{}
|
||||
|
||||
// TODO(d.kolyshev): Use slices.IndexFunc + slices.Delete?
|
||||
for _, ent := range s.rewrites {
|
||||
if ent.equal(item) {
|
||||
if ent.Equal(item) {
|
||||
log.Debug("rewrite: removed element: %s -> %s", ent.Domain, ent.Answer)
|
||||
|
||||
continue
|
||||
|
@ -193,8 +180,8 @@ func (s *DefaultStorage) Remove(item *Item) (err error) {
|
|||
return s.resetRules()
|
||||
}
|
||||
|
||||
// List implements the [Storage] interface for *DefaultStorage.
|
||||
func (s *DefaultStorage) List() (items []*Item) {
|
||||
// List implements the [RewriteStorage] interface for *DefaultStorage.
|
||||
func (s *DefaultStorage) List() (items []*filtering.RewriteItem) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
|
@ -206,7 +193,7 @@ func (s *DefaultStorage) resetRules() (err error) {
|
|||
// TODO(a.garipov): Use strings.Builder.
|
||||
var rulesText []string
|
||||
for _, rewrite := range s.rewrites {
|
||||
rulesText = append(rulesText, rewrite.toRule())
|
||||
rulesText = append(rulesText, toRule(rewrite))
|
||||
}
|
||||
|
||||
strList := &filterlist.StringRuleList{
|
||||
|
@ -247,3 +234,46 @@ func matchesQType(dnsrr *rules.DNSRewrite, qt uint16) (ok bool) {
|
|||
func isWildcard(pat string) (res bool) {
|
||||
return strings.HasPrefix(pat, "|*.")
|
||||
}
|
||||
|
||||
// toRule converts rw to a filter rule.
|
||||
func toRule(rw *filtering.RewriteItem) (res string) {
|
||||
if rw == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
domain := strings.ToLower(rw.Domain)
|
||||
|
||||
dType, exception := rewriteParams(rw)
|
||||
dTypeKey := dns.TypeToString[dType]
|
||||
if exception {
|
||||
return fmt.Sprintf("@@||%s^$dnstype=%s,dnsrewrite", domain, dTypeKey)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("|%s^$dnsrewrite=NOERROR;%s;%s", domain, dTypeKey, rw.Answer)
|
||||
}
|
||||
|
||||
// RewriteParams returns dns request type and exception flag for rw.
|
||||
func rewriteParams(rw *filtering.RewriteItem) (dType uint16, exception bool) {
|
||||
switch rw.Answer {
|
||||
case "AAAA":
|
||||
return dns.TypeAAAA, true
|
||||
case "A":
|
||||
return dns.TypeA, true
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(rw.Answer)
|
||||
if err != nil {
|
||||
// TODO(d.kolyshev): Validate rw.Answer as a domain name.
|
||||
return dns.TypeCNAME, false
|
||||
}
|
||||
|
||||
if addr.Is4() {
|
||||
dType = dns.TypeA
|
||||
} else {
|
||||
dType = dns.TypeAAAA
|
||||
}
|
||||
|
||||
return dType, false
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
|
@ -12,32 +13,32 @@ import (
|
|||
)
|
||||
|
||||
func TestNewDefaultStorage(t *testing.T) {
|
||||
items := []*Item{{
|
||||
items := []*filtering.RewriteItem{{
|
||||
Domain: "example.com",
|
||||
Answer: "answer.com",
|
||||
}}
|
||||
|
||||
s, err := NewDefaultStorage(-1, items)
|
||||
s, err := NewDefaultStorage(items)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, s.List(), 1)
|
||||
}
|
||||
|
||||
func TestDefaultStorage_CRUD(t *testing.T) {
|
||||
var items []*Item
|
||||
var items []*filtering.RewriteItem
|
||||
|
||||
s, err := NewDefaultStorage(-1, items)
|
||||
s, err := NewDefaultStorage(items)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, s.List(), 0)
|
||||
|
||||
item := &Item{Domain: "example.com", Answer: "answer.com"}
|
||||
item := &filtering.RewriteItem{Domain: "example.com", Answer: "answer.com"}
|
||||
|
||||
err = s.Add(item)
|
||||
require.NoError(t, err)
|
||||
|
||||
list := s.List()
|
||||
require.Len(t, list, 1)
|
||||
require.True(t, item.equal(list[0]))
|
||||
require.True(t, item.Equal(list[0]))
|
||||
|
||||
err = s.Remove(item)
|
||||
require.NoError(t, err)
|
||||
|
@ -45,7 +46,7 @@ func TestDefaultStorage_CRUD(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDefaultStorage_MatchRequest(t *testing.T) {
|
||||
items := []*Item{{
|
||||
items := []*filtering.RewriteItem{{
|
||||
// This one and below are about CNAME, A and AAAA.
|
||||
Domain: "somecname",
|
||||
Answer: "somehost.com",
|
||||
|
@ -101,7 +102,7 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
|
|||
Answer: "sub.issue4016.com",
|
||||
}}
|
||||
|
||||
s, err := NewDefaultStorage(-1, items)
|
||||
s, err := NewDefaultStorage(items)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
|
@ -285,7 +286,7 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
|
|||
|
||||
func TestDefaultStorage_MatchRequest_Levels(t *testing.T) {
|
||||
// Exact host, wildcard L2, wildcard L3.
|
||||
items := []*Item{{
|
||||
items := []*filtering.RewriteItem{{
|
||||
Domain: "host.com",
|
||||
Answer: "1.1.1.1",
|
||||
}, {
|
||||
|
@ -296,7 +297,7 @@ func TestDefaultStorage_MatchRequest_Levels(t *testing.T) {
|
|||
Answer: "3.3.3.3",
|
||||
}}
|
||||
|
||||
s, err := NewDefaultStorage(-1, items)
|
||||
s, err := NewDefaultStorage(items)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
|
@ -355,7 +356,7 @@ func TestDefaultStorage_MatchRequest_Levels(t *testing.T) {
|
|||
|
||||
func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) {
|
||||
// Wildcard and exception for a sub-domain.
|
||||
items := []*Item{{
|
||||
items := []*filtering.RewriteItem{{
|
||||
Domain: "*.host.com",
|
||||
Answer: "2.2.2.2",
|
||||
}, {
|
||||
|
@ -366,7 +367,7 @@ func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) {
|
|||
Answer: "sub.host.com",
|
||||
}}
|
||||
|
||||
s, err := NewDefaultStorage(-1, items)
|
||||
s, err := NewDefaultStorage(items)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
|
@ -410,7 +411,7 @@ func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) {
|
|||
|
||||
func TestDefaultStorage_MatchRequest_CNAMEs(t *testing.T) {
|
||||
// Two cname rules for one subdomain
|
||||
items := []*Item{{
|
||||
items := []*filtering.RewriteItem{{
|
||||
Domain: "cname.org",
|
||||
Answer: "1.1.1.1",
|
||||
}, {
|
||||
|
@ -424,7 +425,7 @@ func TestDefaultStorage_MatchRequest_CNAMEs(t *testing.T) {
|
|||
Answer: "sub_cname.org",
|
||||
}}
|
||||
|
||||
s, err := NewDefaultStorage(-1, items)
|
||||
s, err := NewDefaultStorage(items)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
|
@ -478,7 +479,7 @@ func TestDefaultStorage_MatchRequest_CNAMEs(t *testing.T) {
|
|||
|
||||
func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) {
|
||||
// Exception for AAAA record.
|
||||
items := []*Item{{
|
||||
items := []*filtering.RewriteItem{{
|
||||
Domain: "host.com",
|
||||
Answer: "1.2.3.4",
|
||||
}, {
|
||||
|
@ -495,7 +496,7 @@ func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) {
|
|||
Answer: "A",
|
||||
}}
|
||||
|
||||
s, err := NewDefaultStorage(-1, items)
|
||||
s, err := NewDefaultStorage(items)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
|
@ -556,3 +557,66 @@ func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToRule(t *testing.T) {
|
||||
const testDomain = "example.org"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
item *filtering.RewriteItem
|
||||
want string
|
||||
}{{
|
||||
name: "nil",
|
||||
item: nil,
|
||||
want: "",
|
||||
}, {
|
||||
name: "a_rule",
|
||||
item: &filtering.RewriteItem{
|
||||
Domain: testDomain,
|
||||
Answer: "1.1.1.1",
|
||||
},
|
||||
want: "|example.org^$dnsrewrite=NOERROR;A;1.1.1.1",
|
||||
}, {
|
||||
name: "aaaa_rule",
|
||||
item: &filtering.RewriteItem{
|
||||
Domain: testDomain,
|
||||
Answer: "1:2:3::4",
|
||||
},
|
||||
want: "|example.org^$dnsrewrite=NOERROR;AAAA;1:2:3::4",
|
||||
}, {
|
||||
name: "cname_rule",
|
||||
item: &filtering.RewriteItem{
|
||||
Domain: testDomain,
|
||||
Answer: "other.org",
|
||||
},
|
||||
want: "|example.org^$dnsrewrite=NOERROR;CNAME;other.org",
|
||||
}, {
|
||||
name: "wildcard_rule",
|
||||
item: &filtering.RewriteItem{
|
||||
Domain: "*.example.org",
|
||||
Answer: "other.org",
|
||||
},
|
||||
want: "|*.example.org^$dnsrewrite=NOERROR;CNAME;other.org",
|
||||
}, {
|
||||
name: "aaaa_exception",
|
||||
item: &filtering.RewriteItem{
|
||||
Domain: testDomain,
|
||||
Answer: "A",
|
||||
},
|
||||
want: "@@||example.org^$dnstype=A,dnsrewrite",
|
||||
}, {
|
||||
name: "aaaa_exception",
|
||||
item: &filtering.RewriteItem{
|
||||
Domain: testDomain,
|
||||
Answer: "AAAA",
|
||||
},
|
||||
want: "@@||example.org^$dnstype=AAAA,dnsrewrite",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
res := toRule(tc.item)
|
||||
assert.Equal(t, tc.want, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
61
internal/filtering/rewrite_test.go
Normal file
61
internal/filtering/rewrite_test.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package filtering
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestItem_equal(t *testing.T) {
|
||||
const (
|
||||
testDomain = "example.org"
|
||||
testAnswer = "1.1.1.1"
|
||||
)
|
||||
|
||||
testItem := &RewriteItem{
|
||||
Domain: testDomain,
|
||||
Answer: testAnswer,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
left *RewriteItem
|
||||
right *RewriteItem
|
||||
want bool
|
||||
}{{
|
||||
name: "nil_left",
|
||||
left: nil,
|
||||
right: testItem,
|
||||
want: false,
|
||||
}, {
|
||||
name: "nil_right",
|
||||
left: testItem,
|
||||
right: nil,
|
||||
want: false,
|
||||
}, {
|
||||
name: "nils",
|
||||
left: nil,
|
||||
right: nil,
|
||||
want: true,
|
||||
}, {
|
||||
name: "equal",
|
||||
left: testItem,
|
||||
right: testItem,
|
||||
want: true,
|
||||
}, {
|
||||
name: "distinct",
|
||||
left: testItem,
|
||||
right: &RewriteItem{
|
||||
Domain: "other",
|
||||
Answer: "other",
|
||||
},
|
||||
want: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
res := tc.left.Equal(tc.right)
|
||||
assert.Equal(t, tc.want, res)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -5,7 +5,6 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rewrite"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
|
@ -16,7 +15,7 @@ func (d *DNSFilter) handleRewriteList(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// handleRewriteAdd is the handler for the POST /control/rewrite/add HTTP API.
|
||||
func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
|
||||
rw := &rewrite.Item{}
|
||||
rw := &RewriteItem{}
|
||||
err := json.NewDecoder(r.Body).Decode(rw)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
|
@ -43,7 +42,7 @@ func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
|
|||
// handleRewriteDelete is the handler for the POST /control/rewrite/delete HTTP
|
||||
// API.
|
||||
func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request) {
|
||||
entDel := rewrite.Item{}
|
||||
entDel := RewriteItem{}
|
||||
err := json.NewDecoder(r.Body).Decode(&entDel)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rewrite"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
|
@ -76,30 +77,21 @@ func initDNSServer() (err error) {
|
|||
}
|
||||
Context.queryLog = querylog.New(conf)
|
||||
|
||||
Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil)
|
||||
rewriteStorage, err := rewrite.NewDefaultStorage(config.DNS.DnsfilterConf.Rewrites)
|
||||
if err != nil {
|
||||
return fmt.Errorf("rewrites: init: %w", err)
|
||||
}
|
||||
|
||||
Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil, rewriteStorage)
|
||||
if err != nil {
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
var privateNets netutil.SubnetSet
|
||||
switch len(config.DNS.PrivateNets) {
|
||||
case 0:
|
||||
// Use an optimized locally-served matcher.
|
||||
privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
case 1:
|
||||
privateNets, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing the set of private subnets: %w", err)
|
||||
}
|
||||
default:
|
||||
var nets []*net.IPNet
|
||||
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing the set of private subnets: %w", err)
|
||||
}
|
||||
|
||||
privateNets = netutil.SliceSubnetSet(nets)
|
||||
privateNets, err := initPrivateNets()
|
||||
if err != nil {
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
p := dnsforward.DNSCreateParams{
|
||||
|
@ -146,6 +138,29 @@ func initDNSServer() (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
func initPrivateNets() (privateNets netutil.SubnetSet, err error) {
|
||||
switch len(config.DNS.PrivateNets) {
|
||||
case 0:
|
||||
// Use an optimized locally-served matcher.
|
||||
privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
case 1:
|
||||
privateNets, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("preparing the set of private subnets: %w", err)
|
||||
}
|
||||
default:
|
||||
var nets []*net.IPNet
|
||||
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("preparing the set of private subnets: %w", err)
|
||||
}
|
||||
|
||||
privateNets = netutil.SliceSubnetSet(nets)
|
||||
}
|
||||
|
||||
return privateNets, nil
|
||||
}
|
||||
|
||||
func isRunning() bool {
|
||||
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue