diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index 7988d27e..32707aac 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -291,7 +291,7 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering var result Result var err error - result = d.processRewrites(host, qtype) + result = d.processRewrites(host) if result.Reason == ReasonRewrite { return result, nil } @@ -356,8 +356,8 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering // . if found, set domain name to canonical name // . repeat for the new domain name (Note: we return only the last CNAME) // . Find A or AAAA record for a domain name (exact match or by wildcard) -// . if found, return IP addresses -func (d *Dnsfilter) processRewrites(host string, qtype uint16) Result { +// . if found, return IP addresses (both IPv4 and IPv6) +func (d *Dnsfilter) processRewrites(host string) Result { var res Result d.confLock.RLock() @@ -384,7 +384,7 @@ func (d *Dnsfilter) processRewrites(host string, qtype uint16) Result { } for _, r := range rr { - if r.Type != dns.TypeCNAME && r.Type == qtype { + if r.Type != dns.TypeCNAME { res.IPList = append(res.IPList, r.IP) log.Debug("Rewrite: A/AAAA for %s is %s", host, r.IP) } diff --git a/dnsfilter/dnsfilter_test.go b/dnsfilter/dnsfilter_test.go index 9d6a71c8..963926a9 100644 --- a/dnsfilter/dnsfilter_test.go +++ b/dnsfilter/dnsfilter_test.go @@ -544,20 +544,16 @@ func TestRewrites(t *testing.T) { RewriteEntry{"www.host.com", "host.com", 0, nil}, } d.prepareRewrites() - r := d.processRewrites("host2.com", dns.TypeA) + r := d.processRewrites("host2.com") assert.Equal(t, NotFilteredNotFound, r.Reason) - r = d.processRewrites("www.host.com", dns.TypeA) + r = d.processRewrites("www.host.com") assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, "host.com", r.CanonName) - assert.True(t, len(r.IPList) == 2) + assert.True(t, len(r.IPList) == 3) assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) assert.True(t, r.IPList[1].Equal(net.ParseIP("1.2.3.5"))) - - r = d.processRewrites("www.host.com", dns.TypeAAAA) - assert.Equal(t, ReasonRewrite, r.Reason) - assert.True(t, len(r.IPList) == 1) - assert.True(t, r.IPList[0].Equal(net.ParseIP("1:2:3::4"))) + assert.True(t, r.IPList[2].Equal(net.ParseIP("1:2:3::4"))) // wildcard d.Rewrites = []RewriteEntry{ @@ -565,15 +561,15 @@ func TestRewrites(t *testing.T) { RewriteEntry{"*.host.com", "1.2.3.5", 0, nil}, } d.prepareRewrites() - r = d.processRewrites("host.com", dns.TypeA) + r = d.processRewrites("host.com") assert.Equal(t, ReasonRewrite, r.Reason) assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) - r = d.processRewrites("www.host.com", dns.TypeA) + r = d.processRewrites("www.host.com") assert.Equal(t, ReasonRewrite, r.Reason) assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.5"))) - r = d.processRewrites("www.host2.com", dns.TypeA) + r = d.processRewrites("www.host2.com") assert.Equal(t, NotFilteredNotFound, r.Reason) // override a wildcard @@ -582,7 +578,7 @@ func TestRewrites(t *testing.T) { RewriteEntry{"*.host.com", "1.2.3.5", 0, nil}, } d.prepareRewrites() - r = d.processRewrites("a.host.com", dns.TypeA) + r = d.processRewrites("a.host.com") assert.Equal(t, ReasonRewrite, r.Reason) assert.True(t, len(r.IPList) == 1) assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) @@ -593,7 +589,7 @@ func TestRewrites(t *testing.T) { RewriteEntry{"*.host.com", "host.com", 0, nil}, } d.prepareRewrites() - r = d.processRewrites("www.host.com", dns.TypeA) + r = d.processRewrites("www.host.com") assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, "host.com", r.CanonName) assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) @@ -605,7 +601,7 @@ func TestRewrites(t *testing.T) { RewriteEntry{"host.com", "1.2.3.4", 0, nil}, } d.prepareRewrites() - r = d.processRewrites("b.host.com", dns.TypeA) + r = d.processRewrites("b.host.com") assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, "host.com", r.CanonName) assert.True(t, len(r.IPList) == 1) @@ -618,7 +614,7 @@ func TestRewrites(t *testing.T) { RewriteEntry{"*.somehost.com", "1.2.3.4", 0, nil}, } d.prepareRewrites() - r = d.processRewrites("b.host.com", dns.TypeA) + r = d.processRewrites("b.host.com") assert.Equal(t, ReasonRewrite, r.Reason) assert.Equal(t, "x.somehost.com", r.CanonName) assert.True(t, len(r.IPList) == 1) diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 7c001e27..3ddda7f1 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -791,11 +791,12 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) { } for _, ip := range res.IPList { - if req.Question[0].Qtype == dns.TypeA { - a := s.genAAnswer(req, ip) + ip4 := ip.To4() + if req.Question[0].Qtype == dns.TypeA && ip4 != nil { + a := s.genAAnswer(req, ip4) a.Hdr.Name = dns.Fqdn(name) resp.Answer = append(resp.Answer, a) - } else if req.Question[0].Qtype == dns.TypeAAAA { + } else if req.Question[0].Qtype == dns.TypeAAAA && ip4 == nil { a := s.genAAAAAnswer(req, ip) a.Hdr.Name = dns.Fqdn(name) resp.Answer = append(resp.Answer, a)