From 49a92605b8894ac66d0c290106ee63114ef4047a Mon Sep 17 00:00:00 2001
From: Simon Zolin <s.zolin@adguard.com>
Date: Tue, 23 Jun 2020 12:13:13 +0300
Subject: [PATCH] + dns: respond to PTR requests for internal IP addresses from
 DHCP Close #1682

Squashed commit of the following:

commit 2fad3544bf8853b1f8f19ad8b7bc8a490c96e533
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Mon Jun 22 17:32:45 2020 +0300

    minor

commit 7c17992424702d95e6de91f30e8ae2dfcd8de257
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Mon Jun 22 16:09:34 2020 +0300

    build

commit 16a52e11a015a97d3cbf30362482a4abd052192b
Merge: 7b6a73c8 2c47053c
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Mon Jun 22 16:08:32 2020 +0300

    Merge remote-tracking branch 'origin/master' into 1682-dhcp-resolve

commit 7b6a73c84b5cb9a073a9dfb7d7bdecd22e1e1318
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Mon Jun 22 16:01:34 2020 +0300

    tests

commit c2654abb2e5e7b7e3a04e4ddb8e1064b37613929
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Mon Jun 1 15:15:13 2020 +0300

    + dnsforward: respond to PTR requests for internal IP addresses

    {[IP] => "host"} <- DNSforward <-(leases)-- DHCP
---
 dhcpd/dhcpd.go                | 10 +++--
 dnsforward/dnsforward.go      | 37 ++++++++++++-----
 dnsforward/dnsforward_test.go | 43 ++++++++++++++++++--
 dnsforward/handle_dns.go      | 75 +++++++++++++++++++++++++++++++++++
 home/dns.go                   |  8 +++-
 home/whois_test.go            |  2 +-
 util/auto_hosts.go            | 66 +-----------------------------
 util/auto_hosts_test.go       | 14 ++++---
 util/dns.go                   | 70 ++++++++++++++++++++++++++++++++
 9 files changed, 236 insertions(+), 89 deletions(-)
 create mode 100644 util/dns.go

diff --git a/dhcpd/dhcpd.go b/dhcpd/dhcpd.go
index f95a9867..0ec83370 100644
--- a/dhcpd/dhcpd.go
+++ b/dhcpd/dhcpd.go
@@ -92,7 +92,7 @@ type Server struct {
 	conf ServerConfig
 
 	// Called when the leases DB is modified
-	onLeaseChanged onLeaseChangedT
+	onLeaseChanged []onLeaseChangedT
 }
 
 // Print information about the available network interfaces
@@ -146,14 +146,16 @@ func (s *Server) Init(config ServerConfig) error {
 
 // SetOnLeaseChanged - set callback
 func (s *Server) SetOnLeaseChanged(onLeaseChanged onLeaseChangedT) {
-	s.onLeaseChanged = onLeaseChanged
+	s.onLeaseChanged = append(s.onLeaseChanged, onLeaseChanged)
 }
 
 func (s *Server) notify(flags int) {
-	if s.onLeaseChanged == nil {
+	if len(s.onLeaseChanged) == 0 {
 		return
 	}
-	s.onLeaseChanged(flags)
+	for _, f := range s.onLeaseChanged {
+		f(flags)
+	}
 }
 
 // WriteDiskConfig - write configuration
diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go
index e8d8a6b0..36f5445e 100644
--- a/dnsforward/dnsforward.go
+++ b/dnsforward/dnsforward.go
@@ -8,6 +8,7 @@ import (
 	"sync"
 	"time"
 
+	"github.com/AdguardTeam/AdGuardHome/dhcpd"
 	"github.com/AdguardTeam/AdGuardHome/dnsfilter"
 	"github.com/AdguardTeam/AdGuardHome/querylog"
 	"github.com/AdguardTeam/AdGuardHome/stats"
@@ -43,11 +44,15 @@ var webRegistered bool
 //
 // The zero Server is empty and ready for use.
 type Server struct {
-	dnsProxy  *proxy.Proxy         // DNS proxy instance
-	dnsFilter *dnsfilter.Dnsfilter // DNS filter instance
-	queryLog  querylog.QueryLog    // Query log instance
-	stats     stats.Stats
-	access    *accessCtx
+	dnsProxy   *proxy.Proxy         // DNS proxy instance
+	dnsFilter  *dnsfilter.Dnsfilter // DNS filter instance
+	dhcpServer *dhcpd.Server        // DHCP server instance (optional)
+	queryLog   querylog.QueryLog    // Query log instance
+	stats      stats.Stats
+	access     *accessCtx
+
+	tablePTR     map[string]string // "IP -> hostname" table for reverse lookup
+	tablePTRLock sync.Mutex
 
 	// DNS proxy instance for internal usage
 	// We don't Start() it and so no listen port is required.
@@ -59,13 +64,27 @@ type Server struct {
 	conf ServerConfig
 }
 
+// DNSCreateParams - parameters for NewServer()
+type DNSCreateParams struct {
+	DNSFilter  *dnsfilter.Dnsfilter
+	Stats      stats.Stats
+	QueryLog   querylog.QueryLog
+	DHCPServer *dhcpd.Server
+}
+
 // NewServer creates a new instance of the dnsforward.Server
 // Note: this function must be called only once
-func NewServer(dnsFilter *dnsfilter.Dnsfilter, stats stats.Stats, queryLog querylog.QueryLog) *Server {
+func NewServer(p DNSCreateParams) *Server {
 	s := &Server{}
-	s.dnsFilter = dnsFilter
-	s.stats = stats
-	s.queryLog = queryLog
+	s.dnsFilter = p.DNSFilter
+	s.stats = p.Stats
+	s.queryLog = p.QueryLog
+	s.dhcpServer = p.DHCPServer
+
+	if s.dhcpServer != nil {
+		s.dhcpServer.SetOnLeaseChanged(s.onDHCPLeaseChanged)
+		s.onDHCPLeaseChanged(dhcpd.LeaseChangedAdded)
+	}
 
 	if runtime.GOARCH == "mips" || runtime.GOARCH == "mipsle" {
 		// Use plain DNS on MIPS, encryption is too slow
diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go
index 42018265..773769ee 100644
--- a/dnsforward/dnsforward_test.go
+++ b/dnsforward/dnsforward_test.go
@@ -15,6 +15,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/AdguardTeam/AdGuardHome/dhcpd"
 	"github.com/AdguardTeam/AdGuardHome/dnsfilter"
 	"github.com/AdguardTeam/dnsproxy/proxy"
 	"github.com/AdguardTeam/dnsproxy/upstream"
@@ -496,7 +497,7 @@ func TestBlockedCustomIP(t *testing.T) {
 	c := dnsfilter.Config{}
 
 	f := dnsfilter.New(&c, filters)
-	s := NewServer(f, nil, nil)
+	s := NewServer(DNSCreateParams{DNSFilter: f})
 	conf := ServerConfig{}
 	conf.UDPListenAddr = &net.UDPAddr{Port: 0}
 	conf.TCPListenAddr = &net.TCPAddr{Port: 0}
@@ -648,7 +649,7 @@ func TestRewrite(t *testing.T) {
 	}
 
 	f := dnsfilter.New(&c, nil)
-	s := NewServer(f, nil, nil)
+	s := NewServer(DNSCreateParams{DNSFilter: f})
 	conf := ServerConfig{}
 	conf.UDPListenAddr = &net.UDPAddr{Port: 0}
 	conf.TCPListenAddr = &net.TCPAddr{Port: 0}
@@ -705,7 +706,7 @@ func createTestServer(t *testing.T) *Server {
 	c.CacheTime = 30
 
 	f := dnsfilter.New(&c, filters)
-	s := NewServer(f, nil, nil)
+	s := NewServer(DNSCreateParams{DNSFilter: f})
 	s.conf.UDPListenAddr = &net.UDPAddr{Port: 0}
 	s.conf.TCPListenAddr = &net.TCPAddr{Port: 0}
 	s.conf.UpstreamDNS = []string{"8.8.8.8:53", "8.8.4.4:53"}
@@ -1012,3 +1013,39 @@ func TestMatchDNSName(t *testing.T) {
 	assert.True(t, !matchDNSName(dnsNames, ""))
 	assert.True(t, !matchDNSName(dnsNames, "*.host2"))
 }
+
+func TestPTRResponse(t *testing.T) {
+	dhcp := &dhcpd.Server{}
+	dhcp.IPpool = make(map[[4]byte]net.HardwareAddr)
+
+	c := dnsfilter.Config{}
+	f := dnsfilter.New(&c, nil)
+	s := NewServer(DNSCreateParams{DNSFilter: f, DHCPServer: dhcp})
+	s.conf.UDPListenAddr = &net.UDPAddr{Port: 0}
+	s.conf.TCPListenAddr = &net.TCPAddr{Port: 0}
+	s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
+	s.conf.FilteringConfig.ProtectionEnabled = true
+	err := s.Prepare(nil)
+	assert.True(t, err == nil)
+	assert.Nil(t, s.Start())
+
+	l := dhcpd.Lease{}
+	l.IP = net.ParseIP("127.0.0.1").To4()
+	l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
+	l.Hostname = "localhost"
+	dhcp.AddStaticLease(l)
+
+	addr := s.dnsProxy.Addr(proxy.ProtoUDP)
+	req := createTestMessage("1.0.0.127.in-addr.arpa.")
+	req.Question[0].Qtype = dns.TypePTR
+
+	resp, err := dns.Exchange(req, addr.String())
+	assert.Nil(t, err)
+	assert.Equal(t, 1, len(resp.Answer))
+	assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
+	assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
+	ptr := resp.Answer[0].(*dns.PTR)
+	assert.Equal(t, "localhost.", ptr.Ptr)
+
+	s.Close()
+}
diff --git a/dnsforward/handle_dns.go b/dnsforward/handle_dns.go
index 87230e9b..462f3750 100644
--- a/dnsforward/handle_dns.go
+++ b/dnsforward/handle_dns.go
@@ -1,9 +1,12 @@
 package dnsforward
 
 import (
+	"strings"
 	"time"
 
+	"github.com/AdguardTeam/AdGuardHome/dhcpd"
 	"github.com/AdguardTeam/AdGuardHome/dnsfilter"
+	"github.com/AdguardTeam/AdGuardHome/util"
 	"github.com/AdguardTeam/dnsproxy/proxy"
 	"github.com/AdguardTeam/golibs/log"
 	"github.com/miekg/dns"
@@ -39,6 +42,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
 	type modProcessFunc func(ctx *dnsContext) int
 	mods := []modProcessFunc{
 		processInitial,
+		processInternalIPAddrs,
 		processFilteringBeforeRequest,
 		processUpstream,
 		processDNSSECAfterResponse,
@@ -88,11 +92,82 @@ func processInitial(ctx *dnsContext) int {
 	return resultDone
 }
 
+func (s *Server) onDHCPLeaseChanged(flags int) {
+	switch flags {
+	case dhcpd.LeaseChangedAdded,
+		dhcpd.LeaseChangedAddedStatic,
+		dhcpd.LeaseChangedRemovedStatic:
+		//
+	default:
+		return
+	}
+
+	m := make(map[string]string)
+	ll := s.dhcpServer.Leases(dhcpd.LeasesAll)
+	for _, l := range ll {
+		if len(l.Hostname) == 0 {
+			continue
+		}
+		m[l.IP.String()] = l.Hostname
+	}
+	log.Debug("DNS: added %d PTR entries from DHCP", len(m))
+	s.tablePTRLock.Lock()
+	s.tablePTR = m
+	s.tablePTRLock.Unlock()
+}
+
+// Respond to PTR requests if the target IP address is leased by our DHCP server
+func processInternalIPAddrs(ctx *dnsContext) int {
+	s := ctx.srv
+	req := ctx.proxyCtx.Req
+	if req.Question[0].Qtype != dns.TypePTR {
+		return resultDone
+	}
+
+	arpa := req.Question[0].Name
+	arpa = strings.TrimSuffix(arpa, ".")
+	arpa = strings.ToLower(arpa)
+	ip := util.DNSUnreverseAddr(arpa)
+	if ip == nil {
+		return resultDone
+	}
+
+	s.tablePTRLock.Lock()
+	if s.tablePTR == nil {
+		s.tablePTRLock.Unlock()
+		return resultDone
+	}
+	host, ok := s.tablePTR[ip.String()]
+	s.tablePTRLock.Unlock()
+	if !ok {
+		return resultDone
+	}
+
+	log.Debug("DNS: reverse-lookup: %s -> %s", arpa, host)
+
+	resp := s.makeResponse(req)
+	ptr := &dns.PTR{}
+	ptr.Hdr = dns.RR_Header{
+		Name:   req.Question[0].Name,
+		Rrtype: dns.TypePTR,
+		Ttl:    s.conf.BlockedResponseTTL,
+		Class:  dns.ClassINET,
+	}
+	ptr.Ptr = host + "."
+	resp.Answer = append(resp.Answer, ptr)
+	ctx.proxyCtx.Res = resp
+	return resultDone
+}
+
 // Apply filtering logic
 func processFilteringBeforeRequest(ctx *dnsContext) int {
 	s := ctx.srv
 	d := ctx.proxyCtx
 
+	if d.Res != nil {
+		return resultDone // response is already set - nothing to do
+	}
+
 	s.RLock()
 	// Synchronize access to s.dnsFilter so it won't be suddenly uninitialized while in use.
 	// This could happen after proxy server has been stopped, but its workers are not yet exited.
diff --git a/home/dns.go b/home/dns.go
index db05f99a..a5547627 100644
--- a/home/dns.go
+++ b/home/dns.go
@@ -61,7 +61,13 @@ func initDNSServer() error {
 	filterConf.HTTPRegister = httpRegister
 	Context.dnsFilter = dnsfilter.New(&filterConf, nil)
 
-	Context.dnsServer = dnsforward.NewServer(Context.dnsFilter, Context.stats, Context.queryLog)
+	p := dnsforward.DNSCreateParams{
+		DNSFilter:  Context.dnsFilter,
+		Stats:      Context.stats,
+		QueryLog:   Context.queryLog,
+		DHCPServer: Context.dhcpServer,
+	}
+	Context.dnsServer = dnsforward.NewServer(p)
 	dnsConfig := generateServerConfig()
 	err = Context.dnsServer.Prepare(&dnsConfig)
 	if err != nil {
diff --git a/home/whois_test.go b/home/whois_test.go
index 31e6aba2..3ea73c53 100644
--- a/home/whois_test.go
+++ b/home/whois_test.go
@@ -9,7 +9,7 @@ import (
 
 func prepareTestDNSServer() error {
 	config.DNS.Port = 1234
-	Context.dnsServer = dnsforward.NewServer(nil, nil, nil)
+	Context.dnsServer = dnsforward.NewServer(dnsforward.DNSCreateParams{})
 	conf := &dnsforward.ServerConfig{}
 	conf.UpstreamDNS = []string{"8.8.8.8"}
 	return Context.dnsServer.Prepare(conf)
diff --git a/util/auto_hosts.go b/util/auto_hosts.go
index 34a979da..b12acd81 100644
--- a/util/auto_hosts.go
+++ b/util/auto_hosts.go
@@ -296,70 +296,6 @@ func (a *AutoHosts) Process(host string, qtype uint16) []net.IP {
 	return ipsCopy
 }
 
-// convert character to hex number
-func charToHex(n byte) int8 {
-	if n >= '0' && n <= '9' {
-		return int8(n) - '0'
-	} else if (n|0x20) >= 'a' && (n|0x20) <= 'f' {
-		return (int8(n) | 0x20) - 'a' + 10
-	}
-	return -1
-}
-
-// parse IPv6 reverse address
-func ipParseArpa6(s string) net.IP {
-	if len(s) != 63 {
-		return nil
-	}
-	ip6 := make(net.IP, 16)
-
-	for i := 0; i != 64; i += 4 {
-
-		// parse "0.1."
-		n := charToHex(s[i])
-		n2 := charToHex(s[i+2])
-		if s[i+1] != '.' || (i != 60 && s[i+3] != '.') ||
-			n < 0 || n2 < 0 {
-			return nil
-		}
-
-		ip6[16-i/4-1] = byte(n2<<4) | byte(n&0x0f)
-	}
-	return ip6
-}
-
-// ipReverse - reverse IP address: 1.0.0.127 -> 127.0.0.1
-func ipReverse(ip net.IP) net.IP {
-	n := len(ip)
-	r := make(net.IP, n)
-	for i := 0; i != n; i++ {
-		r[i] = ip[n-i-1]
-	}
-	return r
-}
-
-// Convert reversed ARPA address to a normal IP address
-func dnsUnreverseAddr(s string) net.IP {
-	const arpaV4 = ".in-addr.arpa"
-	const arpaV6 = ".ip6.arpa"
-
-	if strings.HasSuffix(s, arpaV4) {
-		ip := strings.TrimSuffix(s, arpaV4)
-		ip4 := net.ParseIP(ip).To4()
-		if ip4 == nil {
-			return nil
-		}
-
-		return ipReverse(ip4)
-
-	} else if strings.HasSuffix(s, arpaV6) {
-		ip := strings.TrimSuffix(s, arpaV6)
-		return ipParseArpa6(ip)
-	}
-
-	return nil // unknown suffix
-}
-
 // ProcessReverse - process PTR request
 // Return "" if not found or an error occurred
 func (a *AutoHosts) ProcessReverse(addr string, qtype uint16) string {
@@ -367,7 +303,7 @@ func (a *AutoHosts) ProcessReverse(addr string, qtype uint16) string {
 		return ""
 	}
 
-	ipReal := dnsUnreverseAddr(addr)
+	ipReal := DNSUnreverseAddr(addr)
 	if ipReal == nil {
 		return "" // invalid IP in question
 	}
diff --git a/util/auto_hosts_test.go b/util/auto_hosts_test.go
index 322e7e9c..ea2e43ad 100644
--- a/util/auto_hosts_test.go
+++ b/util/auto_hosts_test.go
@@ -104,11 +104,13 @@ func TestAutoHostsFSNotify(t *testing.T) {
 }
 
 func TestIP(t *testing.T) {
-	assert.True(t, dnsUnreverseAddr("1.0.0.127.in-addr.arpa").Equal(net.ParseIP("127.0.0.1").To4()))
-	assert.True(t, dnsUnreverseAddr("4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa").Equal(net.ParseIP("::abcd:1234")))
+	assert.Equal(t, "127.0.0.1", DNSUnreverseAddr("1.0.0.127.in-addr.arpa").String())
+	assert.Equal(t, "::abcd:1234", DNSUnreverseAddr("4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa").String())
+	assert.Equal(t, "::abcd:1234", DNSUnreverseAddr("4.3.2.1.d.c.B.A.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa").String())
 
-	assert.True(t, dnsUnreverseAddr("1.0.0.127.in-addr.arpa.") == nil)
-	assert.True(t, dnsUnreverseAddr(".0.0.127.in-addr.arpa") == nil)
-	assert.True(t, dnsUnreverseAddr(".3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa") == nil)
-	assert.True(t, dnsUnreverseAddr("4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0..ip6.arpa") == nil)
+	assert.Nil(t, DNSUnreverseAddr("1.0.0.127.in-addr.arpa."))
+	assert.Nil(t, DNSUnreverseAddr(".0.0.127.in-addr.arpa"))
+	assert.Nil(t, DNSUnreverseAddr(".3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa"))
+	assert.Nil(t, DNSUnreverseAddr("4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0..ip6.arpa"))
+	assert.Nil(t, DNSUnreverseAddr("4.3.2.1.d.c.b. .0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa"))
 }
diff --git a/util/dns.go b/util/dns.go
new file mode 100644
index 00000000..aaf51d4d
--- /dev/null
+++ b/util/dns.go
@@ -0,0 +1,70 @@
+package util
+
+import (
+	"net"
+	"strings"
+)
+
+// convert character to hex number
+func charToHex(n byte) int8 {
+	if n >= '0' && n <= '9' {
+		return int8(n) - '0'
+	} else if (n|0x20) >= 'a' && (n|0x20) <= 'f' {
+		return (int8(n) | 0x20) - 'a' + 10
+	}
+	return -1
+}
+
+// parse IPv6 reverse address
+func ipParseArpa6(s string) net.IP {
+	if len(s) != 63 {
+		return nil
+	}
+	ip6 := make(net.IP, 16)
+
+	for i := 0; i != 64; i += 4 {
+
+		// parse "0.1."
+		n := charToHex(s[i])
+		n2 := charToHex(s[i+2])
+		if s[i+1] != '.' || (i != 60 && s[i+3] != '.') ||
+			n < 0 || n2 < 0 {
+			return nil
+		}
+
+		ip6[16-i/4-1] = byte(n2<<4) | byte(n&0x0f)
+	}
+	return ip6
+}
+
+// ipReverse - reverse IP address: 1.0.0.127 -> 127.0.0.1
+func ipReverse(ip net.IP) net.IP {
+	n := len(ip)
+	r := make(net.IP, n)
+	for i := 0; i != n; i++ {
+		r[i] = ip[n-i-1]
+	}
+	return r
+}
+
+// DNSUnreverseAddr - convert reversed ARPA address to a normal IP address
+func DNSUnreverseAddr(s string) net.IP {
+	const arpaV4 = ".in-addr.arpa"
+	const arpaV6 = ".ip6.arpa"
+
+	if strings.HasSuffix(s, arpaV4) {
+		ip := strings.TrimSuffix(s, arpaV4)
+		ip4 := net.ParseIP(ip).To4()
+		if ip4 == nil {
+			return nil
+		}
+
+		return ipReverse(ip4)
+
+	} else if strings.HasSuffix(s, arpaV6) {
+		ip := strings.TrimSuffix(s, arpaV6)
+		return ipParseArpa6(ip)
+	}
+
+	return nil // unknown suffix
+}