Pull request 2365: AGDNS-2714-tls-manager-tests

Merge in DNS/adguard-home from AGDNS-2714-tls-manager-tests to master

Squashed commit of the following:

commit 2a3c6558a4098eb6b531e792884e5ca2bc2dd362
Merge: 85d72559c 1a3853d52
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Mar 17 18:07:49 2025 +0300

    Merge branch 'master' into AGDNS-2714-tls-manager-tests

commit 85d72559c371d4f14b40077d9aec69afa8dc7e73
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Mar 17 17:55:41 2025 +0300

    home: imp tests

commit 9ad19e3cee255b157992e4045f4e27fa5aa54325
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Mar 17 16:21:47 2025 +0300

    home: imp code

commit 8a05bc0199
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Mar 17 15:08:58 2025 +0300

    home: imp tests

commit 85173f986d
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Mar 13 18:18:56 2025 +0300

    home: add tests

commit add531ea17
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Mar 11 19:55:51 2025 +0300

    home: tls manager tests
This commit is contained in:
Stanislav Chzhen 2025-03-17 18:16:33 +03:00
parent 1a3853d52a
commit f82dee17f0
3 changed files with 462 additions and 27 deletions

View file

@ -443,31 +443,6 @@ func startDNSServer() error {
return nil
}
// reconfigureDNSServer updates the DNS server configuration using the provided
// TLS settings. tlsMgr must not be nil.
func reconfigureDNSServer(tlsMgr *tlsManager) (err error) {
tlsConf := &tlsConfigSettings{}
tlsMgr.WriteDiskConfig(tlsConf)
newConf, err := newServerConfig(
&config.DNS,
config.Clients.Sources,
tlsConf,
httpRegister,
globalContext.clients.storage,
)
if err != nil {
return fmt.Errorf("generating forwarding dns server config: %w", err)
}
err = globalContext.dnsServer.Reconfigure(newConf)
if err != nil {
return fmt.Errorf("starting forwarding dns server: %w", err)
}
return nil
}
func stopDNSServer() (err error) {
if !isRunning() {
return nil

View file

@ -153,7 +153,7 @@ func (m *tlsManager) reload() {
m.certLastMod = fi.ModTime().UTC()
_ = reconfigureDNSServer(m)
_ = m.reconfigureDNSServer()
m.confLock.Lock()
tlsConf = m.conf
@ -165,6 +165,31 @@ func (m *tlsManager) reload() {
globalContext.web.tlsConfigChanged(context.Background(), tlsConf)
}
// reconfigureDNSServer updates the DNS server configuration using the stored
// TLS settings.
func (m *tlsManager) reconfigureDNSServer() (err error) {
tlsConf := &tlsConfigSettings{}
m.WriteDiskConfig(tlsConf)
newConf, err := newServerConfig(
&config.DNS,
config.Clients.Sources,
tlsConf,
httpRegister,
globalContext.clients.storage,
)
if err != nil {
return fmt.Errorf("generating forwarding dns server config: %w", err)
}
err = globalContext.dnsServer.Reconfigure(newConf)
if err != nil {
return fmt.Errorf("starting forwarding dns server: %w", err)
}
return nil
}
// loadTLSConf loads and validates the TLS configuration. The returned error is
// also set in status.WarningValidation.
func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error) {
@ -442,7 +467,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
onConfigModified()
err = reconfigureDNSServer(m)
err = m.reconfigureDNSServer()
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)

View file

@ -1,11 +1,33 @@
package home
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"net"
"net/http"
"net/http/httptest"
"net/netip"
"os"
"path/filepath"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var testCertChainData = []byte(`-----BEGIN CERTIFICATE-----
@ -75,3 +97,416 @@ func TestValidateCertificates(t *testing.T) {
assert.True(t, status.ValidPair)
})
}
// storeGlobals is a test helper function that saves global variables and
// restores them once the test is complete.
//
// The global variables are:
// - [configuration.dns]
// - [homeContext.clients.storage]
// - [homeContext.dnsServer]
// - [homeContext.mux]
// - [homeContext.web]
//
// TODO(s.chzhen): Remove this once the TLS manager no longer accesses global
// variables. Make tests that use this helper concurrent.
func storeGlobals(tb testing.TB) {
tb.Helper()
prevConfig := config
storage := globalContext.clients.storage
dnsServer := globalContext.dnsServer
mux := globalContext.mux
web := globalContext.web
tb.Cleanup(func() {
config = prevConfig
globalContext.clients.storage = storage
globalContext.dnsServer = dnsServer
globalContext.mux = mux
globalContext.web = web
})
}
// newCertAndKey is a helper function that generates certificate and key.
func newCertAndKey(tb testing.TB, n int64) (certDER []byte, key *rsa.PrivateKey) {
tb.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(tb, err)
certTmpl := &x509.Certificate{
SerialNumber: big.NewInt(n),
}
certDER, err = x509.CreateCertificate(rand.Reader, certTmpl, certTmpl, &key.PublicKey, key)
require.NoError(tb, err)
return certDER, key
}
// writeCertAndKey is a helper function that writes certificate and key to
// specified paths.
func writeCertAndKey(
tb testing.TB,
certDER []byte,
certPath string,
key *rsa.PrivateKey,
keyPath string,
) {
tb.Helper()
certFile, err := os.OpenFile(certPath, os.O_WRONLY|os.O_CREATE, 0o600)
require.NoError(tb, err)
defer func() {
err = certFile.Close()
require.NoError(tb, err)
}()
err = pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
require.NoError(tb, err)
keyFile, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE, 0o600)
require.NoError(tb, err)
defer func() {
err = keyFile.Close()
require.NoError(tb, err)
}()
err = pem.Encode(keyFile, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
require.NoError(tb, err)
}
// assertCertSerialNumber is a helper function that checks serial number of the
// TLS certificate.
func assertCertSerialNumber(tb testing.TB, conf *tlsConfigSettings, wantSN int64) {
tb.Helper()
cert, err := tls.X509KeyPair(conf.CertificateChainData, conf.PrivateKeyData)
require.NoError(tb, err)
assert.Equal(tb, wantSN, cert.Leaf.SerialNumber.Int64())
}
func TestTLSManager_Reload(t *testing.T) {
storeGlobals(t)
var (
logger = slogutil.NewDiscardLogger()
ctx = testutil.ContextWithTimeout(t, testTimeout)
err error
)
globalContext.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{
Logger: logger,
})
require.NoError(t, err)
globalContext.clients.storage, err = client.NewStorage(ctx, &client.StorageConfig{
Logger: logger,
Clock: timeutil.SystemClock{},
})
require.NoError(t, err)
globalContext.mux = http.NewServeMux()
globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false)
require.NoError(t, err)
const (
snBefore int64 = 1
snAfter int64 = 2
)
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
certDER, key := newCertAndKey(t, snBefore)
writeCertAndKey(t, certDER, certPath, key, keyPath)
m, err := newTLSManager(tlsConfigSettings{
Enabled: true,
TLSConfig: dnsforward.TLSConfig{
CertificatePath: certPath,
PrivateKeyPath: keyPath,
},
}, false)
require.NoError(t, err)
conf := &tlsConfigSettings{}
m.WriteDiskConfig(conf)
assertCertSerialNumber(t, conf, snBefore)
certDER, key = newCertAndKey(t, snAfter)
writeCertAndKey(t, certDER, certPath, key, keyPath)
m.reload()
m.WriteDiskConfig(conf)
assertCertSerialNumber(t, conf, snAfter)
}
func TestTLSManager_HandleTLSStatus(t *testing.T) {
m, err := newTLSManager(tlsConfigSettings{
Enabled: true,
TLSConfig: dnsforward.TLSConfig{
CertificateChain: string(testCertChainData),
PrivateKey: string(testPrivateKeyData),
},
}, false)
require.NoError(t, err)
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/control/tls/status", nil)
m.handleTLSStatus(w, r)
res := &tlsConfigSettingsExt{}
err = json.NewDecoder(w.Body).Decode(res)
require.NoError(t, err)
wantCertificateChain := base64.StdEncoding.EncodeToString(testCertChainData)
assert.True(t, res.Enabled)
assert.Equal(t, wantCertificateChain, res.CertificateChain)
assert.True(t, res.PrivateKeySaved)
}
func TestValidateTLSSettings(t *testing.T) {
storeGlobals(t)
var (
logger = slogutil.NewDiscardLogger()
ctx = testutil.ContextWithTimeout(t, testTimeout)
err error
)
ln, err := net.Listen("tcp", ":0")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, ln.Close)
addr := testutil.RequireTypeAssert[*net.TCPAddr](t, ln.Addr())
busyPort := addr.Port
globalContext.mux = http.NewServeMux()
globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false)
require.NoError(t, err)
testCases := []struct {
setts tlsConfigSettingsExt
name string
wantErr string
}{{
name: "basic",
setts: tlsConfigSettingsExt{},
wantErr: "",
}, {
setts: tlsConfigSettingsExt{
ServePlainDNS: aghalg.NBFalse,
},
name: "disabled_all",
wantErr: "plain DNS is required in case encryption protocols are disabled",
}, {
setts: tlsConfigSettingsExt{
tlsConfigSettings: tlsConfigSettings{
Enabled: true,
PortHTTPS: uint16(busyPort),
},
},
name: "busy_port",
wantErr: fmt.Sprintf("port %d is not available, cannot enable HTTPS on it", busyPort),
}, {
setts: tlsConfigSettingsExt{
tlsConfigSettings: tlsConfigSettings{
Enabled: true,
PortHTTPS: 4433,
PortDNSOverTLS: 4433,
},
},
name: "duplicate_port",
wantErr: "validating tcp ports: duplicated values: [4433]",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err = validateTLSSettings(tc.setts)
testutil.AssertErrorMsg(t, tc.wantErr, err)
})
}
}
func TestTLSManager_HandleTLSValidate(t *testing.T) {
storeGlobals(t)
var (
logger = slogutil.NewDiscardLogger()
ctx = testutil.ContextWithTimeout(t, testTimeout)
err error
)
globalContext.mux = http.NewServeMux()
globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false)
require.NoError(t, err)
m, err := newTLSManager(tlsConfigSettings{
Enabled: true,
TLSConfig: dnsforward.TLSConfig{
CertificateChain: string(testCertChainData),
PrivateKey: string(testPrivateKeyData),
},
}, false)
require.NoError(t, err)
setts := &tlsConfigSettingsExt{
tlsConfigSettings: tlsConfigSettings{
Enabled: true,
TLSConfig: dnsforward.TLSConfig{
CertificateChain: base64.StdEncoding.EncodeToString(testCertChainData),
PrivateKey: base64.StdEncoding.EncodeToString(testPrivateKeyData),
},
},
}
req, err := json.Marshal(setts)
require.NoError(t, err)
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/control/tls/validate", bytes.NewReader(req))
m.handleTLSValidate(w, r)
res := &tlsConfigStatus{}
err = json.NewDecoder(w.Body).Decode(res)
require.NoError(t, err)
cert, err := tls.X509KeyPair(testCertChainData, testPrivateKeyData)
require.NoError(t, err)
wantIssuer := cert.Leaf.Issuer.String()
assert.Equal(t, wantIssuer, res.Issuer)
}
func TestTLSManager_HandleTLSConfigure(t *testing.T) {
// Store the global state before making any changes.
storeGlobals(t)
var (
logger = slogutil.NewDiscardLogger()
ctx = testutil.ContextWithTimeout(t, testTimeout)
err error
)
globalContext.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{
Logger: logger,
})
require.NoError(t, err)
err = globalContext.dnsServer.Prepare(&dnsforward.ServerConfig{
Config: dnsforward.Config{
UpstreamMode: dnsforward.UpstreamModeLoadBalance,
EDNSClientSubnet: &dnsforward.EDNSClientSubnet{Enabled: false},
ClientsContainer: dnsforward.EmptyClientsContainer{},
},
ServePlainDNS: true,
})
require.NoError(t, err)
globalContext.clients.storage, err = client.NewStorage(ctx, &client.StorageConfig{
Logger: logger,
Clock: timeutil.SystemClock{},
})
require.NoError(t, err)
globalContext.mux = http.NewServeMux()
globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false)
require.NoError(t, err)
config.DNS.BindHosts = []netip.Addr{netip.MustParseAddr("127.0.0.1")}
config.DNS.Port = 0
const wantSerialNumber int64 = 1
// Prepare the TLS manager configuration.
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
certDER, key := newCertAndKey(t, wantSerialNumber)
writeCertAndKey(t, certDER, certPath, key, keyPath)
// Initialize the TLS manager and assert its configuration.
m, err := newTLSManager(tlsConfigSettings{
Enabled: true,
TLSConfig: dnsforward.TLSConfig{
CertificatePath: certPath,
PrivateKeyPath: keyPath,
},
}, true)
require.NoError(t, err)
conf := &tlsConfigSettings{}
m.WriteDiskConfig(conf)
assertCertSerialNumber(t, conf, wantSerialNumber)
// Prepare a request with the new TLS configuration.
setts := &tlsConfigSettingsExt{
tlsConfigSettings: tlsConfigSettings{
Enabled: true,
PortHTTPS: 4433,
TLSConfig: dnsforward.TLSConfig{
CertificateChain: base64.StdEncoding.EncodeToString(testCertChainData),
PrivateKey: base64.StdEncoding.EncodeToString(testPrivateKeyData),
},
},
}
req, err := json.Marshal(setts)
require.NoError(t, err)
r := httptest.NewRequest(http.MethodPost, "/control/tls/configure", bytes.NewReader(req))
w := httptest.NewRecorder()
// Reconfigure the TLS manager.
m.handleTLSConfigure(w, r)
// The [tlsManager.handleTLSConfigure] method will start the DNS server and
// it should be stopped after the test ends.
testutil.CleanupAndRequireSuccess(t, globalContext.dnsServer.Stop)
res := &tlsConfig{
tlsConfigStatus: &tlsConfigStatus{},
}
err = json.NewDecoder(w.Body).Decode(res)
require.NoError(t, err)
cert, err := tls.X509KeyPair(testCertChainData, testPrivateKeyData)
require.NoError(t, err)
wantIssuer := cert.Leaf.Issuer.String()
assert.Equal(t, wantIssuer, res.tlsConfigStatus.Issuer)
// Assert that the Web API's TLS configuration has been updated.
//
// TODO(s.chzhen): Remove when [httpsServer.cond] is removed.
assert.Eventually(t, func() bool {
globalContext.web.httpsServer.condLock.Lock()
defer globalContext.web.httpsServer.condLock.Unlock()
cert = globalContext.web.httpsServer.cert
if cert.Leaf == nil {
return false
}
assert.Equal(t, wantIssuer, cert.Leaf.Issuer.String())
return true
}, testTimeout, testTimeout/10)
}