* DNS rewrites: don't pass request to an upstream server if matched by Rewrite rule

For example, if there's an A rewrite rule, but no AAAA rule,
 the response to AAAA request must be empty.
This commit is contained in:
Simon Zolin 2020-03-02 11:49:26 +03:00
parent 80df44b316
commit 140d5553e7
3 changed files with 19 additions and 22 deletions

View file

@ -291,7 +291,7 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering
var result Result var result Result
var err error var err error
result = d.processRewrites(host, qtype) result = d.processRewrites(host)
if result.Reason == ReasonRewrite { if result.Reason == ReasonRewrite {
return result, nil return result, nil
} }
@ -356,8 +356,8 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering
// . if found, set domain name to canonical name // . if found, set domain name to canonical name
// . repeat for the new domain name (Note: we return only the last CNAME) // . repeat for the new domain name (Note: we return only the last CNAME)
// . Find A or AAAA record for a domain name (exact match or by wildcard) // . Find A or AAAA record for a domain name (exact match or by wildcard)
// . if found, return IP addresses // . if found, return IP addresses (both IPv4 and IPv6)
func (d *Dnsfilter) processRewrites(host string, qtype uint16) Result { func (d *Dnsfilter) processRewrites(host string) Result {
var res Result var res Result
d.confLock.RLock() d.confLock.RLock()
@ -384,7 +384,7 @@ func (d *Dnsfilter) processRewrites(host string, qtype uint16) Result {
} }
for _, r := range rr { for _, r := range rr {
if r.Type != dns.TypeCNAME && r.Type == qtype { if r.Type != dns.TypeCNAME {
res.IPList = append(res.IPList, r.IP) res.IPList = append(res.IPList, r.IP)
log.Debug("Rewrite: A/AAAA for %s is %s", host, r.IP) log.Debug("Rewrite: A/AAAA for %s is %s", host, r.IP)
} }

View file

@ -544,20 +544,16 @@ func TestRewrites(t *testing.T) {
RewriteEntry{"www.host.com", "host.com", 0, nil}, RewriteEntry{"www.host.com", "host.com", 0, nil},
} }
d.prepareRewrites() d.prepareRewrites()
r := d.processRewrites("host2.com", dns.TypeA) r := d.processRewrites("host2.com")
assert.Equal(t, NotFilteredNotFound, r.Reason) assert.Equal(t, NotFilteredNotFound, r.Reason)
r = d.processRewrites("www.host.com", dns.TypeA) r = d.processRewrites("www.host.com")
assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, "host.com", r.CanonName) assert.Equal(t, "host.com", r.CanonName)
assert.True(t, len(r.IPList) == 2) assert.True(t, len(r.IPList) == 3)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
assert.True(t, r.IPList[1].Equal(net.ParseIP("1.2.3.5"))) assert.True(t, r.IPList[1].Equal(net.ParseIP("1.2.3.5")))
assert.True(t, r.IPList[2].Equal(net.ParseIP("1:2:3::4")))
r = d.processRewrites("www.host.com", dns.TypeAAAA)
assert.Equal(t, ReasonRewrite, r.Reason)
assert.True(t, len(r.IPList) == 1)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1:2:3::4")))
// wildcard // wildcard
d.Rewrites = []RewriteEntry{ d.Rewrites = []RewriteEntry{
@ -565,15 +561,15 @@ func TestRewrites(t *testing.T) {
RewriteEntry{"*.host.com", "1.2.3.5", 0, nil}, RewriteEntry{"*.host.com", "1.2.3.5", 0, nil},
} }
d.prepareRewrites() d.prepareRewrites()
r = d.processRewrites("host.com", dns.TypeA) r = d.processRewrites("host.com")
assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, ReasonRewrite, r.Reason)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
r = d.processRewrites("www.host.com", dns.TypeA) r = d.processRewrites("www.host.com")
assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, ReasonRewrite, r.Reason)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.5"))) assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.5")))
r = d.processRewrites("www.host2.com", dns.TypeA) r = d.processRewrites("www.host2.com")
assert.Equal(t, NotFilteredNotFound, r.Reason) assert.Equal(t, NotFilteredNotFound, r.Reason)
// override a wildcard // override a wildcard
@ -582,7 +578,7 @@ func TestRewrites(t *testing.T) {
RewriteEntry{"*.host.com", "1.2.3.5", 0, nil}, RewriteEntry{"*.host.com", "1.2.3.5", 0, nil},
} }
d.prepareRewrites() d.prepareRewrites()
r = d.processRewrites("a.host.com", dns.TypeA) r = d.processRewrites("a.host.com")
assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, ReasonRewrite, r.Reason)
assert.True(t, len(r.IPList) == 1) assert.True(t, len(r.IPList) == 1)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
@ -593,7 +589,7 @@ func TestRewrites(t *testing.T) {
RewriteEntry{"*.host.com", "host.com", 0, nil}, RewriteEntry{"*.host.com", "host.com", 0, nil},
} }
d.prepareRewrites() d.prepareRewrites()
r = d.processRewrites("www.host.com", dns.TypeA) r = d.processRewrites("www.host.com")
assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, "host.com", r.CanonName) assert.Equal(t, "host.com", r.CanonName)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
@ -605,7 +601,7 @@ func TestRewrites(t *testing.T) {
RewriteEntry{"host.com", "1.2.3.4", 0, nil}, RewriteEntry{"host.com", "1.2.3.4", 0, nil},
} }
d.prepareRewrites() d.prepareRewrites()
r = d.processRewrites("b.host.com", dns.TypeA) r = d.processRewrites("b.host.com")
assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, "host.com", r.CanonName) assert.Equal(t, "host.com", r.CanonName)
assert.True(t, len(r.IPList) == 1) assert.True(t, len(r.IPList) == 1)
@ -618,7 +614,7 @@ func TestRewrites(t *testing.T) {
RewriteEntry{"*.somehost.com", "1.2.3.4", 0, nil}, RewriteEntry{"*.somehost.com", "1.2.3.4", 0, nil},
} }
d.prepareRewrites() d.prepareRewrites()
r = d.processRewrites("b.host.com", dns.TypeA) r = d.processRewrites("b.host.com")
assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, "x.somehost.com", r.CanonName) assert.Equal(t, "x.somehost.com", r.CanonName)
assert.True(t, len(r.IPList) == 1) assert.True(t, len(r.IPList) == 1)

View file

@ -791,11 +791,12 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) {
} }
for _, ip := range res.IPList { for _, ip := range res.IPList {
if req.Question[0].Qtype == dns.TypeA { ip4 := ip.To4()
a := s.genAAnswer(req, ip) if req.Question[0].Qtype == dns.TypeA && ip4 != nil {
a := s.genAAnswer(req, ip4)
a.Hdr.Name = dns.Fqdn(name) a.Hdr.Name = dns.Fqdn(name)
resp.Answer = append(resp.Answer, a) resp.Answer = append(resp.Answer, a)
} else if req.Question[0].Qtype == dns.TypeAAAA { } else if req.Question[0].Qtype == dns.TypeAAAA && ip4 == nil {
a := s.genAAAAAnswer(req, ip) a := s.genAAAAAnswer(req, ip)
a.Hdr.Name = dns.Fqdn(name) a.Hdr.Name = dns.Fqdn(name)
resp.Answer = append(resp.Answer, a) resp.Answer = append(resp.Answer, a)