From 300821a7fbb354fde0f1c51ca2526ae72a88c4e1 Mon Sep 17 00:00:00 2001
From: Eugene Burkov <e.burkov@adguard.com>
Date: Thu, 27 Jul 2023 18:23:23 +0300
Subject: [PATCH] Pull request 1943: 6046 Local PTR

Merge in DNS/adguard-home from 6046-local-ptr to master

Updates #6046.

Squashed commit of the following:

commit 3e90815f29173d2f68970278bd7b1b29cc0a4465
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Jul 27 18:17:41 2023 +0300

    all: log changes

commit 7639f6f785670c15911fb3ca20abeb4e2b8f8582
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Jul 27 17:40:49 2023 +0300

    all: fix 0 ttl ptr
---
 CHANGELOG.md                           |  7 +++
 internal/dnsforward/dnsforward.go      | 16 +++++--
 internal/dnsforward/dnsforward_test.go | 26 ++++++++++++
 internal/dnsforward/process.go         |  2 +
 internal/rdns/rdns.go                  |  2 +
 internal/rdns/rdns_test.go             | 59 ++++++++++++++++++++------
 6 files changed, 94 insertions(+), 18 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index e04f023d..190d1235 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -23,6 +23,13 @@ See also the [v0.107.36 GitHub milestone][ms-v0.107.36].
 NOTE: Add new changes BELOW THIS COMMENT.
 -->
 
+### Fixed
+
+- Client hostnames not resolving when upstream server responds with zero-TTL
+  records ([#6046]).
+
+[#6046]: https://github.com/AdguardTeam/AdGuardHome/issues/6046
+
 <!--
 NOTE: Add new changes ABOVE THIS COMMENT.
 -->
diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go
index 730e88f8..894d2ecd 100644
--- a/internal/dnsforward/dnsforward.go
+++ b/internal/dnsforward/dnsforward.go
@@ -346,19 +346,21 @@ func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err er
 	}
 
 	var resolver *proxy.Proxy
+	var errMsg string
 	if s.privateNets.Contains(ip.AsSlice()) {
 		if !s.conf.UsePrivateRDNS {
 			return "", 0, nil
 		}
 
 		resolver = s.localResolvers
+		errMsg = "resolving a private address: %w"
 		s.recDetector.add(*req)
 	} else {
 		resolver = s.internalProxy
+		errMsg = "resolving an address: %w"
 	}
-
 	if err = resolver.Resolve(dctx); err != nil {
-		return "", 0, err
+		return "", 0, fmt.Errorf(errMsg, err)
 	}
 
 	return hostFromPTR(dctx.Res)
@@ -377,13 +379,18 @@ func hostFromPTR(resp *dns.Msg) (host string, ttl time.Duration, err error) {
 
 	var ttlSec uint32
 
+	log.Debug("dnsforward: resolving ptr, received %d answers", len(resp.Answer))
 	for _, ans := range resp.Answer {
 		ptr, ok := ans.(*dns.PTR)
 		if !ok {
 			continue
 		}
 
-		if ptr.Hdr.Ttl > ttlSec {
+		// Respect zero TTL records since some DNS servers use it to
+		// locally-resolved addresses.
+		//
+		// See https://github.com/AdguardTeam/AdGuardHome/issues/6046.
+		if ptr.Hdr.Ttl >= ttlSec {
 			host = ptr.Ptr
 			ttlSec = ptr.Hdr.Ttl
 		}
@@ -465,6 +472,7 @@ func (s *Server) filterOurDNSAddrs(addrs []string) (filtered []string, err error
 	}
 
 	ourAddrsSet := stringutil.NewSet(ourAddrs...)
+	log.Debug("dnsforward: filtering out %s", ourAddrsSet.String())
 
 	// TODO(e.burkov): The approach of subtracting sets of strings is not
 	// really applicable here since in case of listening on all network
@@ -501,7 +509,7 @@ func (s *Server) setupLocalResolvers() (err error) {
 		PreferIPv6: s.conf.BootstrapPreferIPv6,
 	})
 	if err != nil {
-		return fmt.Errorf("parsing private upstreams: %w", err)
+		return fmt.Errorf("preparing private upstreams: %w", err)
 	}
 
 	s.localResolvers = &proxy.Proxy{
diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go
index 775a97b5..a5e25681 100644
--- a/internal/dnsforward/dnsforward_test.go
+++ b/internal/dnsforward/dnsforward_test.go
@@ -1374,6 +1374,24 @@ func TestServer_Exchange(t *testing.T) {
 	refusingUpstream := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
 		return new(dns.Msg).SetRcode(req, dns.RcodeRefused), nil
 	})
+	zeroTTLUps := &aghtest.UpstreamMock{
+		OnAddress: func() (addr string) { return "zero.ttl.example" },
+		OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
+			resp = new(dns.Msg).SetReply(req)
+			hdr := dns.RR_Header{
+				Name:   req.Question[0].Name,
+				Rrtype: dns.TypePTR,
+				Class:  dns.ClassINET,
+				Ttl:    0,
+			}
+			resp.Answer = []dns.RR{&dns.PTR{
+				Hdr: hdr,
+				Ptr: localDomainHost,
+			}}
+
+			return resp, nil
+		},
+	}
 
 	srv := &Server{
 		recDetector: newRecursionDetector(0, 1),
@@ -1445,6 +1463,13 @@ func TestServer_Exchange(t *testing.T) {
 		locUpstream: nil,
 		req:         twosIP,
 		wantTTL:     defaultTTL * 2,
+	}, {
+		name:        "zero_ttl",
+		want:        localDomainHost,
+		wantErr:     nil,
+		locUpstream: zeroTTLUps,
+		req:         localIP,
+		wantTTL:     0,
 	}}
 
 	for _, tc := range testCases {
@@ -1468,6 +1493,7 @@ func TestServer_Exchange(t *testing.T) {
 
 	t.Run("resolving_disabled", func(t *testing.T) {
 		srv.conf.UsePrivateRDNS = false
+		t.Cleanup(func() { srv.conf.UsePrivateRDNS = true })
 
 		host, _, eerr := srv.Exchange(localIP)
 
diff --git a/internal/dnsforward/process.go b/internal/dnsforward/process.go
index 60feb968..13a8a2eb 100644
--- a/internal/dnsforward/process.go
+++ b/internal/dnsforward/process.go
@@ -719,6 +719,8 @@ func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
 	if s.conf.UsePrivateRDNS {
 		s.recDetector.add(*pctx.Req)
 		if err := s.localResolvers.Resolve(pctx); err != nil {
+			log.Debug("dnsforward: resolving private address: %s", err)
+
 			// Generate the server failure if the private upstream configuration
 			// is empty.
 			//
diff --git a/internal/rdns/rdns.go b/internal/rdns/rdns.go
index b33e212c..93898b3e 100644
--- a/internal/rdns/rdns.go
+++ b/internal/rdns/rdns.go
@@ -101,6 +101,8 @@ func (r *Default) Process(ip netip.Addr) (host string, changed bool) {
 		log.Debug("rdns: cache: adding item %q: %s", ip, err)
 	}
 
+	// TODO(e.burkov):  The name doesn't change if it's neither stored in cache
+	// nor resolved successfully.  Is it correct?
 	return host, fromCache == "" || host != fromCache
 }
 
diff --git a/internal/rdns/rdns_test.go b/internal/rdns/rdns_test.go
index 61130ec5..0db13728 100644
--- a/internal/rdns/rdns_test.go
+++ b/internal/rdns/rdns_test.go
@@ -25,11 +25,6 @@ func TestDefault_Process(t *testing.T) {
 	localRevAddr1, err := netutil.IPToReversedAddr(localIP.AsSlice())
 	require.NoError(t, err)
 
-	config := &rdns.Config{
-		CacheSize: 100,
-		CacheTTL:  time.Hour,
-	}
-
 	testCases := []struct {
 		name string
 		addr netip.Addr
@@ -60,21 +55,21 @@ func TestDefault_Process(t *testing.T) {
 
 				switch ip {
 				case ip1:
-					return revAddr1, 0, nil
+					return revAddr1, time.Hour, nil
 				case ip2:
-					return revAddr2, 0, nil
+					return revAddr2, time.Hour, nil
 				case localIP:
-					return localRevAddr1, 0, nil
+					return localRevAddr1, time.Hour, nil
 				default:
-					return "", 0, nil
+					return "", time.Hour, nil
 				}
 			}
-			exchanger := &aghtest.Exchanger{
-				OnExchange: onExchange,
-			}
 
-			config.Exchanger = exchanger
-			r := rdns.New(config)
+			r := rdns.New(&rdns.Config{
+				CacheSize: 100,
+				CacheTTL:  time.Hour,
+				Exchanger: &aghtest.Exchanger{OnExchange: onExchange},
+			})
 
 			got, changed := r.Process(tc.addr)
 			require.True(t, changed)
@@ -90,4 +85,40 @@ func TestDefault_Process(t *testing.T) {
 			assert.Equal(t, 1, hit)
 		})
 	}
+
+	t.Run("zero_ttl", func(t *testing.T) {
+		const cacheTTL = time.Second / 2
+
+		zeroTTLExchanger := &aghtest.Exchanger{
+			OnExchange: func(ip netip.Addr) (host string, ttl time.Duration, err error) {
+				return revAddr1, 0, nil
+			},
+		}
+
+		r := rdns.New(&rdns.Config{
+			CacheSize: 1,
+			CacheTTL:  cacheTTL,
+			Exchanger: zeroTTLExchanger,
+		})
+
+		got, changed := r.Process(ip1)
+		require.True(t, changed)
+		assert.Equal(t, revAddr1, got)
+
+		zeroTTLExchanger.OnExchange = func(ip netip.Addr) (host string, ttl time.Duration, err error) {
+			return revAddr2, time.Hour, nil
+		}
+
+		require.EventuallyWithT(t, func(t *assert.CollectT) {
+			got, changed = r.Process(ip1)
+			assert.True(t, changed)
+			assert.Equal(t, revAddr2, got)
+		}, 2*cacheTTL, time.Millisecond*100)
+
+		assert.Never(t, func() (changed bool) {
+			_, changed = r.Process(ip1)
+
+			return changed
+		}, 2*cacheTTL, time.Millisecond*100)
+	})
 }