From 8b8ae8ffadd34ebb69a28a958c4ccc1feb73bb69 Mon Sep 17 00:00:00 2001
From: Ainar Garipov <a.garipov@adguard.com>
Date: Fri, 8 Sep 2023 17:55:13 +0300
Subject: [PATCH] Pull request 2007: 6183-orig-resp

Closes #6183.

Squashed commit of the following:

commit a99b935d7a152f2cf2d003057cfb8e3c7c3579c5
Merge: 3534f663f 36517fc21
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Sep 8 17:46:51 2023 +0300

    Merge branch 'master' into 6183-orig-resp

commit 3534f663ff4aaacc4a1044b018802bd23cd8f7ec
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Sep 8 17:00:54 2023 +0300

    dnsforward: fix orig resp
---
 CHANGELOG.md                       |  2 ++
 internal/dnsforward/filter.go      | 47 ++++++++++++++++--------------
 internal/dnsforward/filter_test.go | 22 +++++++++-----
 internal/dnsforward/process.go     | 21 +++++--------
 4 files changed, 50 insertions(+), 42 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 49fc2311..41812ad3 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -25,12 +25,14 @@ NOTE: Add new changes BELOW THIS COMMENT.
 
 ### Fixed
 
+- Incorrect original answer when a response is filtered ([#6183]).
 - Comments in the *Fallback DNS servers* field in the UI ([#6182]).
 - Empty or default Safe Browsing and Parental Control settings ([#6181]).
 - Various UI issues.
 
 [#6181]: https://github.com/AdguardTeam/AdGuardHome/issues/6181
 [#6182]: https://github.com/AdguardTeam/AdGuardHome/issues/6182
+[#6183]: https://github.com/AdguardTeam/AdGuardHome/issues/6183
 
 <!--
 NOTE: Add new changes ABOVE THIS COMMENT.
diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go
index 6f551e59..d80f022e 100644
--- a/internal/dnsforward/filter.go
+++ b/internal/dnsforward/filter.go
@@ -11,6 +11,7 @@ import (
 	"github.com/AdguardTeam/dnsproxy/proxy"
 	"github.com/AdguardTeam/golibs/log"
 	"github.com/AdguardTeam/golibs/netutil"
+	"github.com/AdguardTeam/urlfilter/rules"
 	"github.com/miekg/dns"
 	"golang.org/x/exp/slices"
 )
@@ -140,15 +141,15 @@ func (s *Server) filterRewritten(
 
 // checkHostRules checks the host against filters.  It is safe for concurrent
 // use.
-func (s *Server) checkHostRules(host string, rrtype uint16, setts *filtering.Settings) (
-	r *filtering.Result,
-	err error,
-) {
+func (s *Server) checkHostRules(
+	host string,
+	rrtype rules.RRType,
+	setts *filtering.Settings,
+) (r *filtering.Result, err error) {
 	s.serverLock.RLock()
 	defer s.serverLock.RUnlock()
 
-	var res filtering.Result
-	res, err = s.dnsFilter.CheckHostRules(host, rrtype, setts)
+	res, err := s.dnsFilter.CheckHostRules(host, rrtype, setts)
 	if err != nil {
 		return nil, err
 	}
@@ -156,20 +157,21 @@ func (s *Server) checkHostRules(host string, rrtype uint16, setts *filtering.Set
 	return &res, err
 }
 
-// filterDNSResponse checks each resource record of the response's answer
-// section from pctx and returns a non-nil res if at least one of canonical
-// names or IP addresses in it matches the filtering rules.
-func (s *Server) filterDNSResponse(
-	pctx *proxy.DNSContext,
-	setts *filtering.Settings,
-) (res *filtering.Result, err error) {
+// filterDNSResponse checks each resource record of answer section of
+// dctx.proxyCtx.Res.  It sets dctx.result and dctx.origResp if at least one of
+// canonical names, IP addresses, or HTTPS RR hints in it matches the filtering
+// rules, as well as sets dctx.proxyCtx.Res to the filtered response.
+func (s *Server) filterDNSResponse(dctx *dnsContext) (err error) {
+	setts := dctx.setts
 	if !setts.FilteringEnabled {
-		return nil, nil
+		return nil
 	}
 
-	for _, a := range pctx.Res.Answer {
+	var res *filtering.Result
+	pctx := dctx.proxyCtx
+	for i, a := range pctx.Res.Answer {
 		host := ""
-		var rrtype uint16
+		var rrtype rules.RRType
 		switch a := a.(type) {
 		case *dns.CNAME:
 			host = strings.TrimSuffix(a.Target, ".")
@@ -195,18 +197,19 @@ func (s *Server) filterDNSResponse(
 		log.Debug("dnsforward: checked %s %s for %s", dns.Type(rrtype), host, a.Header().Name)
 
 		if err != nil {
-			return nil, err
-		} else if res == nil {
-			continue
-		} else if res.IsFiltered {
+			return fmt.Errorf("filtering answer at index %d: %w", i, err)
+		} else if res != nil && res.IsFiltered {
+			dctx.result = res
+			dctx.origResp = pctx.Res
 			pctx.Res = s.genDNSFilterMessage(pctx, res)
+
 			log.Debug("dnsforward: matched %q by response: %q", pctx.Req.Question[0].Name, host)
 
-			return res, nil
+			break
 		}
 	}
 
-	return nil, nil
+	return nil
 }
 
 // removeIPv6Hints deletes IPv6 hints from RR values.
diff --git a/internal/dnsforward/filter_test.go b/internal/dnsforward/filter_test.go
index 88a316e8..fe64cdf0 100644
--- a/internal/dnsforward/filter_test.go
+++ b/internal/dnsforward/filter_test.go
@@ -328,26 +328,34 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
 				Addr:  &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 1},
 			}
 
-			res, rErr := s.filterDNSResponse(pctx, &filtering.Settings{
-				ProtectionEnabled: true,
-				FilteringEnabled:  true,
-			})
-			require.NoError(t, rErr)
+			dctx := &dnsContext{
+				proxyCtx: pctx,
+				setts: &filtering.Settings{
+					ProtectionEnabled: true,
+					FilteringEnabled:  true,
+				},
+			}
 
+			fltErr := s.filterDNSResponse(dctx)
+			require.NoError(t, fltErr)
+
+			res := dctx.result
 			if tc.wantRule == "" {
 				assert.Nil(t, res)
 
 				return
 			}
 
-			want := &filtering.Result{
+			wantResult := &filtering.Result{
 				IsFiltered: true,
 				Reason:     filtering.FilteredBlockList,
 				Rules: []*filtering.ResultRule{{
 					Text: tc.wantRule,
 				}},
 			}
-			assert.Equal(t, want, res)
+
+			assert.Equal(t, wantResult, res)
+			assert.Equal(t, resp, dctx.origResp)
 		})
 	}
 }
diff --git a/internal/dnsforward/process.go b/internal/dnsforward/process.go
index 4780c856..0b572d8b 100644
--- a/internal/dnsforward/process.go
+++ b/internal/dnsforward/process.go
@@ -671,11 +671,11 @@ func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
 }
 
 // Apply filtering logic
-func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
+func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode) {
 	log.Debug("dnsforward: started processing filtering before req")
 	defer log.Debug("dnsforward: finished processing filtering before req")
 
-	if ctx.proxyCtx.Res != nil {
+	if dctx.proxyCtx.Res != nil {
 		// Go on since the response is already set.
 		return resultCodeSuccess
 	}
@@ -684,8 +684,8 @@ func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode)
 	defer s.serverLock.RUnlock()
 
 	var err error
-	if ctx.result, err = s.filterDNSRequest(ctx); err != nil {
-		ctx.err = err
+	if dctx.result, err = s.filterDNSRequest(dctx); err != nil {
+		dctx.err = err
 
 		return resultCodeError
 	}
@@ -857,7 +857,6 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
 	log.Debug("dnsforward: started processing filtering after resp")
 	defer log.Debug("dnsforward: finished processing filtering after resp")
 
-	pctx := dctx.proxyCtx
 	switch res := dctx.result; res.Reason {
 	case filtering.NotFilteredAllowList:
 		return resultCodeSuccess
@@ -871,6 +870,7 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
 			return resultCodeSuccess
 		}
 
+		pctx := dctx.proxyCtx
 		pctx.Req.Question[0], pctx.Res.Question[0] = dctx.origQuestion, dctx.origQuestion
 		if len(pctx.Res.Answer) > 0 {
 			rr := s.genAnswerCNAME(pctx.Req, res.CanonName)
@@ -880,13 +880,13 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
 
 		return resultCodeSuccess
 	default:
-		return s.filterAfterResponse(dctx, pctx)
+		return s.filterAfterResponse(dctx)
 	}
 }
 
 // filterAfterResponse returns the result of filtering the response that wasn't
 // explicitly allowed or rewritten.
-func (s *Server) filterAfterResponse(dctx *dnsContext, pctx *proxy.DNSContext) (res resultCode) {
+func (s *Server) filterAfterResponse(dctx *dnsContext) (res resultCode) {
 	// Check the response only if it's from an upstream.  Don't check the
 	// response if the protection is disabled since dnsrewrite rules aren't
 	// applied to it anyway.
@@ -894,17 +894,12 @@ func (s *Server) filterAfterResponse(dctx *dnsContext, pctx *proxy.DNSContext) (
 		return resultCodeSuccess
 	}
 
-	result, err := s.filterDNSResponse(pctx, dctx.setts)
+	err := s.filterDNSResponse(dctx)
 	if err != nil {
 		dctx.err = err
 
 		return resultCodeError
 	}
 
-	if result != nil {
-		dctx.result = result
-		dctx.origResp = pctx.Res
-	}
-
 	return resultCodeSuccess
 }