From a1acfbbae4c222a6edecc705b2d2344d77a9f2be Mon Sep 17 00:00:00 2001
From: Ainar Garipov <a.garipov@adguard.com>
Date: Fri, 14 Oct 2022 19:03:03 +0300
Subject: [PATCH] Pull request: 4925-refactor-tls-vol-1

Merge in DNS/adguard-home from 4925-refactor-tls-vol-1 to master

Squashed commit of the following:

commit ad87b2e93183b28f2e38666cc4267fa8dfd1cca0
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Oct 14 18:49:22 2022 +0300

    all: refactor tls, vol. 1

    Co-Authored-By: Rahul Somasundaram <Rahul.Somasundaram@checkpt.com>
---
 internal/aghtls/aghtls.go                     | 43 +++++++++++++-
 internal/aghtls/aghtls_test.go                | 57 +++++++++++++++++++
 internal/aghtls/root.go                       | 14 +++++
 internal/aghtls/root_linux.go                 | 56 ++++++++++++++++++
 internal/aghtls/root_others.go                |  9 +++
 internal/home/home.go                         |  2 +-
 internal/home/tls.go                          | 45 ---------------
 .../{control_test.go => tls_internal_test.go} |  0
 8 files changed, 179 insertions(+), 47 deletions(-)
 create mode 100644 internal/aghtls/aghtls_test.go
 create mode 100644 internal/aghtls/root.go
 create mode 100644 internal/aghtls/root_linux.go
 create mode 100644 internal/aghtls/root_others.go
 rename internal/home/{control_test.go => tls_internal_test.go} (100%)

diff --git a/internal/aghtls/aghtls.go b/internal/aghtls/aghtls.go
index 5dc7a382..bcceaad9 100644
--- a/internal/aghtls/aghtls.go
+++ b/internal/aghtls/aghtls.go
@@ -1,7 +1,48 @@
 // Package aghtls contains utilities for work with TLS.
 package aghtls
 
-import "crypto/tls"
+import (
+	"crypto/tls"
+	"fmt"
+
+	"github.com/AdguardTeam/golibs/log"
+)
+
+// init makes sure that the cipher name map is filled.
+//
+// TODO(a.garipov): Propose a similar API to crypto/tls.
+func init() {
+	suites := tls.CipherSuites()
+	cipherSuites = make(map[string]uint16, len(suites))
+	for _, s := range suites {
+		cipherSuites[s.Name] = s.ID
+	}
+
+	log.Debug("tls: known ciphers: %q", cipherSuites)
+}
+
+// cipherSuites are a name-to-ID mapping of cipher suites from crypto/tls.  It
+// is filled by init.  It must not be modified.
+var cipherSuites map[string]uint16
+
+// ParseCiphers parses a slice of cipher suites from cipher names.
+func ParseCiphers(cipherNames []string) (cipherIDs []uint16, err error) {
+	if cipherNames == nil {
+		return nil, nil
+	}
+
+	cipherIDs = make([]uint16, 0, len(cipherNames))
+	for _, name := range cipherNames {
+		id, ok := cipherSuites[name]
+		if !ok {
+			return nil, fmt.Errorf("unknown cipher %q", name)
+		}
+
+		cipherIDs = append(cipherIDs, id)
+	}
+
+	return cipherIDs, nil
+}
 
 // SaferCipherSuites returns a set of default cipher suites with vulnerable and
 // weak cipher suites removed.
diff --git a/internal/aghtls/aghtls_test.go b/internal/aghtls/aghtls_test.go
new file mode 100644
index 00000000..923ff063
--- /dev/null
+++ b/internal/aghtls/aghtls_test.go
@@ -0,0 +1,57 @@
+package aghtls_test
+
+import (
+	"crypto/tls"
+	"testing"
+
+	"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
+	"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
+	"github.com/AdguardTeam/golibs/testutil"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestMain(m *testing.M) {
+	aghtest.DiscardLogOutput(m)
+}
+
+func TestParseCiphers(t *testing.T) {
+	testCases := []struct {
+		name       string
+		wantErrMsg string
+		want       []uint16
+		in         []string
+	}{{
+		name:       "nil",
+		wantErrMsg: "",
+		want:       nil,
+		in:         nil,
+	}, {
+		name:       "empty",
+		wantErrMsg: "",
+		want:       []uint16{},
+		in:         []string{},
+	}, {}, {
+		name:       "one",
+		wantErrMsg: "",
+		want:       []uint16{tls.TLS_AES_128_GCM_SHA256},
+		in:         []string{"TLS_AES_128_GCM_SHA256"},
+	}, {
+		name:       "several",
+		wantErrMsg: "",
+		want:       []uint16{tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384},
+		in:         []string{"TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384"},
+	}, {
+		name:       "bad",
+		wantErrMsg: `unknown cipher "bad_cipher"`,
+		want:       nil,
+		in:         []string{"bad_cipher"},
+	}}
+
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			got, err := aghtls.ParseCiphers(tc.in)
+			testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
+			assert.Equal(t, tc.want, got)
+		})
+	}
+}
diff --git a/internal/aghtls/root.go b/internal/aghtls/root.go
new file mode 100644
index 00000000..d81db143
--- /dev/null
+++ b/internal/aghtls/root.go
@@ -0,0 +1,14 @@
+package aghtls
+
+import (
+	"crypto/x509"
+)
+
+// SystemRootCAs tries to load root certificates from the operating system.  It
+// returns nil in case nothing is found so that Go' crypto/x509 can use its
+// default algorithm to find system root CA list.
+//
+// See https://github.com/AdguardTeam/AdGuardHome/issues/1311.
+func SystemRootCAs() (roots *x509.CertPool) {
+	return rootCAs()
+}
diff --git a/internal/aghtls/root_linux.go b/internal/aghtls/root_linux.go
new file mode 100644
index 00000000..0805f198
--- /dev/null
+++ b/internal/aghtls/root_linux.go
@@ -0,0 +1,56 @@
+//go:build linux
+
+package aghtls
+
+import (
+	"crypto/x509"
+	"os"
+	"path/filepath"
+
+	"github.com/AdguardTeam/golibs/errors"
+	"github.com/AdguardTeam/golibs/log"
+)
+
+func rootCAs() (roots *x509.CertPool) {
+	// Directories with the system root certificates, which aren't supported by
+	// Go's crypto/x509.
+	dirs := []string{
+		// Entware.
+		"/opt/etc/ssl/certs",
+	}
+
+	roots = x509.NewCertPool()
+	for _, dir := range dirs {
+		dirEnts, err := os.ReadDir(dir)
+		if err != nil {
+			if errors.Is(err, os.ErrNotExist) {
+				continue
+			}
+
+			// TODO(a.garipov): Improve error handling here and in other places.
+			log.Error("aghtls: opening directory %q: %s", dir, err)
+		}
+
+		var rootsAdded bool
+		for _, de := range dirEnts {
+			var certData []byte
+			rootFile := filepath.Join(dir, de.Name())
+			certData, err = os.ReadFile(rootFile)
+			if err != nil {
+				log.Error("aghtls: reading root cert: %s", err)
+			} else {
+				if roots.AppendCertsFromPEM(certData) {
+					rootsAdded = true
+				} else {
+					log.Error("aghtls: could not add root from %q", rootFile)
+				}
+			}
+		}
+
+		if rootsAdded {
+			return roots
+		}
+	}
+
+	return nil
+}
diff --git a/internal/aghtls/root_others.go b/internal/aghtls/root_others.go
new file mode 100644
index 00000000..38a50630
--- /dev/null
+++ b/internal/aghtls/root_others.go
@@ -0,0 +1,9 @@
+//go:build !linux
+
+package aghtls
+
+import "crypto/x509"
+
+func rootCAs() (roots *x509.CertPool) {
+	return nil
+}
diff --git a/internal/home/home.go b/internal/home/home.go
index adbf7bfb..dbe34f4e 100644
--- a/internal/home/home.go
+++ b/internal/home/home.go
@@ -147,7 +147,7 @@ func setupContext(opts options) {
 		// Go on.
 	}
 
-	Context.tlsRoots = LoadSystemRootCAs()
+	Context.tlsRoots = aghtls.SystemRootCAs()
 	Context.transport = &http.Transport{
 		DialContext: customDialContext,
 		Proxy:       getHTTPProxy,
diff --git a/internal/home/tls.go b/internal/home/tls.go
index a5089bd8..e7e8e35f 100644
--- a/internal/home/tls.go
+++ b/internal/home/tls.go
@@ -14,8 +14,6 @@ import (
 	"fmt"
 	"net/http"
 	"os"
-	"path/filepath"
-	"runtime"
 	"strings"
 	"sync"
 	"time"
@@ -699,46 +697,3 @@ func (t *TLSMod) registerWebHandlers() {
 	httpRegister(http.MethodPost, "/control/tls/configure", t.handleTLSConfigure)
 	httpRegister(http.MethodPost, "/control/tls/validate", t.handleTLSValidate)
 }
-
-// LoadSystemRootCAs tries to load root certificates from the operating system.
-// It returns nil in case nothing is found so that that Go.crypto will use it's
-// default algorithm to find system root CA list.
-//
-// See https://github.com/AdguardTeam/AdGuardHome/internal/issues/1311.
-func LoadSystemRootCAs() (roots *x509.CertPool) {
-	// TODO(e.burkov): Use build tags instead.
-	if runtime.GOOS != "linux" {
-		return nil
-	}
-
-	// Directories with the system root certificates, which aren't supported
-	// by Go.crypto.
-	dirs := []string{
-		// Entware.
-		"/opt/etc/ssl/certs",
-	}
-	roots = x509.NewCertPool()
-	for _, dir := range dirs {
-		dirEnts, err := os.ReadDir(dir)
-		if errors.Is(err, os.ErrNotExist) {
-			continue
-		} else if err != nil {
-			log.Error("opening directory: %q: %s", dir, err)
-		}
-
-		var rootsAdded bool
-		for _, de := range dirEnts {
-			var certData []byte
-			certData, err = os.ReadFile(filepath.Join(dir, de.Name()))
-			if err == nil && roots.AppendCertsFromPEM(certData) {
-				rootsAdded = true
-			}
-		}
-
-		if rootsAdded {
-			return roots
-		}
-	}
-
-	return nil
-}
diff --git a/internal/home/control_test.go b/internal/home/tls_internal_test.go
similarity index 100%
rename from internal/home/control_test.go
rename to internal/home/tls_internal_test.go