From 841bb9bc35786c710414706ef38c2830024fbfed Mon Sep 17 00:00:00 2001
From: Ainar Garipov <a.garipov@adguard.com>
Date: Thu, 11 Feb 2021 15:20:30 +0300
Subject: [PATCH] Pull request: dnsforward: do not check client srv name unless
 asked

Merge in DNS/adguard-home from 2664-non-strict-sni to master

Updates #2664.

Squashed commit of the following:

commit e8d625fe3b1f06f97328809a3330b37e5bd578d7
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Feb 11 14:46:52 2021 +0300

    all: imp doc

commit 10537b8bdf126eca9608353e57d92edba632232a
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Feb 11 14:30:25 2021 +0300

    dnsforward: do not check client srv name unless asked
---
 CHANGELOG.md                                  |   9 +-
 internal/dnsforward/clientid.go               | 165 ++++++++++++++++++
 .../{dns_test.go => clientid_test.go}         |  80 ++++++---
 internal/dnsforward/dns.go                    | 152 ----------------
 4 files changed, 230 insertions(+), 176 deletions(-)
 create mode 100644 internal/dnsforward/clientid.go
 rename internal/dnsforward/{dns_test.go => clientid_test.go} (73%)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 7560957b..bfae45f1 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -19,17 +19,20 @@ and this project adheres to
 
 ### Changed
 
-- Increase the HTTP API request body size limit for the `/control/access/set`
-  API ([#2666]).
+- The server name sent by clients of TLS APIs is not only checked when
+  `strict_sni_check` is enabled ([#2664]).
+- HTTP API request body size limit for the `/control/access/set` API is
+  increased ([#2666]).
 
 ### Fixed
 
-- Set the request body size limit for HTTPS reqeusts as well.
+- The request body size limit is now set for HTTPS requests as well.
 - Incorrect version tag in the Docker release ([#2663]).
 - DNSCrypt queries weren't marked as such in logs ([#2662]).
 
 [#2662]: https://github.com/AdguardTeam/AdGuardHome/issues/2662
 [#2663]: https://github.com/AdguardTeam/AdGuardHome/issues/2663
+[#2664]: https://github.com/AdguardTeam/AdGuardHome/issues/2664
 [#2666]: https://github.com/AdguardTeam/AdGuardHome/issues/2666
 
 
diff --git a/internal/dnsforward/clientid.go b/internal/dnsforward/clientid.go
new file mode 100644
index 00000000..c497c7b7
--- /dev/null
+++ b/internal/dnsforward/clientid.go
@@ -0,0 +1,165 @@
+package dnsforward
+
+import (
+	"crypto/tls"
+	"fmt"
+	"path"
+	"strings"
+
+	"github.com/AdguardTeam/dnsproxy/proxy"
+	"github.com/lucas-clemente/quic-go"
+)
+
+const maxDomainPartLen = 64
+
+// ValidateClientID returns an error if clientID is not a valid client ID.
+func ValidateClientID(clientID string) (err error) {
+	if len(clientID) > maxDomainPartLen {
+		return fmt.Errorf("client id %q is too long, max: %d", clientID, maxDomainPartLen)
+	}
+
+	for i, r := range clientID {
+		if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' {
+			continue
+		}
+
+		return fmt.Errorf("invalid char %q at index %d in client id %q", r, i, clientID)
+	}
+
+	return nil
+}
+
+// clientIDFromClientServerName extracts and validates a client ID.  hostSrvName
+// is the server name of the host.  cliSrvName is the server name as sent by the
+// client.  When strict is true, and client and host server name don't match,
+// clientIDFromClientServerName will return an error.
+func clientIDFromClientServerName(hostSrvName, cliSrvName string, strict bool) (clientID string, err error) {
+	if hostSrvName == cliSrvName {
+		return "", nil
+	}
+
+	if !strings.HasSuffix(cliSrvName, hostSrvName) {
+		if !strict {
+			return "", nil
+		}
+
+		return "", fmt.Errorf("client server name %q doesn't match host server name %q", cliSrvName, hostSrvName)
+	}
+
+	clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1]
+	err = ValidateClientID(clientID)
+	if err != nil {
+		return "", fmt.Errorf("invalid client id: %w", err)
+	}
+
+	return clientID, nil
+}
+
+// processClientIDHTTPS extracts the client's ID from the path of the
+// client's DNS-over-HTTPS request.
+func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) {
+	pctx := ctx.proxyCtx
+	r := pctx.HTTPRequest
+	if r == nil {
+		ctx.err = fmt.Errorf("proxy ctx http request of proto %s is nil", pctx.Proto)
+
+		return resultCodeError
+	}
+
+	origPath := r.URL.Path
+	parts := strings.Split(path.Clean(origPath), "/")
+	if parts[0] == "" {
+		parts = parts[1:]
+	}
+
+	if len(parts) == 0 || parts[0] != "dns-query" {
+		ctx.err = fmt.Errorf("client id check: invalid path %q", origPath)
+
+		return resultCodeError
+	}
+
+	clientID := ""
+	switch len(parts) {
+	case 1:
+		// Just /dns-query, no client ID.
+		return resultCodeSuccess
+	case 2:
+		clientID = parts[1]
+	default:
+		ctx.err = fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
+
+		return resultCodeError
+	}
+
+	err := ValidateClientID(clientID)
+	if err != nil {
+		ctx.err = fmt.Errorf("client id check: invalid client id: %w", err)
+
+		return resultCodeError
+	}
+
+	ctx.clientID = clientID
+
+	return resultCodeSuccess
+}
+
+// tlsConn is a narrow interface for *tls.Conn to simplify testing.
+type tlsConn interface {
+	ConnectionState() (cs tls.ConnectionState)
+}
+
+// quicSession is a narrow interface for quic.Session to simplify testing.
+type quicSession interface {
+	ConnectionState() (cs quic.ConnectionState)
+}
+
+// processClientID extracts the client's ID from the server name of the client's
+// DOT or DOQ request or the path of the client's DOH.
+func processClientID(dctx *dnsContext) (rc resultCode) {
+	pctx := dctx.proxyCtx
+	proto := pctx.Proto
+	if proto == proxy.ProtoHTTPS {
+		return processClientIDHTTPS(dctx)
+	} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
+		return resultCodeSuccess
+	}
+
+	srvConf := dctx.srv.conf
+	hostSrvName := srvConf.TLSConfig.ServerName
+	if hostSrvName == "" {
+		return resultCodeSuccess
+	}
+
+	cliSrvName := ""
+	if proto == proxy.ProtoTLS {
+		conn := pctx.Conn
+		tc, ok := conn.(tlsConn)
+		if !ok {
+			dctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn)
+
+			return resultCodeError
+		}
+
+		cliSrvName = tc.ConnectionState().ServerName
+	} else if proto == proxy.ProtoQUIC {
+		qs, ok := pctx.QUICSession.(quicSession)
+		if !ok {
+			dctx.err = fmt.Errorf("proxy ctx quic session of proto %s is %T, want quic.Session", proto, pctx.QUICSession)
+
+			return resultCodeError
+		}
+
+		cliSrvName = qs.ConnectionState().ServerName
+	}
+
+	clientID, err := clientIDFromClientServerName(hostSrvName, cliSrvName, srvConf.StrictSNICheck)
+	if err != nil {
+		dctx.err = fmt.Errorf("client id check: %w", err)
+
+		return resultCodeError
+	}
+
+	dctx.clientID = clientID
+
+	return resultCodeSuccess
+}
diff --git a/internal/dnsforward/dns_test.go b/internal/dnsforward/clientid_test.go
similarity index 73%
rename from internal/dnsforward/dns_test.go
rename to internal/dnsforward/clientid_test.go
index bd0ef4ab..503203f9 100644
--- a/internal/dnsforward/dns_test.go
+++ b/internal/dnsforward/clientid_test.go
@@ -10,6 +10,7 @@ import (
 	"github.com/AdguardTeam/dnsproxy/proxy"
 	"github.com/lucas-clemente/quic-go"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 // testTLSConn is a tlsConn for tests.
@@ -53,6 +54,7 @@ func TestProcessClientID(t *testing.T) {
 		wantClientID string
 		wantErrMsg   string
 		wantRes      resultCode
+		strictSNI    bool
 	}{{
 		name:         "udp",
 		proto:        proxy.ProtoUDP,
@@ -61,6 +63,7 @@ func TestProcessClientID(t *testing.T) {
 		wantClientID: "",
 		wantErrMsg:   "",
 		wantRes:      resultCodeSuccess,
+		strictSNI:    false,
 	}, {
 		name:         "tls_no_client_id",
 		proto:        proxy.ProtoTLS,
@@ -69,6 +72,26 @@ func TestProcessClientID(t *testing.T) {
 		wantClientID: "",
 		wantErrMsg:   "",
 		wantRes:      resultCodeSuccess,
+		strictSNI:    true,
+	}, {
+		name:         "tls_no_client_server_name",
+		proto:        proxy.ProtoTLS,
+		hostSrvName:  "example.com",
+		cliSrvName:   "",
+		wantClientID: "",
+		wantErrMsg: `client id check: client server name "" ` +
+			`doesn't match host server name "example.com"`,
+		wantRes:   resultCodeError,
+		strictSNI: true,
+	}, {
+		name:         "tls_no_client_server_name_no_strict",
+		proto:        proxy.ProtoTLS,
+		hostSrvName:  "example.com",
+		cliSrvName:   "",
+		wantClientID: "",
+		wantErrMsg:   "",
+		wantRes:      resultCodeSuccess,
+		strictSNI:    false,
 	}, {
 		name:         "tls_client_id",
 		proto:        proxy.ProtoTLS,
@@ -77,30 +100,39 @@ func TestProcessClientID(t *testing.T) {
 		wantClientID: "cli",
 		wantErrMsg:   "",
 		wantRes:      resultCodeSuccess,
+		strictSNI:    true,
 	}, {
 		name:         "tls_client_id_hostname_error",
 		proto:        proxy.ProtoTLS,
 		hostSrvName:  "example.com",
 		cliSrvName:   "cli.example.net",
 		wantClientID: "",
-		wantErrMsg:   `client id check: client server name "cli.example.net" doesn't match host server name "example.com"`,
-		wantRes:      resultCodeError,
+		wantErrMsg: `client id check: client server name "cli.example.net" ` +
+			`doesn't match host server name "example.com"`,
+		wantRes:   resultCodeError,
+		strictSNI: true,
 	}, {
 		name:         "tls_invalid_client_id",
 		proto:        proxy.ProtoTLS,
 		hostSrvName:  "example.com",
 		cliSrvName:   "!!!.example.com",
 		wantClientID: "",
-		wantErrMsg:   `client id check: invalid client id: invalid char '!' at index 0 in client id "!!!"`,
-		wantRes:      resultCodeError,
+		wantErrMsg: `client id check: invalid client id: invalid char '!' ` +
+			`at index 0 in client id "!!!"`,
+		wantRes:   resultCodeError,
+		strictSNI: true,
 	}, {
-		name:         "tls_client_id_too_long",
-		proto:        proxy.ProtoTLS,
-		hostSrvName:  "example.com",
-		cliSrvName:   "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789.example.com",
+		name:        "tls_client_id_too_long",
+		proto:       proxy.ProtoTLS,
+		hostSrvName: "example.com",
+		cliSrvName: `abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmno` +
+			`pqrstuvwxyz0123456789.example.com`,
 		wantClientID: "",
-		wantErrMsg:   `client id check: invalid client id: client id "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789" is too long, max: 64`,
-		wantRes:      resultCodeError,
+		wantErrMsg: `client id check: invalid client id: client id "abcdefghijklmno` +
+			`pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789" ` +
+			`is too long, max: 64`,
+		wantRes:   resultCodeError,
+		strictSNI: true,
 	}, {
 		name:         "quic_client_id",
 		proto:        proxy.ProtoQUIC,
@@ -109,14 +141,17 @@ func TestProcessClientID(t *testing.T) {
 		wantClientID: "cli",
 		wantErrMsg:   "",
 		wantRes:      resultCodeSuccess,
+		strictSNI:    true,
 	}}
 
 	for _, tc := range testCases {
 		t.Run(tc.name, func(t *testing.T) {
+			tlsConf := TLSConfig{
+				ServerName:     tc.hostSrvName,
+				StrictSNICheck: tc.strictSNI,
+			}
 			srv := &Server{
-				conf: ServerConfig{
-					TLSConfig: TLSConfig{ServerName: tc.hostSrvName},
-				},
+				conf: ServerConfig{TLSConfig: tlsConf},
 			}
 
 			var conn net.Conn
@@ -146,10 +181,11 @@ func TestProcessClientID(t *testing.T) {
 			assert.Equal(t, tc.wantRes, res)
 			assert.Equal(t, tc.wantClientID, dctx.clientID)
 
-			if tc.wantErrMsg != "" && assert.NotNil(t, dctx.err) {
-				assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
-			} else {
+			if tc.wantErrMsg == "" {
 				assert.Nil(t, dctx.err)
+			} else {
+				require.NotNil(t, dctx.err)
+				assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
 			}
 		})
 	}
@@ -202,8 +238,9 @@ func TestProcessClientID_https(t *testing.T) {
 		name:         "invalid_client_id",
 		path:         "/dns-query/!!!",
 		wantClientID: "",
-		wantErrMsg:   `client id check: invalid client id: invalid char '!' at index 0 in client id "!!!"`,
-		wantRes:      resultCodeError,
+		wantErrMsg: `client id check: invalid client id: invalid char '!'` +
+			` at index 0 in client id "!!!"`,
+		wantRes: resultCodeError,
 	}}
 
 	for _, tc := range testCases {
@@ -225,10 +262,11 @@ func TestProcessClientID_https(t *testing.T) {
 			assert.Equal(t, tc.wantRes, res)
 			assert.Equal(t, tc.wantClientID, dctx.clientID)
 
-			if tc.wantErrMsg != "" && assert.NotNil(t, dctx.err) {
-				assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
-			} else {
+			if tc.wantErrMsg == "" {
 				assert.Nil(t, dctx.err)
+			} else {
+				require.NotNil(t, dctx.err)
+				assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
 			}
 		})
 	}
diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go
index f8e7bff0..acc6aa86 100644
--- a/internal/dnsforward/dns.go
+++ b/internal/dnsforward/dns.go
@@ -1,10 +1,7 @@
 package dnsforward
 
 import (
-	"crypto/tls"
-	"fmt"
 	"net"
-	"path"
 	"strings"
 	"time"
 
@@ -13,7 +10,6 @@ import (
 	"github.com/AdguardTeam/AdGuardHome/internal/util"
 	"github.com/AdguardTeam/dnsproxy/proxy"
 	"github.com/AdguardTeam/golibs/log"
-	"github.com/lucas-clemente/quic-go"
 	"github.com/miekg/dns"
 )
 
@@ -234,154 +230,6 @@ func processInternalHosts(ctx *dnsContext) (rc resultCode) {
 	return resultCodeSuccess
 }
 
-const maxDomainPartLen = 64
-
-// ValidateClientID returns an error if clientID is not a valid client ID.
-func ValidateClientID(clientID string) (err error) {
-	if len(clientID) > maxDomainPartLen {
-		return fmt.Errorf("client id %q is too long, max: %d", clientID, maxDomainPartLen)
-	}
-
-	for i, r := range clientID {
-		if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' {
-			continue
-		}
-
-		return fmt.Errorf("invalid char %q at index %d in client id %q", r, i, clientID)
-	}
-
-	return nil
-}
-
-// clientIDFromClientServerName extracts and validates a client ID.  hostSrvName
-// is the server name of the host.  cliSrvName is the server name as sent by the
-// client.
-func clientIDFromClientServerName(hostSrvName, cliSrvName string) (clientID string, err error) {
-	if hostSrvName == cliSrvName {
-		return "", nil
-	}
-
-	if !strings.HasSuffix(cliSrvName, hostSrvName) {
-		return "", fmt.Errorf("client server name %q doesn't match host server name %q", cliSrvName, hostSrvName)
-	}
-
-	clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1]
-	err = ValidateClientID(clientID)
-	if err != nil {
-		return "", fmt.Errorf("invalid client id: %w", err)
-	}
-
-	return clientID, nil
-}
-
-// processClientIDHTTPS extracts the client's ID from the path of the
-// client's DNS-over-HTTPS request.
-func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) {
-	pctx := ctx.proxyCtx
-	r := pctx.HTTPRequest
-	if r == nil {
-		ctx.err = fmt.Errorf("proxy ctx http request of proto %s is nil", pctx.Proto)
-
-		return resultCodeError
-	}
-
-	origPath := r.URL.Path
-	parts := strings.Split(path.Clean(origPath), "/")
-	if parts[0] == "" {
-		parts = parts[1:]
-	}
-
-	if len(parts) == 0 || parts[0] != "dns-query" {
-		ctx.err = fmt.Errorf("client id check: invalid path %q", origPath)
-
-		return resultCodeError
-	}
-
-	clientID := ""
-	switch len(parts) {
-	case 1:
-		// Just /dns-query, no client ID.
-		return resultCodeSuccess
-	case 2:
-		clientID = parts[1]
-	default:
-		ctx.err = fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
-
-		return resultCodeError
-	}
-
-	err := ValidateClientID(clientID)
-	if err != nil {
-		ctx.err = fmt.Errorf("client id check: invalid client id: %w", err)
-
-		return resultCodeError
-	}
-
-	ctx.clientID = clientID
-
-	return resultCodeSuccess
-}
-
-// tlsConn is a narrow interface for *tls.Conn to simplify testing.
-type tlsConn interface {
-	ConnectionState() (cs tls.ConnectionState)
-}
-
-// quicSession is a narrow interface for quic.Session to simplify testing.
-type quicSession interface {
-	ConnectionState() (cs quic.ConnectionState)
-}
-
-// processClientID extracts the client's ID from the server name of the client's
-// DOT or DOQ request or the path of the client's DOH.
-func processClientID(ctx *dnsContext) (rc resultCode) {
-	pctx := ctx.proxyCtx
-	proto := pctx.Proto
-	if proto == proxy.ProtoHTTPS {
-		return processClientIDHTTPS(ctx)
-	} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
-		return resultCodeSuccess
-	}
-
-	hostSrvName := ctx.srv.conf.TLSConfig.ServerName
-	if hostSrvName == "" {
-		return resultCodeSuccess
-	}
-
-	cliSrvName := ""
-	if proto == proxy.ProtoTLS {
-		conn := pctx.Conn
-		tc, ok := conn.(tlsConn)
-		if !ok {
-			ctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn)
-
-			return resultCodeError
-		}
-
-		cliSrvName = tc.ConnectionState().ServerName
-	} else if proto == proxy.ProtoQUIC {
-		qs, ok := pctx.QUICSession.(quicSession)
-		if !ok {
-			ctx.err = fmt.Errorf("proxy ctx quic session of proto %s is %T, want quic.Session", proto, pctx.QUICSession)
-
-			return resultCodeError
-		}
-
-		cliSrvName = qs.ConnectionState().ServerName
-	}
-
-	clientID, err := clientIDFromClientServerName(hostSrvName, cliSrvName)
-	if err != nil {
-		ctx.err = fmt.Errorf("client id check: %w", err)
-
-		return resultCodeError
-	}
-
-	ctx.clientID = clientID
-
-	return resultCodeSuccess
-}
-
 // Respond to PTR requests if the target IP address is leased by our DHCP server
 func processInternalIPAddrs(ctx *dnsContext) (rc resultCode) {
 	s := ctx.srv