diff --git a/internal/home/dns.go b/internal/home/dns.go index 4a89131c..9b291695 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -495,31 +495,6 @@ func startDNSServer() error { return nil } -// reconfigureDNSServer updates the DNS server configuration using the provided -// TLS settings. tlsMgr must not be nil. -func reconfigureDNSServer(tlsMgr *tlsManager) (err error) { - tlsConf := &tlsConfigSettings{} - tlsMgr.WriteDiskConfig(tlsConf) - - newConf, err := newServerConfig( - &config.DNS, - config.Clients.Sources, - tlsConf, - httpRegister, - globalContext.clients.storage, - ) - if err != nil { - return fmt.Errorf("generating forwarding dns server config: %w", err) - } - - err = globalContext.dnsServer.Reconfigure(newConf) - if err != nil { - return fmt.Errorf("starting forwarding dns server: %w", err) - } - - return nil -} - func stopDNSServer() (err error) { if !isRunning() { return nil diff --git a/internal/home/tls.go b/internal/home/tls.go index 611a7383..2882e8ef 100644 --- a/internal/home/tls.go +++ b/internal/home/tls.go @@ -153,13 +153,7 @@ func (m *tlsManager) reload() { m.certLastMod = fi.ModTime().UTC() - // TODO(s.chzhen): Temporary check for tests. Remove this after - // refactoring. - if globalContext.dnsServer == nil && globalContext.web == nil { - return - } - - _ = reconfigureDNSServer(m) + _ = m.reconfigureDNSServer() m.confLock.Lock() tlsConf = m.conf @@ -171,6 +165,31 @@ func (m *tlsManager) reload() { globalContext.web.tlsConfigChanged(context.Background(), tlsConf) } +// reconfigureDNSServer updates the DNS server configuration using the stored +// TLS settings. +func (m *tlsManager) reconfigureDNSServer() (err error) { + tlsConf := &tlsConfigSettings{} + m.WriteDiskConfig(tlsConf) + + newConf, err := newServerConfig( + &config.DNS, + config.Clients.Sources, + tlsConf, + httpRegister, + globalContext.clients.storage, + ) + if err != nil { + return fmt.Errorf("generating forwarding dns server config: %w", err) + } + + err = globalContext.dnsServer.Reconfigure(newConf) + if err != nil { + return fmt.Errorf("starting forwarding dns server: %w", err) + } + + return nil +} + // loadTLSConf loads and validates the TLS configuration. The returned error is // also set in status.WarningValidation. func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error) { @@ -448,7 +467,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) onConfigModified() - err = reconfigureDNSServer(m) + err = m.reconfigureDNSServer() if err != nil { aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) diff --git a/internal/home/tls_internal_test.go b/internal/home/tls_internal_test.go index e7de146d..75544aca 100644 --- a/internal/home/tls_internal_test.go +++ b/internal/home/tls_internal_test.go @@ -1,19 +1,29 @@ package home import ( + "bytes" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" + "encoding/base64" + "encoding/json" "encoding/pem" "math/big" + "net/http" + "net/http/httptest" + "net/netip" "os" "path/filepath" "testing" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghalg" + "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" + "github.com/AdguardTeam/golibs/timeutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -86,6 +96,36 @@ func TestValidateCertificates(t *testing.T) { }) } +// storeGlobals is a test helper function that saves global variables and +// restores them once the test is complete. +// +// The global variables are: +// - [configuration.dns] +// - [homeContext.clients.storage] +// - [homeContext.dnsServer] +// - [homeContext.mux] +// - [homeContext.web] +// +// TODO(s.chzhen): Remove this once the TLS manager no longer accesses global +// variables. Make tests that use this helper concurrent. +func storeGlobals(tb testing.TB) { + tb.Helper() + + prevConfig := config + storage := globalContext.clients.storage + dnsServer := globalContext.dnsServer + mux := globalContext.mux + web := globalContext.web + + tb.Cleanup(func() { + config = prevConfig + globalContext.clients.storage = storage + globalContext.dnsServer = dnsServer + globalContext.mux = mux + globalContext.web = web + }) +} + // newCertAndKey is a helper function that generates certificate and key. func newCertAndKey(tb testing.TB, n int64) (certDER []byte, key *rsa.PrivateKey) { tb.Helper() @@ -152,6 +192,30 @@ func assertCertSerialNumber(tb testing.TB, conf *tlsConfigSettings, wantSN int64 } func TestTLSManager_Reload(t *testing.T) { + storeGlobals(t) + + var ( + logger = slogutil.NewDiscardLogger() + ctx = testutil.ContextWithTimeout(t, testTimeout) + err error + ) + + globalContext.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{ + Logger: logger, + }) + require.NoError(t, err) + + globalContext.clients.storage, err = client.NewStorage(ctx, &client.StorageConfig{ + Logger: logger, + Clock: timeutil.SystemClock{}, + }) + require.NoError(t, err) + + globalContext.mux = http.NewServeMux() + + globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false) + require.NoError(t, err) + const ( snBefore int64 = 1 snAfter int64 = 2 @@ -185,3 +249,254 @@ func TestTLSManager_Reload(t *testing.T) { m.WriteDiskConfig(conf) assertCertSerialNumber(t, conf, snAfter) } + +func TestTLSManager_HandleTLSStatus(t *testing.T) { + m, err := newTLSManager(tlsConfigSettings{ + Enabled: true, + TLSConfig: dnsforward.TLSConfig{ + CertificateChain: string(testCertChainData), + PrivateKey: string(testPrivateKeyData), + }, + }, false) + require.NoError(t, err) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/control/tls/status", nil) + m.handleTLSStatus(w, r) + + res := &tlsConfigSettingsExt{} + err = json.NewDecoder(w.Body).Decode(res) + require.NoError(t, err) + + wantCertificateChain := base64.StdEncoding.EncodeToString(testCertChainData) + assert.True(t, res.Enabled) + assert.Equal(t, wantCertificateChain, res.CertificateChain) + assert.True(t, res.PrivateKeySaved) +} + +func TestValidateTLSSettings(t *testing.T) { + storeGlobals(t) + + var ( + logger = slogutil.NewDiscardLogger() + ctx = testutil.ContextWithTimeout(t, testTimeout) + err error + ) + + globalContext.mux = http.NewServeMux() + + globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false) + require.NoError(t, err) + + testCases := []struct { + setts tlsConfigSettingsExt + name string + wantErr string + }{{ + name: "basic", + setts: tlsConfigSettingsExt{}, + wantErr: "", + }, { + setts: tlsConfigSettingsExt{ + ServePlainDNS: aghalg.NBFalse, + }, + name: "disabled_all", + wantErr: "plain DNS is required in case encryption protocols are disabled", + }, { + setts: tlsConfigSettingsExt{ + tlsConfigSettings: tlsConfigSettings{ + Enabled: true, + PortHTTPS: 433, + }, + }, + name: "privileged_port", + wantErr: "port 433 is not available, cannot enable HTTPS on it", + }, { + setts: tlsConfigSettingsExt{ + tlsConfigSettings: tlsConfigSettings{ + Enabled: true, + PortHTTPS: 4433, + PortDNSOverTLS: 4433, + }, + }, + name: "duplicate_port", + wantErr: "validating tcp ports: duplicated values: [4433]", + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err = validateTLSSettings(tc.setts) + testutil.AssertErrorMsg(t, tc.wantErr, err) + }) + } +} + +func TestTLSManager_HandleTLSValidate(t *testing.T) { + storeGlobals(t) + + var ( + logger = slogutil.NewDiscardLogger() + ctx = testutil.ContextWithTimeout(t, testTimeout) + err error + ) + + globalContext.mux = http.NewServeMux() + + globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false) + require.NoError(t, err) + + m, err := newTLSManager(tlsConfigSettings{ + Enabled: true, + TLSConfig: dnsforward.TLSConfig{ + CertificateChain: string(testCertChainData), + PrivateKey: string(testPrivateKeyData), + }, + }, false) + require.NoError(t, err) + + setts := &tlsConfigSettingsExt{ + tlsConfigSettings: tlsConfigSettings{ + Enabled: true, + TLSConfig: dnsforward.TLSConfig{ + CertificateChain: base64.StdEncoding.EncodeToString(testCertChainData), + PrivateKey: base64.StdEncoding.EncodeToString(testPrivateKeyData), + }, + }, + } + + req, err := json.Marshal(setts) + require.NoError(t, err) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/control/tls/validate", bytes.NewReader(req)) + m.handleTLSValidate(w, r) + + res := &tlsConfigStatus{} + err = json.NewDecoder(w.Body).Decode(res) + require.NoError(t, err) + + cert, err := tls.X509KeyPair(testCertChainData, testPrivateKeyData) + require.NoError(t, err) + + wantIssuer := cert.Leaf.Issuer.String() + assert.Equal(t, wantIssuer, res.Issuer) +} + +func TestTLSManager_HandleTLSConfigure(t *testing.T) { + // Store the global state before making any changes. + storeGlobals(t) + + var ( + logger = slogutil.NewDiscardLogger() + ctx = testutil.ContextWithTimeout(t, testTimeout) + err error + ) + + globalContext.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{ + Logger: logger, + }) + require.NoError(t, err) + + err = globalContext.dnsServer.Prepare(&dnsforward.ServerConfig{ + Config: dnsforward.Config{ + UpstreamMode: dnsforward.UpstreamModeLoadBalance, + EDNSClientSubnet: &dnsforward.EDNSClientSubnet{Enabled: false}, + ClientsContainer: dnsforward.EmptyClientsContainer{}, + }, + ServePlainDNS: true, + }) + require.NoError(t, err) + + globalContext.clients.storage, err = client.NewStorage(ctx, &client.StorageConfig{ + Logger: logger, + Clock: timeutil.SystemClock{}, + }) + require.NoError(t, err) + + globalContext.mux = http.NewServeMux() + + globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false) + require.NoError(t, err) + + config.DNS.BindHosts = []netip.Addr{netip.MustParseAddr("127.0.0.1")} + config.DNS.Port = 0 + + const wantSerialNumber int64 = 1 + + // Prepare the TLS manager configuration. + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "cert.pem") + keyPath := filepath.Join(tmpDir, "key.pem") + + certDER, key := newCertAndKey(t, wantSerialNumber) + writeCertAndKey(t, certDER, certPath, key, keyPath) + + // Initialize the TLS manager and assert its configuration. + m, err := newTLSManager(tlsConfigSettings{ + Enabled: true, + TLSConfig: dnsforward.TLSConfig{ + CertificatePath: certPath, + PrivateKeyPath: keyPath, + }, + }, true) + require.NoError(t, err) + + conf := &tlsConfigSettings{} + m.WriteDiskConfig(conf) + assertCertSerialNumber(t, conf, wantSerialNumber) + + // Prepare a request with the new TLS configuration. + setts := &tlsConfigSettingsExt{ + tlsConfigSettings: tlsConfigSettings{ + Enabled: true, + PortHTTPS: 4433, + TLSConfig: dnsforward.TLSConfig{ + CertificateChain: base64.StdEncoding.EncodeToString(testCertChainData), + PrivateKey: base64.StdEncoding.EncodeToString(testPrivateKeyData), + }, + }, + } + + req, err := json.Marshal(setts) + require.NoError(t, err) + + r := httptest.NewRequest(http.MethodPost, "/control/tls/configure", bytes.NewReader(req)) + w := httptest.NewRecorder() + + // Reconfigure the TLS manager. + m.handleTLSConfigure(w, r) + + // The [tlsManager.handleTLSConfigure] method will start the DNS server and + // it should be stopped after the test ends. + t.Cleanup(func() { + err = globalContext.dnsServer.Stop() + require.NoError(t, err) + }) + + res := &tlsConfig{ + tlsConfigStatus: &tlsConfigStatus{}, + } + err = json.NewDecoder(w.Body).Decode(res) + require.NoError(t, err) + + cert, err := tls.X509KeyPair(testCertChainData, testPrivateKeyData) + require.NoError(t, err) + + wantIssuer := cert.Leaf.Issuer.String() + assert.Equal(t, wantIssuer, res.tlsConfigStatus.Issuer) + + // Assert that the Web API's TLS configuration has been updated. + assert.Eventually(t, func() bool { + globalContext.web.httpsServer.condLock.Lock() + defer globalContext.web.httpsServer.condLock.Unlock() + + cert = globalContext.web.httpsServer.cert + if cert.Leaf == nil { + return false + } + + assert.Equal(t, wantIssuer, cert.Leaf.Issuer.String()) + + return true + }, testTimeout, testTimeout/10) +}