filtering: wildcard interference

This commit is contained in:
Stanislav Chzhen 2023-05-12 12:25:33 +03:00
parent c77b2a0ce5
commit 1f2ba07eae
2 changed files with 53 additions and 0 deletions

View file

@ -3,6 +3,7 @@ package filtering
import (
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)
// DNSRewriteResult is the result of application of $dnsrewrite rules.
@ -24,7 +25,13 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
Response: DNSRewriteResultResponse{},
}
slices.SortFunc(dnsr, rewriteSortsBefore)
for _, nr := range dnsr {
if containsWildcard(nr) {
break
}
dr := nr.DNSRewrite
if dr.NewCNAME != "" {
// NewCNAME rules have a higher priority than other rules.
@ -73,3 +80,19 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
Reason: RewrittenRule,
}
}
func rewriteSortsBefore(a, b *rules.NetworkRule) (sortsBefore bool) {
return len(a.Shortcut) > len(b.Shortcut)
}
func containsWildcard(r *rules.NetworkRule) (ok bool) {
for _, c := range r.RuleText {
if c == '*' {
return true
} else if c == '^' {
break
}
}
return false
}

View file

@ -5,6 +5,7 @@ import (
"path"
"testing"
"github.com/AdguardTeam/urlfilter"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -202,3 +203,32 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
assert.Equal(t, "new-ptr-with-dot.", ptr)
})
}
func TestDNSFilter_processDNSRewrites(t *testing.T) {
const text = `
|www.example.com^$dnsrewrite=127.0.0.1
|*.example.com^$dnsrewrite=127.0.0.2
`
host := "www.example.com"
rrtype := dns.TypeA
f, _ := newForTest(t, nil, []Filter{{ID: 0, Data: []byte(text)}})
setts := &Settings{
FilteringEnabled: true,
}
ufReq := &urlfilter.DNSRequest{
Hostname: host,
SortedClientTags: setts.ClientTags,
ClientIP: setts.ClientIP.String(),
ClientName: setts.ClientName,
DNSType: rrtype,
}
dres, matched := f.filteringEngine.MatchRequest(ufReq)
require.False(t, matched)
res := f.processDNSResultRewrites(dres, host)
assert.Len(t, res.Rules, 1)
}