From 73358263e82a92e611944eb54cc3f011ffec082a Mon Sep 17 00:00:00 2001
From: Ainar Garipov <a.garipov@adguard.com>
Date: Wed, 22 Nov 2023 13:49:02 +0300
Subject: [PATCH] Pull request 2076: 1660-disable-plain

Updates #1660.

Squashed commit of the following:

commit d928a00b7c77a33717fe3e77aace1f1b41a960d2
Merge: 38e401d78 0f5e8ca56
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Nov 22 13:39:34 2023 +0300

    Merge branch 'master' into 1660-disable-plain

commit 38e401d7827ce1ea190b5328cadb3bb0ff5a5cba
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Nov 21 20:17:53 2023 +0300

    dnsforward: imp validation

commit f9e99cec209078128fef1b147294c7abe3f6ae70
Merge: cb7529682 c8f1112d4
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon Nov 20 16:02:31 2023 +0300

    Merge branch 'master' into 1660-disable-plain

commit cb75296821cae594e8c4d17dfdd8be2190aee7f7
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Nov 17 14:20:02 2023 +0300

    all: add serve_plain_dns
---
 CHANGELOG.md                                 |  5 ++
 internal/dnsforward/config.go                | 51 ++++++++++++++++----
 internal/dnsforward/dns64_test.go            |  1 +
 internal/dnsforward/dnsforward.go            | 13 +++--
 internal/dnsforward/dnsforward_test.go       | 18 +++++++
 internal/dnsforward/dnsrewrite_test.go       |  1 +
 internal/dnsforward/filter_test.go           |  1 +
 internal/dnsforward/http_test.go             |  3 ++
 internal/dnsforward/process_internal_test.go |  5 ++
 internal/dnsforward/svcbmsg_test.go          |  1 +
 internal/home/config.go                      |  4 ++
 internal/home/dns.go                         | 12 ++---
 12 files changed, 94 insertions(+), 21 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 1d53dd12..d577b25a 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -25,10 +25,14 @@ NOTE: Add new changes BELOW THIS COMMENT.
 
 ### Added
 
+- Ability to disable plain-DNS serving through configuration file if an
+  encrypted protocol is already used ([#1660]).
 - Ability to specify rate limiting settings in the Web UI ([#6369]).
 
 #### Configuration changes
 
+- The new property `dns.serve_plain_dns` has been added to the configuration
+  file ([#1660]).
 - The property `dns.bogus_nxdomain` is now validated more strictly.
 - Added new properties `clients.persistent.*.upstreams_cache_enabled` and
   `clients.persistent.*.upstreams_cache_size` that describe cache configuration
@@ -40,6 +44,7 @@ NOTE: Add new changes BELOW THIS COMMENT.
 - Pre-filling the New static lease window with data ([#6402]).
 - Protection pause timer synchronization ([#5759]).
 
+[#1660]: https://github.com/AdguardTeam/AdGuardHome/issues/1660
 [#5759]: https://github.com/AdguardTeam/AdGuardHome/issues/5759
 [#6369]: https://github.com/AdguardTeam/AdGuardHome/issues/6369
 [#6402]: https://github.com/AdguardTeam/AdGuardHome/issues/6402
diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go
index 9b5bdc8c..e6f7c57d 100644
--- a/internal/dnsforward/config.go
+++ b/internal/dnsforward/config.go
@@ -289,14 +289,15 @@ type ServerConfig struct {
 	// UseHTTP3Upstreams defines if HTTP/3 is be allowed for DNS-over-HTTPS
 	// upstreams.
 	UseHTTP3Upstreams bool
+
+	// ServePlainDNS defines if plain DNS is allowed for incoming requests.
+	ServePlainDNS bool
 }
 
-// createProxyConfig creates and validates configuration for the main proxy.
-func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
+// newProxyConfig creates and validates configuration for the main proxy.
+func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
 	srvConf := s.conf
-	conf = proxy.Config{
-		UDPListenAddr:           srvConf.UDPListenAddrs,
-		TCPListenAddr:           srvConf.TCPListenAddrs,
+	conf = &proxy.Config{
 		HTTP3:                   srvConf.ServeHTTP3,
 		Ratelimit:               int(srvConf.Ratelimit),
 		RatelimitSubnetMaskIPv4: net.CIDRMask(srvConf.RatelimitSubnetLenIPv4, netutil.IPv4BitLen),
@@ -328,7 +329,7 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
 	}
 
 	setProxyUpstreamMode(
-		&conf,
+		conf,
 		srvConf.AllServers,
 		srvConf.FastestAddr,
 		srvConf.FastestTimeout.Duration,
@@ -336,12 +337,17 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
 
 	conf.BogusNXDomain, err = parseBogusNXDOMAIN(srvConf.BogusNXDomain)
 	if err != nil {
-		return proxy.Config{}, fmt.Errorf("bogus_nxdomain: %w", err)
+		return nil, fmt.Errorf("bogus_nxdomain: %w", err)
 	}
 
-	err = s.prepareTLS(&conf)
+	err = s.prepareTLS(conf)
 	if err != nil {
-		return proxy.Config{}, fmt.Errorf("validating tls: %w", err)
+		return nil, fmt.Errorf("validating tls: %w", err)
+	}
+
+	err = s.preparePlain(conf)
+	if err != nil {
+		return nil, fmt.Errorf("validating plain: %w", err)
 	}
 
 	if c := srvConf.DNSCryptConfig; c.Enabled {
@@ -352,7 +358,7 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
 	}
 
 	if conf.UpstreamConfig == nil || len(conf.UpstreamConfig.Upstreams) == 0 {
-		return proxy.Config{}, errors.Error("no default upstream servers configured")
+		return nil, errors.Error("no default upstream servers configured")
 	}
 
 	return conf, nil
@@ -664,6 +670,31 @@ func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, er
 	return &s.conf.cert, nil
 }
 
+// preparePlain prepares the plain-DNS configuration for the DNS proxy.
+// preparePlain assumes that prepareTLS has already been called.
+func (s *Server) preparePlain(proxyConf *proxy.Config) (err error) {
+	if s.conf.ServePlainDNS {
+		proxyConf.UDPListenAddr = s.conf.UDPListenAddrs
+		proxyConf.TCPListenAddr = s.conf.TCPListenAddrs
+
+		return nil
+	}
+
+	lenEncrypted := len(proxyConf.DNSCryptTCPListenAddr) +
+		len(proxyConf.DNSCryptUDPListenAddr) +
+		len(proxyConf.HTTPSListenAddr) +
+		len(proxyConf.QUICListenAddr) +
+		len(proxyConf.TLSListenAddr)
+	if lenEncrypted == 0 {
+		// TODO(a.garipov): Support full disabling of all DNS.
+		return errors.Error("disabling plain dns requires at least one encrypted protocol")
+	}
+
+	log.Info("dnsforward: warning: plain dns is disabled")
+
+	return nil
+}
+
 // UpdatedProtectionStatus updates protection state, if the protection was
 // disabled temporarily.  Returns the updated state of protection.
 func (s *Server) UpdatedProtectionStatus() (enabled bool, disabledUntil *time.Time) {
diff --git a/internal/dnsforward/dns64_test.go b/internal/dnsforward/dns64_test.go
index 53a18c4e..55c08db7 100644
--- a/internal/dnsforward/dns64_test.go
+++ b/internal/dnsforward/dns64_test.go
@@ -292,6 +292,7 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
 			Config: Config{
 				EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
 			},
+			ServePlainDNS: true,
 		}, localUps)
 
 		t.Run(tc.name, func(t *testing.T) {
diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go
index a6d6ef2a..d3bca111 100644
--- a/internal/dnsforward/dnsforward.go
+++ b/internal/dnsforward/dnsforward.go
@@ -109,7 +109,7 @@ type Server struct {
 	// stats is the statistics collector for client's DNS usage data.
 	stats stats.Interface
 
-	// access drops unallowed clients.
+	// access drops disallowed clients.
 	access *accessManager
 
 	// localDomainSuffix is the suffix used to detect internal hosts.  It
@@ -232,8 +232,10 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
 	if p.Anonymizer == nil {
 		p.Anonymizer = aghnet.NewIPMut(nil)
 	}
+
 	s = &Server{
 		dnsFilter:   p.DNSFilter,
+		dhcpServer:  p.DHCPServer,
 		stats:       p.Stats,
 		queryLog:    p.QueryLog,
 		privateNets: p.PrivateNets,
@@ -246,6 +248,9 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
 			MaxCount:  defaultClientIDCacheCount,
 		}),
 		anonymizer: p.Anonymizer,
+		conf: ServerConfig{
+			ServePlainDNS: true,
+		},
 	}
 
 	s.sysResolvers, err = sysresolv.NewSystemResolvers(nil, defaultPlainDNSPort)
@@ -253,8 +258,6 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
 		return nil, fmt.Errorf("initializing system resolvers: %w", err)
 	}
 
-	s.dhcpServer = p.DHCPServer
-
 	if runtime.GOARCH == "mips" || runtime.GOARCH == "mipsle" {
 		// Use plain DNS on MIPS, encryption is too slow
 		defaultDNS = defaultBootstrap
@@ -540,7 +543,7 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
 		return err
 	}
 
-	proxyConfig, err := s.createProxyConfig()
+	proxyConfig, err := s.newProxyConfig()
 	if err != nil {
 		return fmt.Errorf("preparing proxy: %w", err)
 	}
@@ -559,7 +562,7 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
 	// Set the proxy here because [setupLocalResolvers] sets its values.
 	//
 	// TODO(e.burkov):  Remove once the local resolvers logic moved to dnsproxy.
-	s.dnsProxy = &proxy.Proxy{Config: proxyConfig}
+	s.dnsProxy = &proxy.Proxy{Config: *proxyConfig}
 
 	err = s.setupLocalResolvers(boot)
 	if err != nil {
diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go
index c74d97f7..5353215c 100644
--- a/internal/dnsforward/dnsforward_test.go
+++ b/internal/dnsforward/dnsforward_test.go
@@ -182,6 +182,7 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte)
 		Config: Config{
 			EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
 		},
+		ServePlainDNS: true,
 	}, nil)
 
 	tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem
@@ -309,6 +310,7 @@ func TestServer(t *testing.T) {
 		Config: Config{
 			EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
 		},
+		ServePlainDNS: true,
 	}, nil)
 	s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
 	startDeferStop(t, s)
@@ -347,6 +349,7 @@ func TestServer_timeout(t *testing.T) {
 			Config: Config{
 				EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
 			},
+			ServePlainDNS: true,
 		}
 
 		s, err := NewServer(DNSCreateParams{DNSFilter: createTestDNSFilter(t)})
@@ -381,6 +384,7 @@ func TestServer_Prepare_fallbacks(t *testing.T) {
 			},
 			EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
 		},
+		ServePlainDNS: true,
 	}
 
 	s, err := NewServer(DNSCreateParams{})
@@ -402,6 +406,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
 		Config: Config{
 			EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
 		},
+		ServePlainDNS: true,
 	}, nil)
 	s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
 	startDeferStop(t, s)
@@ -479,6 +484,7 @@ func TestServerRace(t *testing.T) {
 			UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
 		},
 		ConfigModified: func() {},
+		ServePlainDNS:  true,
 	}
 	s := createTestServer(t, filterConf, forwardConf, nil)
 	s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
@@ -532,6 +538,7 @@ func TestSafeSearch(t *testing.T) {
 				Enabled: false,
 			},
 		},
+		ServePlainDNS: true,
 	}
 	s := createTestServer(t, filterConf, forwardConf, nil)
 	startDeferStop(t, s)
@@ -594,6 +601,7 @@ func TestInvalidRequest(t *testing.T) {
 				Enabled: false,
 			},
 		},
+		ServePlainDNS: true,
 	}, nil)
 	startDeferStop(t, s)
 
@@ -622,6 +630,7 @@ func TestBlockedRequest(t *testing.T) {
 				Enabled: false,
 			},
 		},
+		ServePlainDNS: true,
 	}
 	s := createTestServer(t, &filtering.Config{
 		ProtectionEnabled: true,
@@ -657,6 +666,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
 				Enabled: false,
 			},
 		},
+		ServePlainDNS: true,
 	}
 	s := createTestServer(t, &filtering.Config{
 		BlockingMode: filtering.BlockingModeDefault,
@@ -733,6 +743,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
 				Enabled: false,
 			},
 		},
+		ServePlainDNS: true,
 	}, nil)
 	testUpstm := &aghtest.Upstream{
 		CName: testCNAMEs,
@@ -765,6 +776,7 @@ func TestBlockCNAME(t *testing.T) {
 				Enabled: false,
 			},
 		},
+		ServePlainDNS: true,
 	}
 	s := createTestServer(t, &filtering.Config{
 		ProtectionEnabled: true,
@@ -839,6 +851,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
 				Enabled: false,
 			},
 		},
+		ServePlainDNS: true,
 	}
 	s := createTestServer(t, &filtering.Config{
 		BlockingMode: filtering.BlockingModeDefault,
@@ -883,6 +896,7 @@ func TestNullBlockedRequest(t *testing.T) {
 				Enabled: false,
 			},
 		},
+		ServePlainDNS: true,
 	}
 	s := createTestServer(t, &filtering.Config{
 		ProtectionEnabled: true,
@@ -948,6 +962,7 @@ func TestBlockedCustomIP(t *testing.T) {
 				Enabled: false,
 			},
 		},
+		ServePlainDNS: true,
 	}
 
 	// Invalid BlockingIPv4.
@@ -999,6 +1014,7 @@ func TestBlockedByHosts(t *testing.T) {
 				Enabled: false,
 			},
 		},
+		ServePlainDNS: true,
 	}
 
 	s := createTestServer(t, &filtering.Config{
@@ -1049,6 +1065,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
 				Enabled: false,
 			},
 		},
+		ServePlainDNS: true,
 	}
 	s := createTestServer(t, filterConf, forwardConf, nil)
 	startDeferStop(t, s)
@@ -1107,6 +1124,7 @@ func TestRewrite(t *testing.T) {
 				Enabled: false,
 			},
 		},
+		ServePlainDNS: true,
 	}))
 
 	ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
diff --git a/internal/dnsforward/dnsrewrite_test.go b/internal/dnsforward/dnsrewrite_test.go
index 79aecdef..1022388f 100644
--- a/internal/dnsforward/dnsrewrite_test.go
+++ b/internal/dnsforward/dnsrewrite_test.go
@@ -40,6 +40,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
 		Config: Config{
 			EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
 		},
+		ServePlainDNS: true,
 	}, nil)
 
 	makeQ := func(qtype rules.RRType) (req *dns.Msg) {
diff --git a/internal/dnsforward/filter_test.go b/internal/dnsforward/filter_test.go
index fe64cdf0..961ddcf7 100644
--- a/internal/dnsforward/filter_test.go
+++ b/internal/dnsforward/filter_test.go
@@ -35,6 +35,7 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) {
 				Enabled: false,
 			},
 		},
+		ServePlainDNS: true,
 	}
 	filters := []filtering.Filter{{
 		ID: 0, Data: []byte(rules),
diff --git a/internal/dnsforward/http_test.go b/internal/dnsforward/http_test.go
index 99b03786..b16c26df 100644
--- a/internal/dnsforward/http_test.go
+++ b/internal/dnsforward/http_test.go
@@ -79,6 +79,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
 			EDNSClientSubnet:       &EDNSClientSubnet{Enabled: false},
 		},
 		ConfigModified: func() {},
+		ServePlainDNS:  true,
 	}
 	s := createTestServer(t, filterConf, forwardConf, nil)
 	s.sysResolvers = &emptySysResolvers{}
@@ -158,6 +159,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
 			EDNSClientSubnet:       &EDNSClientSubnet{Enabled: false},
 		},
 		ConfigModified: func() {},
+		ServePlainDNS:  true,
 	}
 	s := createTestServer(t, filterConf, forwardConf, nil)
 	s.sysResolvers = &emptySysResolvers{}
@@ -533,6 +535,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
 		Config: Config{
 			EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
 		},
+		ServePlainDNS: true,
 	}, nil)
 	srv.etcHosts = hc
 	startDeferStop(t, srv)
diff --git a/internal/dnsforward/process_internal_test.go b/internal/dnsforward/process_internal_test.go
index 168a97a1..a6977247 100644
--- a/internal/dnsforward/process_internal_test.go
+++ b/internal/dnsforward/process_internal_test.go
@@ -81,6 +81,7 @@ func TestServer_ProcessInitial(t *testing.T) {
 					AAAADisabled:     tc.aaaaDisabled,
 					EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
 				},
+				ServePlainDNS: true,
 			}
 
 			s := createTestServer(t, &filtering.Config{
@@ -180,6 +181,7 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
 					AAAADisabled:     tc.aaaaDisabled,
 					EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
 				},
+				ServePlainDNS: true,
 			}
 
 			s := createTestServer(t, &filtering.Config{
@@ -369,6 +371,7 @@ func prepareTestServer(t *testing.T, portDoH, portDoT, portDoQ int, ddrEnabled b
 			TLSConfig: TLSConfig{
 				ServerName: ddrTestDomainName,
 			},
+			ServePlainDNS: true,
 		},
 	}
 
@@ -699,6 +702,7 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
 		Config: Config{
 			EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
 		},
+		ServePlainDNS: true,
 	}, ups)
 	s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ups}
 	startDeferStop(t, s)
@@ -776,6 +780,7 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
 			Config: Config{
 				EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
 			},
+			ServePlainDNS: true,
 		},
 		aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
 			return aghalg.Coalesce(
diff --git a/internal/dnsforward/svcbmsg_test.go b/internal/dnsforward/svcbmsg_test.go
index 83645b32..58275ef4 100644
--- a/internal/dnsforward/svcbmsg_test.go
+++ b/internal/dnsforward/svcbmsg_test.go
@@ -19,6 +19,7 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
 		Config: Config{
 			EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
 		},
+		ServePlainDNS: true,
 	}, nil)
 
 	req := &dns.Msg{
diff --git a/internal/home/config.go b/internal/home/config.go
index 92e78697..e78bf8e9 100644
--- a/internal/home/config.go
+++ b/internal/home/config.go
@@ -228,6 +228,9 @@ type dnsConfig struct {
 	// TODO(a.garipov): Add to the UI when HTTP/3 support is no longer
 	// experimental.
 	UseHTTP3Upstreams bool `yaml:"use_http3_upstreams"`
+
+	// ServePlainDNS defines if plain DNS is allowed for incoming requests.
+	ServePlainDNS bool `yaml:"serve_plain_dns"`
 }
 
 type tlsConfigSettings struct {
@@ -335,6 +338,7 @@ var config = &configuration{
 		},
 		UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
 		UsePrivateRDNS:  true,
+		ServePlainDNS:   true,
 	},
 	TLS: tlsConfigSettings{
 		PortHTTPS:       defaultPortHTTPS,
diff --git a/internal/home/dns.go b/internal/home/dns.go
index 5c6ad2cc..592e84f2 100644
--- a/internal/home/dns.go
+++ b/internal/home/dns.go
@@ -142,9 +142,12 @@ func initDNSServer(
 		EtcHosts:    Context.etcHosts,
 		LocalDomain: config.DHCP.LocalDomainName,
 	})
+	defer func() {
+		if err != nil {
+			closeDNSServer()
+		}
+	}()
 	if err != nil {
-		closeDNSServer()
-
 		return fmt.Errorf("dnsforward.NewServer: %w", err)
 	}
 
@@ -152,15 +155,11 @@ func initDNSServer(
 
 	dnsConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpReg)
 	if err != nil {
-		closeDNSServer()
-
 		return fmt.Errorf("newServerConfig: %w", err)
 	}
 
 	err = Context.dnsServer.Prepare(dnsConf)
 	if err != nil {
-		closeDNSServer()
-
 		return fmt.Errorf("dnsServer.Prepare: %w", err)
 	}
 
@@ -253,6 +252,7 @@ func newServerConfig(
 		UsePrivateRDNS:         dnsConf.UsePrivateRDNS,
 		ServeHTTP3:             dnsConf.ServeHTTP3,
 		UseHTTP3Upstreams:      dnsConf.UseHTTP3Upstreams,
+		ServePlainDNS:          dnsConf.ServePlainDNS,
 	}
 
 	var initialAddresses []netip.Addr