* DNS rewrites: don't pass request to an upstream server if matched by Rewrite rule

For example, if there's an A rewrite rule, but no AAAA rule,
 the response to AAAA request must be empty.
This commit is contained in:
Simon Zolin 2020-03-02 11:49:26 +03:00
parent 80df44b316
commit 140d5553e7
3 changed files with 19 additions and 22 deletions

View file

@ -291,7 +291,7 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering
var result Result
var err error
result = d.processRewrites(host, qtype)
result = d.processRewrites(host)
if result.Reason == ReasonRewrite {
return result, nil
}
@ -356,8 +356,8 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering
// . if found, set domain name to canonical name
// . repeat for the new domain name (Note: we return only the last CNAME)
// . Find A or AAAA record for a domain name (exact match or by wildcard)
// . if found, return IP addresses
func (d *Dnsfilter) processRewrites(host string, qtype uint16) Result {
// . if found, return IP addresses (both IPv4 and IPv6)
func (d *Dnsfilter) processRewrites(host string) Result {
var res Result
d.confLock.RLock()
@ -384,7 +384,7 @@ func (d *Dnsfilter) processRewrites(host string, qtype uint16) Result {
}
for _, r := range rr {
if r.Type != dns.TypeCNAME && r.Type == qtype {
if r.Type != dns.TypeCNAME {
res.IPList = append(res.IPList, r.IP)
log.Debug("Rewrite: A/AAAA for %s is %s", host, r.IP)
}

View file

@ -544,20 +544,16 @@ func TestRewrites(t *testing.T) {
RewriteEntry{"www.host.com", "host.com", 0, nil},
}
d.prepareRewrites()
r := d.processRewrites("host2.com", dns.TypeA)
r := d.processRewrites("host2.com")
assert.Equal(t, NotFilteredNotFound, r.Reason)
r = d.processRewrites("www.host.com", dns.TypeA)
r = d.processRewrites("www.host.com")
assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, "host.com", r.CanonName)
assert.True(t, len(r.IPList) == 2)
assert.True(t, len(r.IPList) == 3)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
assert.True(t, r.IPList[1].Equal(net.ParseIP("1.2.3.5")))
r = d.processRewrites("www.host.com", dns.TypeAAAA)
assert.Equal(t, ReasonRewrite, r.Reason)
assert.True(t, len(r.IPList) == 1)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1:2:3::4")))
assert.True(t, r.IPList[2].Equal(net.ParseIP("1:2:3::4")))
// wildcard
d.Rewrites = []RewriteEntry{
@ -565,15 +561,15 @@ func TestRewrites(t *testing.T) {
RewriteEntry{"*.host.com", "1.2.3.5", 0, nil},
}
d.prepareRewrites()
r = d.processRewrites("host.com", dns.TypeA)
r = d.processRewrites("host.com")
assert.Equal(t, ReasonRewrite, r.Reason)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
r = d.processRewrites("www.host.com", dns.TypeA)
r = d.processRewrites("www.host.com")
assert.Equal(t, ReasonRewrite, r.Reason)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.5")))
r = d.processRewrites("www.host2.com", dns.TypeA)
r = d.processRewrites("www.host2.com")
assert.Equal(t, NotFilteredNotFound, r.Reason)
// override a wildcard
@ -582,7 +578,7 @@ func TestRewrites(t *testing.T) {
RewriteEntry{"*.host.com", "1.2.3.5", 0, nil},
}
d.prepareRewrites()
r = d.processRewrites("a.host.com", dns.TypeA)
r = d.processRewrites("a.host.com")
assert.Equal(t, ReasonRewrite, r.Reason)
assert.True(t, len(r.IPList) == 1)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
@ -593,7 +589,7 @@ func TestRewrites(t *testing.T) {
RewriteEntry{"*.host.com", "host.com", 0, nil},
}
d.prepareRewrites()
r = d.processRewrites("www.host.com", dns.TypeA)
r = d.processRewrites("www.host.com")
assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, "host.com", r.CanonName)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
@ -605,7 +601,7 @@ func TestRewrites(t *testing.T) {
RewriteEntry{"host.com", "1.2.3.4", 0, nil},
}
d.prepareRewrites()
r = d.processRewrites("b.host.com", dns.TypeA)
r = d.processRewrites("b.host.com")
assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, "host.com", r.CanonName)
assert.True(t, len(r.IPList) == 1)
@ -618,7 +614,7 @@ func TestRewrites(t *testing.T) {
RewriteEntry{"*.somehost.com", "1.2.3.4", 0, nil},
}
d.prepareRewrites()
r = d.processRewrites("b.host.com", dns.TypeA)
r = d.processRewrites("b.host.com")
assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, "x.somehost.com", r.CanonName)
assert.True(t, len(r.IPList) == 1)

View file

@ -791,11 +791,12 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) {
}
for _, ip := range res.IPList {
if req.Question[0].Qtype == dns.TypeA {
a := s.genAAnswer(req, ip)
ip4 := ip.To4()
if req.Question[0].Qtype == dns.TypeA && ip4 != nil {
a := s.genAAnswer(req, ip4)
a.Hdr.Name = dns.Fqdn(name)
resp.Answer = append(resp.Answer, a)
} else if req.Question[0].Qtype == dns.TypeAAAA {
} else if req.Question[0].Qtype == dns.TypeAAAA && ip4 == nil {
a := s.genAAAAAnswer(req, ip)
a.Hdr.Name = dns.Fqdn(name)
resp.Answer = append(resp.Answer, a)