From 496cbba94ec8c1684001f8ed0245b51a73d5bffe Mon Sep 17 00:00:00 2001
From: Stanislav Chzhen <s.chzhen@adguard.com>
Date: Thu, 23 Jan 2025 19:44:26 +0300
Subject: [PATCH] all: imp code

---
 go.mod                       |  2 +-
 go.sum                       |  4 ++--
 internal/dnsforward/stats.go | 41 +++++++-----------------------------
 internal/stats/stats_test.go | 13 ++++++++----
 internal/stats/unit.go       | 21 ++++++++++--------
 5 files changed, 32 insertions(+), 49 deletions(-)

diff --git a/go.mod b/go.mod
index d55a5452..aebc7d3a 100644
--- a/go.mod
+++ b/go.mod
@@ -4,7 +4,7 @@ go 1.23.5
 
 require (
 	// TODO!!
-	github.com/AdguardTeam/dnsproxy v0.74.2-0.20250116174805-966cabfa8953
+	github.com/AdguardTeam/dnsproxy v0.74.2-0.20250123124619-13a82417e9e2
 	github.com/AdguardTeam/golibs v0.31.0
 	github.com/AdguardTeam/urlfilter v0.20.0
 	github.com/NYTimes/gziphandler v1.1.1
diff --git a/go.sum b/go.sum
index 18157b2a..ee53df43 100644
--- a/go.sum
+++ b/go.sum
@@ -1,5 +1,5 @@
-github.com/AdguardTeam/dnsproxy v0.74.2-0.20250116174805-966cabfa8953 h1:oWKRUtLrKqUO0g3Vh/uEvuxH4wpqh5mV2r7fwqmABF8=
-github.com/AdguardTeam/dnsproxy v0.74.2-0.20250116174805-966cabfa8953/go.mod h1:Oqw+k7LyjDObfYzXYCkpgtirbzbUrmotr92jrb3N09I=
+github.com/AdguardTeam/dnsproxy v0.74.2-0.20250123124619-13a82417e9e2 h1:YhG4TGJYPFZbio1Pwo3uxONNUqRJL2dfzVIqtbBl6MQ=
+github.com/AdguardTeam/dnsproxy v0.74.2-0.20250123124619-13a82417e9e2/go.mod h1:Oqw+k7LyjDObfYzXYCkpgtirbzbUrmotr92jrb3N09I=
 github.com/AdguardTeam/golibs v0.31.0 h1:Z0oPfLTLw6iZmpE58dePy2Bel0MaX+lnDwtFEE5EmIo=
 github.com/AdguardTeam/golibs v0.31.0/go.mod h1:wIkZ9o2UnppeW6/YD7yJB71dYbMhiuC1Fh/I2ElW7GQ=
 github.com/AdguardTeam/urlfilter v0.20.0 h1:X32qiuVCVd8WDYCEsbdZKfXMzwdVqrdulamtUi4rmzs=
diff --git a/internal/dnsforward/stats.go b/internal/dnsforward/stats.go
index 622421c4..50818b40 100644
--- a/internal/dnsforward/stats.go
+++ b/internal/dnsforward/stats.go
@@ -139,16 +139,20 @@ func (s *Server) logQuery(dctx *dnsContext, ip net.IP, processingTime time.Durat
 func (s *Server) updateStats(dctx *dnsContext, clientIP string, processingTime time.Duration) {
 	pctx := dctx.proxyCtx
 
+	var upstreamStats []*proxy.UpstreamStatistics
+	qs := pctx.QueryStatistics()
+	if qs != nil {
+		upstreamStats = append(upstreamStats, qs.Main()...)
+		upstreamStats = append(upstreamStats, qs.Fallback()...)
+	}
+
 	e := &stats.Entry{
+		UpstreamStats:  upstreamStats,
 		Domain:         aghnet.NormalizeDomain(pctx.Req.Question[0].Name),
 		Result:         stats.RNotFiltered,
 		ProcessingTime: processingTime,
 	}
 
-	if pctx.Upstream != nil {
-		e.Upstream, e.UpstreamTime = upstreamDur(pctx)
-	}
-
 	if clientID := dctx.clientID; clientID != "" {
 		e.Client = clientID
 	} else {
@@ -171,32 +175,3 @@ func (s *Server) updateStats(dctx *dnsContext, clientIP string, processingTime t
 
 	s.stats.Update(e)
 }
-
-// upstreamDur returns the upstream DNS server address and the DNS lookup
-// duration.  If the upstream address is empty, it means the request was served
-// from the cache.
-func upstreamDur(pctx *proxy.DNSContext) (upstream string, dur time.Duration) {
-	if pctx.Upstream == nil {
-		return "", 0
-	}
-
-	qs := pctx.QueryStatistics()
-	if qs == nil {
-		return "", 0
-	}
-
-	addr := pctx.Upstream.Address()
-	for _, u := range qs.Main() {
-		if u.Address == addr {
-			return u.Address, u.QueryDuration
-		}
-	}
-
-	for _, u := range qs.Fallback() {
-		if u.Address == addr {
-			return u.Address, u.QueryDuration
-		}
-	}
-
-	return "", 0
-}
diff --git a/internal/stats/stats_test.go b/internal/stats/stats_test.go
index dbf857d6..06aa36f3 100644
--- a/internal/stats/stats_test.go
+++ b/internal/stats/stats_test.go
@@ -13,6 +13,7 @@ import (
 
 	"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
 	"github.com/AdguardTeam/AdGuardHome/internal/stats"
+	"github.com/AdguardTeam/dnsproxy/proxy"
 	"github.com/AdguardTeam/golibs/logutil/slogutil"
 	"github.com/AdguardTeam/golibs/netutil"
 	"github.com/AdguardTeam/golibs/testutil"
@@ -78,15 +79,19 @@ func TestStats(t *testing.T) {
 			Client:         cliIPStr,
 			Result:         stats.RFiltered,
 			ProcessingTime: time.Microsecond * 123456,
-			Upstream:       respUpstream,
-			UpstreamTime:   time.Microsecond * 222222,
+			UpstreamStats: []*proxy.UpstreamStatistics{{
+				Address:       respUpstream,
+				QueryDuration: time.Microsecond * 222222,
+			}},
 		}, {
 			Domain:         reqDomain,
 			Client:         cliIPStr,
 			Result:         stats.RNotFiltered,
 			ProcessingTime: time.Microsecond * 123456,
-			Upstream:       respUpstream,
-			UpstreamTime:   time.Microsecond * 222222,
+			UpstreamStats: []*proxy.UpstreamStatistics{{
+				Address:       respUpstream,
+				QueryDuration: time.Microsecond * 222222,
+			}},
 		}}
 
 		wantData := &stats.StatsResp{
diff --git a/internal/stats/unit.go b/internal/stats/unit.go
index cef7fca2..5eef6f3c 100644
--- a/internal/stats/unit.go
+++ b/internal/stats/unit.go
@@ -10,6 +10,7 @@ import (
 	"time"
 
 	"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
+	"github.com/AdguardTeam/dnsproxy/proxy"
 	"github.com/AdguardTeam/golibs/errors"
 	"github.com/AdguardTeam/golibs/logutil/slogutil"
 	"go.etcd.io/bbolt"
@@ -62,8 +63,9 @@ type Entry struct {
 	// Domain is the domain name requested.
 	Domain string
 
-	// Upstream is the upstream DNS server.
-	Upstream string
+	// UpstreamStats contains the DNS query statistics for both the upstream and
+	// fallback DNS servers.
+	UpstreamStats []*proxy.UpstreamStatistics
 
 	// Result is the result of processing the request.
 	Result Result
@@ -71,9 +73,6 @@ type Entry struct {
 	// ProcessingTime is the duration of the request processing from the start
 	// of the request including timeouts.
 	ProcessingTime time.Duration
-
-	// UpstreamTime is the duration of the successful request to the upstream.
-	UpstreamTime time.Duration
 }
 
 // validate returns an error if entry is not valid.
@@ -329,10 +328,14 @@ func (u *unit) add(e *Entry) {
 	u.timeSum += pt
 	u.nTotal++
 
-	if e.Upstream != "" {
-		u.upstreamsResponses[e.Upstream]++
-		ut := uint64(e.UpstreamTime.Microseconds())
-		u.upstreamsTimeSum[e.Upstream] += ut
+	for _, s := range e.UpstreamStats {
+		if s.IsCached || s.Error != nil {
+			continue
+		}
+
+		addr := s.Address
+		u.upstreamsResponses[addr]++
+		u.upstreamsTimeSum[addr] += uint64(s.QueryDuration.Microseconds())
 	}
 }