Pull request: 4925-refactor-tls-vol-1

Merge in DNS/adguard-home from 4925-refactor-tls-vol-1 to master

Squashed commit of the following:

commit ad87b2e93183b28f2e38666cc4267fa8dfd1cca0
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Oct 14 18:49:22 2022 +0300

    all: refactor tls, vol. 1

    Co-Authored-By: Rahul Somasundaram <Rahul.Somasundaram@checkpt.com>
This commit is contained in:
Ainar Garipov 2022-10-14 19:03:03 +03:00
parent 4582b1c919
commit a1acfbbae4
8 changed files with 179 additions and 47 deletions

View file

@ -1,7 +1,48 @@
// Package aghtls contains utilities for work with TLS. // Package aghtls contains utilities for work with TLS.
package aghtls package aghtls
import "crypto/tls" import (
"crypto/tls"
"fmt"
"github.com/AdguardTeam/golibs/log"
)
// init makes sure that the cipher name map is filled.
//
// TODO(a.garipov): Propose a similar API to crypto/tls.
func init() {
suites := tls.CipherSuites()
cipherSuites = make(map[string]uint16, len(suites))
for _, s := range suites {
cipherSuites[s.Name] = s.ID
}
log.Debug("tls: known ciphers: %q", cipherSuites)
}
// cipherSuites are a name-to-ID mapping of cipher suites from crypto/tls. It
// is filled by init. It must not be modified.
var cipherSuites map[string]uint16
// ParseCiphers parses a slice of cipher suites from cipher names.
func ParseCiphers(cipherNames []string) (cipherIDs []uint16, err error) {
if cipherNames == nil {
return nil, nil
}
cipherIDs = make([]uint16, 0, len(cipherNames))
for _, name := range cipherNames {
id, ok := cipherSuites[name]
if !ok {
return nil, fmt.Errorf("unknown cipher %q", name)
}
cipherIDs = append(cipherIDs, id)
}
return cipherIDs, nil
}
// SaferCipherSuites returns a set of default cipher suites with vulnerable and // SaferCipherSuites returns a set of default cipher suites with vulnerable and
// weak cipher suites removed. // weak cipher suites removed.

View file

@ -0,0 +1,57 @@
package aghtls_test
import (
"crypto/tls"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
)
func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
}
func TestParseCiphers(t *testing.T) {
testCases := []struct {
name string
wantErrMsg string
want []uint16
in []string
}{{
name: "nil",
wantErrMsg: "",
want: nil,
in: nil,
}, {
name: "empty",
wantErrMsg: "",
want: []uint16{},
in: []string{},
}, {}, {
name: "one",
wantErrMsg: "",
want: []uint16{tls.TLS_AES_128_GCM_SHA256},
in: []string{"TLS_AES_128_GCM_SHA256"},
}, {
name: "several",
wantErrMsg: "",
want: []uint16{tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384},
in: []string{"TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384"},
}, {
name: "bad",
wantErrMsg: `unknown cipher "bad_cipher"`,
want: nil,
in: []string{"bad_cipher"},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := aghtls.ParseCiphers(tc.in)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.want, got)
})
}
}

14
internal/aghtls/root.go Normal file
View file

@ -0,0 +1,14 @@
package aghtls
import (
"crypto/x509"
)
// SystemRootCAs tries to load root certificates from the operating system. It
// returns nil in case nothing is found so that Go' crypto/x509 can use its
// default algorithm to find system root CA list.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/1311.
func SystemRootCAs() (roots *x509.CertPool) {
return rootCAs()
}

View file

@ -0,0 +1,56 @@
//go:build linux
package aghtls
import (
"crypto/x509"
"os"
"path/filepath"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
func rootCAs() (roots *x509.CertPool) {
// Directories with the system root certificates, which aren't supported by
// Go's crypto/x509.
dirs := []string{
// Entware.
"/opt/etc/ssl/certs",
}
roots = x509.NewCertPool()
for _, dir := range dirs {
dirEnts, err := os.ReadDir(dir)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
continue
}
// TODO(a.garipov): Improve error handling here and in other places.
log.Error("aghtls: opening directory %q: %s", dir, err)
}
var rootsAdded bool
for _, de := range dirEnts {
var certData []byte
rootFile := filepath.Join(dir, de.Name())
certData, err = os.ReadFile(rootFile)
if err != nil {
log.Error("aghtls: reading root cert: %s", err)
} else {
if roots.AppendCertsFromPEM(certData) {
rootsAdded = true
} else {
log.Error("aghtls: could not add root from %q", rootFile)
}
}
}
if rootsAdded {
return roots
}
}
return nil
}

View file

@ -0,0 +1,9 @@
//go:build !linux
package aghtls
import "crypto/x509"
func rootCAs() (roots *x509.CertPool) {
return nil
}

View file

@ -147,7 +147,7 @@ func setupContext(opts options) {
// Go on. // Go on.
} }
Context.tlsRoots = LoadSystemRootCAs() Context.tlsRoots = aghtls.SystemRootCAs()
Context.transport = &http.Transport{ Context.transport = &http.Transport{
DialContext: customDialContext, DialContext: customDialContext,
Proxy: getHTTPProxy, Proxy: getHTTPProxy,

View file

@ -14,8 +14,6 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
"path/filepath"
"runtime"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -699,46 +697,3 @@ func (t *TLSMod) registerWebHandlers() {
httpRegister(http.MethodPost, "/control/tls/configure", t.handleTLSConfigure) httpRegister(http.MethodPost, "/control/tls/configure", t.handleTLSConfigure)
httpRegister(http.MethodPost, "/control/tls/validate", t.handleTLSValidate) httpRegister(http.MethodPost, "/control/tls/validate", t.handleTLSValidate)
} }
// LoadSystemRootCAs tries to load root certificates from the operating system.
// It returns nil in case nothing is found so that that Go.crypto will use it's
// default algorithm to find system root CA list.
//
// See https://github.com/AdguardTeam/AdGuardHome/internal/issues/1311.
func LoadSystemRootCAs() (roots *x509.CertPool) {
// TODO(e.burkov): Use build tags instead.
if runtime.GOOS != "linux" {
return nil
}
// Directories with the system root certificates, which aren't supported
// by Go.crypto.
dirs := []string{
// Entware.
"/opt/etc/ssl/certs",
}
roots = x509.NewCertPool()
for _, dir := range dirs {
dirEnts, err := os.ReadDir(dir)
if errors.Is(err, os.ErrNotExist) {
continue
} else if err != nil {
log.Error("opening directory: %q: %s", dir, err)
}
var rootsAdded bool
for _, de := range dirEnts {
var certData []byte
certData, err = os.ReadFile(filepath.Join(dir, de.Name()))
if err == nil && roots.AppendCertsFromPEM(certData) {
rootsAdded = true
}
}
if rootsAdded {
return roots
}
}
return nil
}