diff --git a/go.mod b/go.mod index b9c6330e..1b9f67d7 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome go 1.21.8 require ( - github.com/AdguardTeam/dnsproxy v0.65.2 + github.com/AdguardTeam/dnsproxy v0.66.0 github.com/AdguardTeam/golibs v0.20.1 github.com/AdguardTeam/urlfilter v0.18.0 github.com/NYTimes/gziphandler v1.1.1 diff --git a/go.sum b/go.sum index bc1679f5..557c6cbb 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/AdguardTeam/dnsproxy v0.65.2 h1:D+BMw0Vu2lbQrYpoPctG2Xr+24KdfhgkzZb6QgPZheM= -github.com/AdguardTeam/dnsproxy v0.65.2/go.mod h1:8NQTTNZY+qR9O1Fzgz3WQv30knfSgms68SRlzSnX74A= +github.com/AdguardTeam/dnsproxy v0.66.0 h1:RyUbyDxRSXBFjVG1l2/4HV3I98DtfIgpnZkgXkgHKnc= +github.com/AdguardTeam/dnsproxy v0.66.0/go.mod h1:ZThEXbMUlP1RxfwtNW30ItPAHE6OF4YFygK8qjU/cvY= github.com/AdguardTeam/golibs v0.20.1 h1:ol8qLjWGZhU9paMMwN+OLWVTUigGsXa29iVTyd62VKY= github.com/AdguardTeam/golibs v0.20.1/go.mod h1:bgcMgRviCKyU6mkrX+RtT/OsKPFzyppelfRsksMG3KU= github.com/AdguardTeam/urlfilter v0.18.0 h1:ZZzwODC/ADpjJSODxySrrUnt/fvOCfGFaCW6j+wsGfQ= diff --git a/internal/aghtest/aghtest.go b/internal/aghtest/aghtest.go index 98de9b05..1d9067c5 100644 --- a/internal/aghtest/aghtest.go +++ b/internal/aghtest/aghtest.go @@ -9,8 +9,13 @@ import ( "net/netip" "net/url" "testing" + "time" + "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/testutil" + "github.com/miekg/dns" "github.com/stretchr/testify/require" ) @@ -71,3 +76,49 @@ func StartHTTPServer(t testing.TB, data []byte) (c *http.Client, u *url.URL) { return srv.Client(), u } + +// testTimeout is a timeout for tests. +// +// TODO(e.burkov): Move into agdctest. +const testTimeout = 1 * time.Second + +// StartLocalhostUpstream is a test helper that starts a DNS server on +// localhost. +func StartLocalhostUpstream(t *testing.T, h dns.Handler) (addr *url.URL) { + t.Helper() + + startCh := make(chan netip.AddrPort) + defer close(startCh) + errCh := make(chan error) + + srv := &dns.Server{ + Addr: "127.0.0.1:0", + Net: string(proxy.ProtoTCP), + Handler: h, + ReadTimeout: testTimeout, + WriteTimeout: testTimeout, + } + srv.NotifyStartedFunc = func() { + addrPort := srv.Listener.Addr() + startCh <- netutil.NetAddrToAddrPort(addrPort) + } + + go func() { errCh <- srv.ListenAndServe() }() + + select { + case addrPort := <-startCh: + addr = &url.URL{ + Scheme: string(proxy.ProtoTCP), + Host: addrPort.String(), + } + + testutil.CleanupAndRequireSuccess(t, func() (err error) { return <-errCh }) + testutil.CleanupAndRequireSuccess(t, srv.Shutdown) + case err := <-errCh: + require.NoError(t, err) + case <-time.After(testTimeout): + require.FailNow(t, "timeout exceeded") + } + + return addr +} diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 30211677..cc75cdf7 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -357,10 +357,6 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) { conf.DNSCryptResolverCert = c.ResolverCert } - if conf.UpstreamConfig == nil || len(conf.UpstreamConfig.Upstreams) == 0 { - return nil, errors.Error("no default upstream servers configured") - } - conf, err = prepareCacheConfig(conf, srvConf.CacheSize, srvConf.CacheMinTTL, diff --git a/internal/dnsforward/dns64_test.go b/internal/dnsforward/dns64_test.go index ad89098c..49e1e4ce 100644 --- a/internal/dnsforward/dns64_test.go +++ b/internal/dnsforward/dns64_test.go @@ -8,7 +8,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" - "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" @@ -101,21 +100,6 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) { type answerMap = map[uint16][sectionsNum][]dns.RR pt := testutil.PanicT{} - newUps := func(answers answerMap) (u upstream.Upstream) { - return aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) { - q := req.Question[0] - require.Contains(pt, answers, q.Qtype) - - answer := answers[q.Qtype] - - resp = (&dns.Msg{}).SetReply(req) - resp.Answer = answer[sectionAnswer] - resp.Ns = answer[sectionAuthority] - resp.Extra = answer[sectionAdditional] - - return resp, nil - }) - } testCases := []struct { name string @@ -265,13 +249,16 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) { }} localRR := newRR(t, ptr64Domain, dns.TypePTR, 3600, pointedDomain) - localUps := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) { - require.Equal(pt, req.Question[0].Name, ptr64Domain) - resp = (&dns.Msg{}).SetReply(req) - resp.Answer = []dns.RR{localRR} + localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { + require.Len(pt, m.Question, 1) + require.Equal(pt, m.Question[0].Name, ptr64Domain) + resp := (&dns.Msg{ + Answer: []dns.RR{localRR}, + }).SetReply(m) - return resp, nil + require.NoError(t, w.WriteMsg(resp)) }) + localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String() client := &dns.Client{ Net: "tcp", @@ -279,25 +266,44 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) { } for _, tc := range testCases { - // TODO(e.burkov): It seems [proxy.Proxy] isn't intended to be reused - // right after stop, due to a data race in [proxy.Proxy.Init] method - // when setting an OOB size. As a temporary workaround, recreate the - // whole server for each test case. - s := createTestServer(t, &filtering.Config{ - BlockingMode: filtering.BlockingModeDefault, - }, ServerConfig{ - UDPListenAddrs: []*net.UDPAddr{{}}, - TCPListenAddrs: []*net.TCPAddr{{}}, - UseDNS64: true, - Config: Config{ - UpstreamMode: UpstreamModeLoadBalance, - EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, - }, - ServePlainDNS: true, - }, localUps) - + tc := tc t.Run(tc.name, func(t *testing.T) { - s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newUps(tc.upsAns)} + upsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { + q := req.Question[0] + require.Contains(pt, tc.upsAns, q.Qtype) + + answer := tc.upsAns[q.Qtype] + + resp := (&dns.Msg{ + Answer: answer[sectionAnswer], + Ns: answer[sectionAuthority], + Extra: answer[sectionAdditional], + }).SetReply(req) + + require.NoError(pt, w.WriteMsg(resp)) + }) + upsAddr := aghtest.StartLocalhostUpstream(t, upsHdlr).String() + + // TODO(e.burkov): It seems [proxy.Proxy] isn't intended to be + // reused right after stop, due to a data race in [proxy.Proxy.Init] + // method when setting an OOB size. As a temporary workaround, + // recreate the whole server for each test case. + s := createTestServer(t, &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, + }, ServerConfig{ + UDPListenAddrs: []*net.UDPAddr{{}}, + TCPListenAddrs: []*net.TCPAddr{{}}, + UseDNS64: true, + Config: Config{ + UpstreamMode: UpstreamModeLoadBalance, + EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + UpstreamDNS: []string{upsAddr}, + }, + UsePrivateRDNS: true, + LocalPTRResolvers: []string{localUpsAddr}, + ServePlainDNS: true, + }) + startDeferStop(t, s) req := (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype) diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 6ad06333..0380d004 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -464,7 +464,8 @@ func (s *Server) Start() error { // startLocked starts the DNS server without locking. s.serverLock is expected // to be locked. func (s *Server) startLocked() error { - err := s.dnsProxy.Start() + // TODO(e.burkov): Use context properly. + err := s.dnsProxy.Start(context.Background()) if err == nil { s.isRunning = true } @@ -518,34 +519,30 @@ func (s *Server) prepareLocalResolvers( } // setupLocalResolvers initializes and sets the resolvers for local addresses. -// It assumes s.serverLock is locked or s not running. -func (s *Server) setupLocalResolvers(boot upstream.Resolver) (err error) { - uc, err := s.prepareLocalResolvers(boot) +// It assumes s.serverLock is locked or s not running. It returns the upstream +// configuration used for private PTR resolving, or nil if it's disabled. Note, +// that it's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig]. +func (s *Server) setupLocalResolvers(boot upstream.Resolver) (uc *proxy.UpstreamConfig, err error) { + if !s.conf.UsePrivateRDNS { + // It's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig]. + return nil, nil + } + + uc, err = s.prepareLocalResolvers(boot) if err != nil { // Don't wrap the error because it's informative enough as is. - return err + return nil, err } - s.localResolvers = &proxy.Proxy{ - Config: proxy.Config{ - UpstreamConfig: uc, - }, - } - - err = s.localResolvers.Init() + s.localResolvers, err = proxy.New(&proxy.Config{ + UpstreamConfig: uc, + }) if err != nil { - return fmt.Errorf("initializing proxy: %w", err) + return nil, fmt.Errorf("creating local resolvers: %w", err) } // TODO(e.burkov): Should we also consider the DNS64 usage? - if s.conf.UsePrivateRDNS && - // Only set the upstream config if there are any upstreams. It's safe - // to put nil into [proxy.Config.PrivateRDNSUpstreamConfig]. - len(uc.Upstreams)+len(uc.DomainReservedUpstreams)+len(uc.SpecifiedDomainUpstreams) > 0 { - s.dnsProxy.PrivateRDNSUpstreamConfig = uc - } - - return nil + return uc, nil } // Prepare initializes parameters of s using data from conf. conf must not be @@ -586,21 +583,22 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) { return fmt.Errorf("preparing access: %w", err) } - // Set the proxy here because [setupLocalResolvers] sets its values. - // // TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy. - s.dnsProxy = &proxy.Proxy{Config: *proxyConfig} - - err = s.setupLocalResolvers(boot) + proxyConfig.PrivateRDNSUpstreamConfig, err = s.setupLocalResolvers(boot) if err != nil { return fmt.Errorf("setting up resolvers: %w", err) } - err = s.setupFallbackDNS() + proxyConfig.Fallbacks, err = s.setupFallbackDNS() if err != nil { return fmt.Errorf("setting up fallback dns servers: %w", err) } + s.dnsProxy, err = proxy.New(proxyConfig) + if err != nil { + return fmt.Errorf("creating proxy: %w", err) + } + s.recDetector.clear() s.setupAddrProc() @@ -643,26 +641,25 @@ func (s *Server) prepareInternalDNS() (boot upstream.Resolver, err error) { } // setupFallbackDNS initializes the fallback DNS servers. -func (s *Server) setupFallbackDNS() (err error) { +func (s *Server) setupFallbackDNS() (uc *proxy.UpstreamConfig, err error) { fallbacks := s.conf.FallbackDNS fallbacks = stringutil.FilterOut(fallbacks, IsCommentOrEmpty) if len(fallbacks) == 0 { - return nil + return nil, nil } - uc, err := proxy.ParseUpstreamsConfig(fallbacks, &upstream.Options{ + uc, err = proxy.ParseUpstreamsConfig(fallbacks, &upstream.Options{ // TODO(s.chzhen): Investigate if other options are needed. Timeout: s.conf.UpstreamTimeout, PreferIPv6: s.conf.BootstrapPreferIPv6, + // TODO(e.burkov): Use bootstrap. }) if err != nil { // Do not wrap the error because it's informative enough as is. - return err + return nil, err } - s.dnsProxy.Fallbacks = uc - - return nil + return uc, nil } // setupAddrProc initializes the address processor. It assumes s.serverLock is @@ -730,19 +727,9 @@ func (s *Server) prepareInternalProxy() (err error) { return fmt.Errorf("invalid upstream mode: %w", err) } - // TODO(a.garipov): Make a proper constructor for proxy.Proxy. - p := &proxy.Proxy{ - Config: *conf, - } + s.internalProxy, err = proxy.New(conf) - err = p.Init() - if err != nil { - return err - } - - s.internalProxy = p - - return nil + return err } // Stop stops the DNS server. @@ -761,14 +748,17 @@ func (s *Server) stopLocked() (err error) { // [upstream.Upstream] implementations. if s.dnsProxy != nil { - err = s.dnsProxy.Stop() + // TODO(e.burkov): Use context properly. + err = s.dnsProxy.Shutdown(context.Background()) if err != nil { log.Error("dnsforward: closing primary resolvers: %s", err) } } logCloserErr(s.internalProxy.UpstreamConfig, "dnsforward: closing internal resolvers: %s") - logCloserErr(s.localResolvers.UpstreamConfig, "dnsforward: closing local resolvers: %s") + if s.localResolvers != nil { + logCloserErr(s.localResolvers.UpstreamConfig, "dnsforward: closing local resolvers: %s") + } for _, b := range s.bootResolvers { logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address()) diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 0cbb21cb..b490b38d 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -5,9 +5,11 @@ import ( "crypto/ecdsa" "crypto/rand" "crypto/rsa" + "crypto/sha256" "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "encoding/hex" "encoding/pem" "fmt" "math/big" @@ -63,8 +65,7 @@ func startDeferStop(t *testing.T, s *Server) { t.Helper() err := s.Start() - require.NoErrorf(t, err, "failed to start server: %s", err) - + require.NoError(t, err) testutil.CleanupAndRequireSuccess(t, s.Stop) } @@ -72,7 +73,6 @@ func createTestServer( t *testing.T, filterConf *filtering.Config, forwardConf ServerConfig, - localUps upstream.Upstream, ) (s *Server) { t.Helper() @@ -82,7 +82,8 @@ func createTestServer( @@||whitelist.example.org^ ||127.0.0.255` filters := []filtering.Filter{{ - ID: 0, Data: []byte(rules), + ID: 0, + Data: []byte(rules), }} f, err := filtering.New(filterConf, filters) @@ -105,19 +106,6 @@ func createTestServer( err = s.Prepare(&forwardConf) require.NoError(t, err) - s.serverLock.Lock() - defer s.serverLock.Unlock() - - // TODO(e.burkov): Try to move it higher. - if localUps != nil { - ups := []upstream.Upstream{localUps} - s.localResolvers.UpstreamConfig.Upstreams = ups - s.conf.UsePrivateRDNS = true - s.dnsProxy.PrivateRDNSUpstreamConfig = &proxy.UpstreamConfig{ - Upstreams: ups, - } - } - return s } @@ -181,7 +169,7 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte) EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, }, ServePlainDNS: true, - }, nil) + }) tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem s.conf.TLSConfig = tlsConf @@ -310,7 +298,7 @@ func TestServer(t *testing.T) { EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, }, ServePlainDNS: true, - }, nil) + }) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()} startDeferStop(t, s) @@ -410,7 +398,7 @@ func TestServerWithProtectionDisabled(t *testing.T) { EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, }, ServePlainDNS: true, - }, nil) + }) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()} startDeferStop(t, s) @@ -490,7 +478,7 @@ func TestServerRace(t *testing.T) { ConfigModified: func() {}, ServePlainDNS: true, } - s := createTestServer(t, filterConf, forwardConf, nil) + s := createTestServer(t, filterConf, forwardConf) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()} startDeferStop(t, s) @@ -545,7 +533,7 @@ func TestSafeSearch(t *testing.T) { }, ServePlainDNS: true, } - s := createTestServer(t, filterConf, forwardConf, nil) + s := createTestServer(t, filterConf, forwardConf) startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP).String() @@ -628,7 +616,7 @@ func TestInvalidRequest(t *testing.T) { }, }, ServePlainDNS: true, - }, nil) + }) startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP).String() @@ -662,7 +650,7 @@ func TestBlockedRequest(t *testing.T) { s := createTestServer(t, &filtering.Config{ ProtectionEnabled: true, BlockingMode: filtering.BlockingModeDefault, - }, forwardConf, nil) + }, forwardConf) startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP) @@ -698,7 +686,7 @@ func TestServerCustomClientUpstream(t *testing.T) { } s := createTestServer(t, &filtering.Config{ BlockingMode: filtering.BlockingModeDefault, - }, forwardConf, nil) + }, forwardConf) ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) { atomic.AddUint32(&upsCalledCounter, 1) @@ -773,7 +761,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) { }, }, ServePlainDNS: true, - }, nil) + }) testUpstm := &aghtest.Upstream{ CName: testCNAMEs, IPv4: testIPv4, @@ -811,7 +799,7 @@ func TestBlockCNAME(t *testing.T) { s := createTestServer(t, &filtering.Config{ ProtectionEnabled: true, BlockingMode: filtering.BlockingModeDefault, - }, forwardConf, nil) + }, forwardConf) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ &aghtest.Upstream{ CName: testCNAMEs, @@ -886,7 +874,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) { } s := createTestServer(t, &filtering.Config{ BlockingMode: filtering.BlockingModeDefault, - }, forwardConf, nil) + }, forwardConf) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ &aghtest.Upstream{ CName: testCNAMEs, @@ -933,7 +921,7 @@ func TestNullBlockedRequest(t *testing.T) { s := createTestServer(t, &filtering.Config{ ProtectionEnabled: true, BlockingMode: filtering.BlockingModeNullIP, - }, forwardConf, nil) + }, forwardConf) startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP) @@ -1054,7 +1042,7 @@ func TestBlockedByHosts(t *testing.T) { s := createTestServer(t, &filtering.Config{ ProtectionEnabled: true, BlockingMode: filtering.BlockingModeDefault, - }, forwardConf, nil) + }, forwardConf) startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP) @@ -1102,7 +1090,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) { }, ServePlainDNS: true, } - s := createTestServer(t, filterConf, forwardConf, nil) + s := createTestServer(t, filterConf, forwardConf) startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP) @@ -1482,6 +1470,8 @@ func TestServer_Exchange(t *testing.T) { onesIP = netip.MustParseAddr("1.1.1.1") twosIP = netip.MustParseAddr("2.2.2.2") localIP = netip.MustParseAddr("192.168.1.1") + + pt = testutil.PanicT{} ) onesRevExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice()) @@ -1490,72 +1480,73 @@ func TestServer_Exchange(t *testing.T) { twosRevExtIPv4, err := netutil.IPToReversedAddr(twosIP.AsSlice()) require.NoError(t, err) - extUpstream := &aghtest.UpstreamMock{ - OnAddress: func() (addr string) { return "external.upstream.example" }, - OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { - return aghalg.Coalesce( - aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, onesHost), - doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, twosHost)), - new(dns.Msg).SetRcode(req, dns.RcodeNameError), - ), nil - }, - } + extUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { + resp := aghalg.Coalesce( + aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, dns.Fqdn(onesHost)), + doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, dns.Fqdn(twosHost))), + new(dns.Msg).SetRcode(req, dns.RcodeNameError), + ) + + require.NoError(pt, w.WriteMsg(resp)) + }) + upsAddr := aghtest.StartLocalhostUpstream(t, extUpsHdlr).String() revLocIPv4, err := netutil.IPToReversedAddr(localIP.AsSlice()) require.NoError(t, err) - locUpstream := &aghtest.UpstreamMock{ - OnAddress: func() (addr string) { return "local.upstream.example" }, - OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { - return aghalg.Coalesce( - aghtest.MatchedResponse(req, dns.TypePTR, revLocIPv4, localDomainHost), - new(dns.Msg).SetRcode(req, dns.RcodeNameError), - ), nil - }, - } + locUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { + resp := aghalg.Coalesce( + aghtest.MatchedResponse(req, dns.TypePTR, revLocIPv4, dns.Fqdn(localDomainHost)), + new(dns.Msg).SetRcode(req, dns.RcodeNameError), + ) - errUpstream := aghtest.NewErrorUpstream() - nonPtrUpstream := aghtest.NewBlockUpstream("some-host", true) - refusingUpstream := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) { - return new(dns.Msg).SetRcode(req, dns.RcodeRefused), nil + require.NoError(pt, w.WriteMsg(resp)) }) - zeroTTLUps := &aghtest.UpstreamMock{ - OnAddress: func() (addr string) { return "zero.ttl.example" }, - OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { - resp = new(dns.Msg).SetReply(req) - hdr := dns.RR_Header{ - Name: req.Question[0].Name, - Rrtype: dns.TypePTR, - Class: dns.ClassINET, - Ttl: 0, - } - resp.Answer = []dns.RR{&dns.PTR{ - Hdr: hdr, - Ptr: localDomainHost, - }} - return resp, nil - }, - } + errUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { + require.NoError(pt, w.WriteMsg(new(dns.Msg).SetRcode(req, dns.RcodeServerFailure))) + }) - srv := &Server{ - recDetector: newRecursionDetector(0, 1), - internalProxy: &proxy.Proxy{ - Config: proxy.Config{ - UpstreamConfig: &proxy.UpstreamConfig{ - Upstreams: []upstream.Upstream{extUpstream}, + nonPtrHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { + hash := sha256.Sum256([]byte("some-host")) + resp := (&dns.Msg{ + Answer: []dns.RR{&dns.TXT{ + Hdr: dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + Ttl: 60, }, - }, - }, - } - srv.conf.UsePrivateRDNS = true - srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed) - require.NoError(t, srv.internalProxy.Init()) + Txt: []string{hex.EncodeToString(hash[:])}, + }}, + }).SetReply(req) + + require.NoError(pt, w.WriteMsg(resp)) + }) + refusingHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { + require.NoError(pt, w.WriteMsg(new(dns.Msg).SetRcode(req, dns.RcodeRefused))) + }) + + zeroTTLHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { + resp := (&dns.Msg{ + Answer: []dns.RR{&dns.PTR{ + Hdr: dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: 0, + }, + Ptr: dns.Fqdn(localDomainHost), + }}, + }).SetReply(req) + + require.NoError(pt, w.WriteMsg(resp)) + }) testCases := []struct { req netip.Addr wantErr error - locUpstream upstream.Upstream + locUpstream dns.Handler name string want string wantTTL time.Duration @@ -1570,35 +1561,35 @@ func TestServer_Exchange(t *testing.T) { name: "local_good", want: localDomainHost, wantErr: nil, - locUpstream: locUpstream, + locUpstream: locUpsHdlr, req: localIP, wantTTL: defaultTTL, }, { name: "upstream_error", want: "", - wantErr: aghtest.ErrUpstream, - locUpstream: errUpstream, + wantErr: ErrRDNSFailed, + locUpstream: errUpsHdlr, req: localIP, wantTTL: 0, }, { name: "empty_answer_error", want: "", wantErr: ErrRDNSNoData, - locUpstream: locUpstream, + locUpstream: locUpsHdlr, req: netip.MustParseAddr("192.168.1.2"), wantTTL: 0, }, { name: "invalid_answer", want: "", wantErr: ErrRDNSNoData, - locUpstream: nonPtrUpstream, + locUpstream: nonPtrHdlr, req: localIP, wantTTL: 0, }, { name: "refused", want: "", wantErr: ErrRDNSFailed, - locUpstream: refusingUpstream, + locUpstream: refusingHdlr, req: localIP, wantTTL: 0, }, { @@ -1612,23 +1603,28 @@ func TestServer_Exchange(t *testing.T) { name: "zero_ttl", want: localDomainHost, wantErr: nil, - locUpstream: zeroTTLUps, + locUpstream: zeroTTLHdlr, req: localIP, wantTTL: 0, }} for _, tc := range testCases { - pcfg := proxy.Config{ - UpstreamConfig: &proxy.UpstreamConfig{ - Upstreams: []upstream.Upstream{tc.locUpstream}, - }, - } - srv.localResolvers = &proxy.Proxy{ - Config: pcfg, - } - require.NoError(t, srv.localResolvers.Init()) + localUpsAddr := aghtest.StartLocalhostUpstream(t, tc.locUpstream).String() t.Run(tc.name, func(t *testing.T) { + srv := createTestServer(t, &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, + }, ServerConfig{ + Config: Config{ + UpstreamDNS: []string{upsAddr}, + UpstreamMode: UpstreamModeLoadBalance, + EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + }, + LocalPTRResolvers: []string{localUpsAddr}, + UsePrivateRDNS: true, + ServePlainDNS: true, + }) + host, ttl, eerr := srv.Exchange(tc.req) require.ErrorIs(t, eerr, tc.wantErr) @@ -1638,8 +1634,17 @@ func TestServer_Exchange(t *testing.T) { } t.Run("resolving_disabled", func(t *testing.T) { - srv.conf.UsePrivateRDNS = false - t.Cleanup(func() { srv.conf.UsePrivateRDNS = true }) + srv := createTestServer(t, &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, + }, ServerConfig{ + Config: Config{ + UpstreamDNS: []string{upsAddr}, + UpstreamMode: UpstreamModeLoadBalance, + EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + }, + LocalPTRResolvers: []string{}, + ServePlainDNS: true, + }) host, _, eerr := srv.Exchange(localIP) diff --git a/internal/dnsforward/dnsrewrite_test.go b/internal/dnsforward/dnsrewrite_test.go index 5204c2f2..8f26ac85 100644 --- a/internal/dnsforward/dnsrewrite_test.go +++ b/internal/dnsforward/dnsrewrite_test.go @@ -42,7 +42,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) { EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, }, ServePlainDNS: true, - }, nil) + }) makeQ := func(qtype rules.RRType) (req *dns.Msg) { return &dns.Msg{ diff --git a/internal/dnsforward/http_test.go b/internal/dnsforward/http_test.go index 66499746..408e2e46 100644 --- a/internal/dnsforward/http_test.go +++ b/internal/dnsforward/http_test.go @@ -83,7 +83,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) { ConfigModified: func() {}, ServePlainDNS: true, } - s := createTestServer(t, filterConf, forwardConf, nil) + s := createTestServer(t, filterConf, forwardConf) s.sysResolvers = &emptySysResolvers{} require.NoError(t, s.Start()) @@ -164,7 +164,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) { ConfigModified: func() {}, ServePlainDNS: true, } - s := createTestServer(t, filterConf, forwardConf, nil) + s := createTestServer(t, filterConf, forwardConf) s.sysResolvers = &emptySysResolvers{} defaultConf := s.conf @@ -439,7 +439,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) { EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, }, ServePlainDNS: true, - }, nil) + }) srv.etcHosts = upstream.NewHostsResolver(hc) startDeferStop(t, srv) diff --git a/internal/dnsforward/process_internal_test.go b/internal/dnsforward/process_internal_test.go index 2c919d7d..5dc4e21b 100644 --- a/internal/dnsforward/process_internal_test.go +++ b/internal/dnsforward/process_internal_test.go @@ -9,7 +9,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" - "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/urlfilter/rules" @@ -87,7 +86,7 @@ func TestServer_ProcessInitial(t *testing.T) { s := createTestServer(t, &filtering.Config{ BlockingMode: filtering.BlockingModeDefault, - }, c, nil) + }, c) var gotAddr netip.Addr s.addrProc = &aghtest.AddressProcessor{ @@ -188,7 +187,7 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) { s := createTestServer(t, &filtering.Config{ BlockingMode: filtering.BlockingModeDefault, - }, c, nil) + }, c) resp := newResp(dns.RcodeSuccess, tc.req, tc.respAns) dctx := &dnsContext{ @@ -248,9 +247,9 @@ func TestServer_ProcessDDRQuery(t *testing.T) { host string want []*dns.SVCB wantRes resultCode - portDoH int - portDoT int - portDoQ int + addrsDoH []*net.TCPAddr + addrsDoT []*net.TCPAddr + addrsDoQ []*net.UDPAddr qtype uint16 ddrEnabled bool }{{ @@ -259,14 +258,14 @@ func TestServer_ProcessDDRQuery(t *testing.T) { host: testQuestionTarget, qtype: dns.TypeSVCB, ddrEnabled: true, - portDoH: 8043, + addrsDoH: []*net.TCPAddr{{Port: 8043}}, }, { name: "pass_qtype", wantRes: resultCodeFinish, host: ddrHostFQDN, qtype: dns.TypeA, ddrEnabled: true, - portDoH: 8043, + addrsDoH: []*net.TCPAddr{{Port: 8043}}, }, { name: "pass_disabled_tls", wantRes: resultCodeFinish, @@ -279,7 +278,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) { host: ddrHostFQDN, qtype: dns.TypeSVCB, ddrEnabled: false, - portDoH: 8043, + addrsDoH: []*net.TCPAddr{{Port: 8043}}, }, { name: "dot", wantRes: resultCodeFinish, @@ -287,7 +286,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) { host: ddrHostFQDN, qtype: dns.TypeSVCB, ddrEnabled: true, - portDoT: 8043, + addrsDoT: []*net.TCPAddr{{Port: 8043}}, }, { name: "doh", wantRes: resultCodeFinish, @@ -295,7 +294,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) { host: ddrHostFQDN, qtype: dns.TypeSVCB, ddrEnabled: true, - portDoH: 8044, + addrsDoH: []*net.TCPAddr{{Port: 8044}}, }, { name: "doq", wantRes: resultCodeFinish, @@ -303,7 +302,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) { host: ddrHostFQDN, qtype: dns.TypeSVCB, ddrEnabled: true, - portDoQ: 8042, + addrsDoQ: []*net.UDPAddr{{Port: 8042}}, }, { name: "dot_doh", wantRes: resultCodeFinish, @@ -311,13 +310,35 @@ func TestServer_ProcessDDRQuery(t *testing.T) { host: ddrHostFQDN, qtype: dns.TypeSVCB, ddrEnabled: true, - portDoT: 8043, - portDoH: 8044, + addrsDoT: []*net.TCPAddr{{Port: 8043}}, + addrsDoH: []*net.TCPAddr{{Port: 8044}}, }} + _, certPem, keyPem := createServerTLSConfig(t) + for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - s := prepareTestServer(t, tc.portDoH, tc.portDoT, tc.portDoQ, tc.ddrEnabled) + s := createTestServer(t, &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, + }, ServerConfig{ + Config: Config{ + HandleDDR: tc.ddrEnabled, + UpstreamMode: UpstreamModeLoadBalance, + EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + }, + TLSConfig: TLSConfig{ + ServerName: ddrTestDomainName, + CertificateChainData: certPem, + PrivateKeyData: keyPem, + TLSListenAddrs: tc.addrsDoT, + HTTPSListenAddrs: tc.addrsDoH, + QUICListenAddrs: tc.addrsDoQ, + }, + ServePlainDNS: true, + }) + // TODO(e.burkov): Generate a certificate actually containing the + // IP addresses. + s.conf.hasIPAddrs = true req := createTestMessageWithType(tc.host, tc.qtype) @@ -358,41 +379,6 @@ func createTestDNSFilter(t *testing.T) (f *filtering.DNSFilter) { return f } -func prepareTestServer(t *testing.T, portDoH, portDoT, portDoQ int, ddrEnabled bool) (s *Server) { - t.Helper() - - s = &Server{ - dnsFilter: createTestDNSFilter(t), - dnsProxy: &proxy.Proxy{ - Config: proxy.Config{}, - }, - conf: ServerConfig{ - Config: Config{ - HandleDDR: ddrEnabled, - }, - TLSConfig: TLSConfig{ - ServerName: ddrTestDomainName, - }, - ServePlainDNS: true, - }, - } - - if portDoT > 0 { - s.dnsProxy.TLSListenAddr = []*net.TCPAddr{{Port: portDoT}} - s.conf.hasIPAddrs = true - } - - if portDoQ > 0 { - s.dnsProxy.QUICListenAddr = []*net.UDPAddr{{Port: portDoQ}} - } - - if portDoH > 0 { - s.conf.HTTPSListenAddrs = []*net.TCPAddr{{Port: portDoH}} - } - - return s -} - func TestServer_ProcessDetermineLocal(t *testing.T) { s := &Server{ privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), @@ -680,13 +666,16 @@ func TestServer_ProcessRestrictLocal(t *testing.T) { intPTRAnswer = "some.local-client." ) - ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) { - return aghalg.Coalesce( + localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { + resp := aghalg.Coalesce( aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer), aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer), new(dns.Msg).SetRcode(req, dns.RcodeNameError), - ), nil + ) + + require.NoError(testutil.PanicT{}, w.WriteMsg(resp)) }) + localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String() s := createTestServer(t, &filtering.Config{ BlockingMode: filtering.BlockingModeDefault, @@ -696,12 +685,14 @@ func TestServer_ProcessRestrictLocal(t *testing.T) { // TODO(s.chzhen): Add tests where EDNSClientSubnet.Enabled is true. // Improve Config declaration for tests. Config: Config{ + UpstreamDNS: []string{localUpsAddr}, UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, }, - ServePlainDNS: true, - }, ups) - s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ups} + UsePrivateRDNS: true, + LocalPTRResolvers: []string{localUpsAddr}, + ServePlainDNS: true, + }) startDeferStop(t, s) testCases := []struct { @@ -764,6 +755,16 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) { const locDomain = "some.local." const reqAddr = "1.1.168.192.in-addr.arpa." + localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { + resp := aghalg.Coalesce( + aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain), + new(dns.Msg).SetRcode(req, dns.RcodeNameError), + ) + + require.NoError(testutil.PanicT{}, w.WriteMsg(resp)) + }) + localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String() + s := createTestServer( t, &filtering.Config{ @@ -776,14 +777,10 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) { UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, }, - ServePlainDNS: true, + UsePrivateRDNS: true, + LocalPTRResolvers: []string{localUpsAddr}, + ServePlainDNS: true, }, - aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) { - return aghalg.Coalesce( - aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain), - new(dns.Msg).SetRcode(req, dns.RcodeNameError), - ), nil - }), ) var proxyCtx *proxy.DNSContext diff --git a/internal/dnsforward/svcbmsg_test.go b/internal/dnsforward/svcbmsg_test.go index 2c2b7b0b..c5dbff6f 100644 --- a/internal/dnsforward/svcbmsg_test.go +++ b/internal/dnsforward/svcbmsg_test.go @@ -21,7 +21,7 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) { EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, }, ServePlainDNS: true, - }, nil) + }) req := &dns.Msg{ Question: []dns.Question{{ diff --git a/internal/next/dnssvc/dnssvc.go b/internal/next/dnssvc/dnssvc.go index 68c2e7e7..345af7bc 100644 --- a/internal/next/dnssvc/dnssvc.go +++ b/internal/next/dnssvc/dnssvc.go @@ -67,19 +67,15 @@ func New(c *Config) (svc *Service, err error) { } svc.bootstrapResolvers = resolvers - svc.proxy = &proxy.Proxy{ - Config: proxy.Config{ - UDPListenAddr: udpAddrs(c.Addresses), - TCPListenAddr: tcpAddrs(c.Addresses), - UpstreamConfig: &proxy.UpstreamConfig{ - Upstreams: upstreams, - }, - UseDNS64: c.UseDNS64, - DNS64Prefs: c.DNS64Prefixes, + svc.proxy, err = proxy.New(&proxy.Config{ + UDPListenAddr: udpAddrs(c.Addresses), + TCPListenAddr: tcpAddrs(c.Addresses), + UpstreamConfig: &proxy.UpstreamConfig{ + Upstreams: upstreams, }, - } - - err = svc.proxy.Init() + UseDNS64: c.UseDNS64, + DNS64Prefs: c.DNS64Prefixes, + }) if err != nil { return nil, fmt.Errorf("proxy: %w", err) } @@ -174,7 +170,7 @@ func (svc *Service) Start() (err error) { svc.running.Store(err == nil) }() - return svc.proxy.Start() + return svc.proxy.Start(context.Background()) } // Shutdown implements the [agh.Service] interface for *Service. svc may be @@ -185,7 +181,7 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) { } errs := []error{ - svc.proxy.Stop(), + svc.proxy.Shutdown(ctx), } for _, b := range svc.bootstrapResolvers {