package home import ( "bytes" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "encoding/base64" "encoding/json" "encoding/pem" "fmt" "math/big" "net" "net/http" "net/http/httptest" "net/netip" "os" "path/filepath" "testing" "time" "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) var testCertChainData = []byte(`-----BEGIN CERTIFICATE----- MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3 MDkyNDIzWhcNNDYwNzE0MDkyNDIzWjAtMRQwEgYDVQQKDAtBZEd1YXJkIEx0ZDEV MBMGA1UEAwwMQWRHdWFyZCBIb21lMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB gQCwvwUnPJiOvLcOaWmGu6Y68ksFr13nrXBcsDlhxlXy8PaohVi3XxEmt2OrVjKW QFw/bdV4fZ9tdWFAVRRkgeGbIZzP7YBD1Ore/O5SQ+DbCCEafvjJCcXQIrTeKFE6 i9G3aSMHs0Pwq2LgV8U5mYotLrvyFiE8QPInJbDDMpaFYwIDAQABo1MwUTAdBgNV HQ4EFgQUdLUmQpEqrhn4eKO029jYd2AAZEQwHwYDVR0jBBgwFoAUdLUmQpEqrhn4 eKO029jYd2AAZEQwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOBgQB8 LwlXfbakf7qkVTlCNXgoY7RaJ8rJdPgOZPoCTVToEhT6u/cb1c2qp8QB0dNExDna b0Z+dnODTZqQOJo6z/wIXlcUrnR4cQVvytXt8lFn+26l6Y6EMI26twC/xWr+1swq Muj4FeWHVDerquH4yMr1jsYLD3ci+kc5sbIX6TfVxQ== -----END CERTIFICATE-----`) var testPrivateKeyData = []byte(`-----BEGIN PRIVATE KEY----- MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALC/BSc8mI68tw5p aYa7pjrySwWvXeetcFywOWHGVfLw9qiFWLdfESa3Y6tWMpZAXD9t1Xh9n211YUBV FGSB4ZshnM/tgEPU6t787lJD4NsIIRp++MkJxdAitN4oUTqL0bdpIwezQ/CrYuBX xTmZii0uu/IWITxA8iclsMMyloVjAgMBAAECgYEAmjzoG1h27UDkIlB9BVWl95TP QVPLB81D267xNFDnWk1Lgr5zL/pnNjkdYjyjgpkBp1yKyE4gHV4skv5sAFWTcOCU QCgfPfUn/rDFcxVzAdJVWAa/CpJNaZgjTPR8NTGU+Ztod+wfBESNCP5tbnuw0GbL MuwdLQJGbzeJYpsNysECQQDfFHYoRNfgxHwMbX24GCoNZIgk12uDmGTA9CS5E+72 9t3V1y4CfXxSkfhqNbd5RWrUBRLEw9BKofBS7L9NMDKDAkEAytQoIueE1vqEAaRg a3A1YDUekKesU5wKfKfKlXvNgB7Hwh4HuvoQS9RCvVhf/60Dvq8KSu6hSjkFRquj FQ5roQJBAMwKwyiCD5MfJPeZDmzcbVpiocRQ5Z4wPbffl9dRTDnIA5AciZDthlFg An/jMjZSMCxNl6UyFcqt5Et1EGVhuFECQQCZLXxaT+qcyHjlHJTMzuMgkz1QFbEp O5EX70gpeGQMPDK0QSWpaazg956njJSDbNCFM4BccrdQbJu1cW4qOsfBAkAMgZuG O88slmgTRHX4JGFmy3rrLiHNI2BbJSuJ++Yllz8beVzh6NfvuY+HKRCmPqoBPATU kXS9jgARhhiWXJrk -----END PRIVATE KEY-----`) func TestValidateCertificates(t *testing.T) { ctx := testutil.ContextWithTimeout(t, testTimeout) logger := slogutil.NewDiscardLogger() m, err := newTLSManager(ctx, &tlsManagerConfig{ logger: logger, configModified: func() {}, servePlainDNS: false, }) require.NoError(t, err) t.Run("bad_certificate", func(t *testing.T) { status := &tlsConfigStatus{} err = m.validateCertificates(ctx, status, []byte("bad cert"), nil, "") testutil.AssertErrorMsg(t, "empty certificate", err) assert.False(t, status.ValidCert) assert.False(t, status.ValidChain) }) t.Run("bad_private_key", func(t *testing.T) { status := &tlsConfigStatus{} err = m.validateCertificates(ctx, status, nil, []byte("bad priv key"), "") testutil.AssertErrorMsg(t, "no valid keys were found", err) assert.False(t, status.ValidKey) }) t.Run("valid", func(t *testing.T) { status := &tlsConfigStatus{} err = m.validateCertificates(ctx, status, testCertChainData, testPrivateKeyData, "") assert.Error(t, err) notBefore := time.Date(2019, 2, 27, 9, 24, 23, 0, time.UTC) notAfter := time.Date(2046, 7, 14, 9, 24, 23, 0, time.UTC) assert.True(t, status.ValidCert) assert.False(t, status.ValidChain) assert.True(t, status.ValidKey) assert.Equal(t, "RSA", status.KeyType) assert.Equal(t, "CN=AdGuard Home,O=AdGuard Ltd", status.Subject) assert.Equal(t, "CN=AdGuard Home,O=AdGuard Ltd", status.Issuer) assert.Equal(t, notBefore, status.NotBefore) assert.Equal(t, notAfter, status.NotAfter) assert.True(t, status.ValidPair) }) } // storeGlobals is a test helper function that saves global variables and // restores them once the test is complete. // // The global variables are: // - [configuration.dns] // - [homeContext.clients.storage] // - [homeContext.dnsServer] // - [homeContext.mux] // - [homeContext.web] // // TODO(s.chzhen): Remove this once the TLS manager no longer accesses global // variables. Make tests that use this helper concurrent. func storeGlobals(tb testing.TB) { tb.Helper() prevConfig := config storage := globalContext.clients.storage dnsServer := globalContext.dnsServer mux := globalContext.mux web := globalContext.web tb.Cleanup(func() { config = prevConfig globalContext.clients.storage = storage globalContext.dnsServer = dnsServer globalContext.mux = mux globalContext.web = web }) } // newCertAndKey is a helper function that generates certificate and key. func newCertAndKey(tb testing.TB, n int64) (certDER []byte, key *rsa.PrivateKey) { tb.Helper() key, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(tb, err) certTmpl := &x509.Certificate{ SerialNumber: big.NewInt(n), } certDER, err = x509.CreateCertificate(rand.Reader, certTmpl, certTmpl, &key.PublicKey, key) require.NoError(tb, err) return certDER, key } // writeCertAndKey is a helper function that writes certificate and key to // specified paths. func writeCertAndKey( tb testing.TB, certDER []byte, certPath string, key *rsa.PrivateKey, keyPath string, ) { tb.Helper() certFile, err := os.OpenFile(certPath, os.O_WRONLY|os.O_CREATE, 0o600) require.NoError(tb, err) defer func() { err = certFile.Close() require.NoError(tb, err) }() err = pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) require.NoError(tb, err) keyFile, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE, 0o600) require.NoError(tb, err) defer func() { err = keyFile.Close() require.NoError(tb, err) }() err = pem.Encode(keyFile, &pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key), }) require.NoError(tb, err) } // assertCertSerialNumber is a helper function that checks serial number of the // TLS certificate. func assertCertSerialNumber(tb testing.TB, conf *tlsConfigSettings, wantSN int64) { tb.Helper() cert, err := tls.X509KeyPair(conf.CertificateChainData, conf.PrivateKeyData) require.NoError(tb, err) assert.Equal(tb, wantSN, cert.Leaf.SerialNumber.Int64()) } func TestTLSManager_Reload(t *testing.T) { storeGlobals(t) var ( logger = slogutil.NewDiscardLogger() ctx = testutil.ContextWithTimeout(t, testTimeout) err error ) globalContext.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{ Logger: logger, }) require.NoError(t, err) globalContext.clients.storage, err = client.NewStorage(ctx, &client.StorageConfig{ Logger: logger, Clock: timeutil.SystemClock{}, }) require.NoError(t, err) globalContext.mux = http.NewServeMux() globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false) require.NoError(t, err) const ( snBefore int64 = 1 snAfter int64 = 2 ) tmpDir := t.TempDir() certPath := filepath.Join(tmpDir, "cert.pem") keyPath := filepath.Join(tmpDir, "key.pem") certDER, key := newCertAndKey(t, snBefore) writeCertAndKey(t, certDER, certPath, key, keyPath) m, err := newTLSManager(ctx, &tlsManagerConfig{ logger: logger, configModified: func() {}, tlsSettings: tlsConfigSettings{ Enabled: true, TLSConfig: dnsforward.TLSConfig{ CertificatePath: certPath, PrivateKeyPath: keyPath, }, }, servePlainDNS: false, }) require.NoError(t, err) conf := &tlsConfigSettings{} m.WriteDiskConfig(conf) assertCertSerialNumber(t, conf, snBefore) certDER, key = newCertAndKey(t, snAfter) writeCertAndKey(t, certDER, certPath, key, keyPath) m.setWebAPI(globalContext.web) m.reload(ctx) m.WriteDiskConfig(conf) assertCertSerialNumber(t, conf, snAfter) } func TestTLSManager_HandleTLSStatus(t *testing.T) { var ( logger = slogutil.NewDiscardLogger() ctx = testutil.ContextWithTimeout(t, testTimeout) err error ) m, err := newTLSManager(ctx, &tlsManagerConfig{ logger: logger, configModified: func() {}, tlsSettings: tlsConfigSettings{ Enabled: true, TLSConfig: dnsforward.TLSConfig{ CertificateChain: string(testCertChainData), PrivateKey: string(testPrivateKeyData), }, }, servePlainDNS: false, }) require.NoError(t, err) w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/control/tls/status", nil) m.handleTLSStatus(w, r) res := &tlsConfigSettingsExt{} err = json.NewDecoder(w.Body).Decode(res) require.NoError(t, err) wantCertificateChain := base64.StdEncoding.EncodeToString(testCertChainData) assert.True(t, res.Enabled) assert.Equal(t, wantCertificateChain, res.CertificateChain) assert.True(t, res.PrivateKeySaved) } func TestValidateTLSSettings(t *testing.T) { storeGlobals(t) var ( logger = slogutil.NewDiscardLogger() ctx = testutil.ContextWithTimeout(t, testTimeout) err error ) ln, err := net.Listen("tcp", ":0") require.NoError(t, err) testutil.CleanupAndRequireSuccess(t, ln.Close) addr := testutil.RequireTypeAssert[*net.TCPAddr](t, ln.Addr()) busyPort := addr.Port globalContext.mux = http.NewServeMux() globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false) require.NoError(t, err) testCases := []struct { setts tlsConfigSettingsExt name string wantErr string }{{ name: "basic", setts: tlsConfigSettingsExt{}, wantErr: "", }, { setts: tlsConfigSettingsExt{ ServePlainDNS: aghalg.NBFalse, }, name: "disabled_all", wantErr: "plain DNS is required in case encryption protocols are disabled", }, { setts: tlsConfigSettingsExt{ tlsConfigSettings: tlsConfigSettings{ Enabled: true, PortHTTPS: uint16(busyPort), }, }, name: "busy_port", wantErr: fmt.Sprintf("port %d is not available, cannot enable HTTPS on it", busyPort), }, { setts: tlsConfigSettingsExt{ tlsConfigSettings: tlsConfigSettings{ Enabled: true, PortHTTPS: 4433, PortDNSOverTLS: 4433, }, }, name: "duplicate_port", wantErr: "validating tcp ports: duplicated values: [4433]", }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { err = validateTLSSettings(tc.setts) testutil.AssertErrorMsg(t, tc.wantErr, err) }) } } func TestTLSManager_HandleTLSValidate(t *testing.T) { storeGlobals(t) var ( logger = slogutil.NewDiscardLogger() ctx = testutil.ContextWithTimeout(t, testTimeout) err error ) globalContext.mux = http.NewServeMux() globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false) require.NoError(t, err) m, err := newTLSManager(ctx, &tlsManagerConfig{ logger: logger, configModified: func() {}, tlsSettings: tlsConfigSettings{ Enabled: true, TLSConfig: dnsforward.TLSConfig{ CertificateChain: string(testCertChainData), PrivateKey: string(testPrivateKeyData), }, }, servePlainDNS: false, }) require.NoError(t, err) setts := &tlsConfigSettingsExt{ tlsConfigSettings: tlsConfigSettings{ Enabled: true, TLSConfig: dnsforward.TLSConfig{ CertificateChain: base64.StdEncoding.EncodeToString(testCertChainData), PrivateKey: base64.StdEncoding.EncodeToString(testPrivateKeyData), }, }, } req, err := json.Marshal(setts) require.NoError(t, err) w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/control/tls/validate", bytes.NewReader(req)) m.handleTLSValidate(w, r) res := &tlsConfigStatus{} err = json.NewDecoder(w.Body).Decode(res) require.NoError(t, err) cert, err := tls.X509KeyPair(testCertChainData, testPrivateKeyData) require.NoError(t, err) wantIssuer := cert.Leaf.Issuer.String() assert.Equal(t, wantIssuer, res.Issuer) } func TestTLSManager_HandleTLSConfigure(t *testing.T) { // Store the global state before making any changes. storeGlobals(t) var ( logger = slogutil.NewDiscardLogger() ctx = testutil.ContextWithTimeout(t, testTimeout) err error ) globalContext.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{ Logger: logger, }) require.NoError(t, err) err = globalContext.dnsServer.Prepare(&dnsforward.ServerConfig{ Config: dnsforward.Config{ UpstreamMode: dnsforward.UpstreamModeLoadBalance, EDNSClientSubnet: &dnsforward.EDNSClientSubnet{Enabled: false}, ClientsContainer: dnsforward.EmptyClientsContainer{}, }, ServePlainDNS: true, }) require.NoError(t, err) globalContext.clients.storage, err = client.NewStorage(ctx, &client.StorageConfig{ Logger: logger, Clock: timeutil.SystemClock{}, }) require.NoError(t, err) globalContext.mux = http.NewServeMux() globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false) require.NoError(t, err) config.DNS.BindHosts = []netip.Addr{netip.MustParseAddr("127.0.0.1")} config.DNS.Port = 0 const wantSerialNumber int64 = 1 // Prepare the TLS manager configuration. tmpDir := t.TempDir() certPath := filepath.Join(tmpDir, "cert.pem") keyPath := filepath.Join(tmpDir, "key.pem") certDER, key := newCertAndKey(t, wantSerialNumber) writeCertAndKey(t, certDER, certPath, key, keyPath) // Initialize the TLS manager and assert its configuration. m, err := newTLSManager(ctx, &tlsManagerConfig{ logger: logger, configModified: func() {}, tlsSettings: tlsConfigSettings{ Enabled: true, TLSConfig: dnsforward.TLSConfig{ CertificatePath: certPath, PrivateKeyPath: keyPath, }, }, servePlainDNS: true, }) require.NoError(t, err) conf := &tlsConfigSettings{} m.WriteDiskConfig(conf) assertCertSerialNumber(t, conf, wantSerialNumber) // Prepare a request with the new TLS configuration. setts := &tlsConfigSettingsExt{ tlsConfigSettings: tlsConfigSettings{ Enabled: true, PortHTTPS: 4433, TLSConfig: dnsforward.TLSConfig{ CertificateChain: base64.StdEncoding.EncodeToString(testCertChainData), PrivateKey: base64.StdEncoding.EncodeToString(testPrivateKeyData), }, }, } req, err := json.Marshal(setts) require.NoError(t, err) r := httptest.NewRequest(http.MethodPost, "/control/tls/configure", bytes.NewReader(req)) w := httptest.NewRecorder() // Reconfigure the TLS manager. m.setWebAPI(globalContext.web) m.handleTLSConfigure(w, r) // The [tlsManager.handleTLSConfigure] method will start the DNS server and // it should be stopped after the test ends. testutil.CleanupAndRequireSuccess(t, globalContext.dnsServer.Stop) res := &tlsConfig{ tlsConfigStatus: &tlsConfigStatus{}, } err = json.NewDecoder(w.Body).Decode(res) require.NoError(t, err) cert, err := tls.X509KeyPair(testCertChainData, testPrivateKeyData) require.NoError(t, err) wantIssuer := cert.Leaf.Issuer.String() assert.Equal(t, wantIssuer, res.tlsConfigStatus.Issuer) // Assert that the Web API's TLS configuration has been updated. // // TODO(s.chzhen): Remove when [httpsServer.cond] is removed. assert.Eventually(t, func() bool { globalContext.web.httpsServer.condLock.Lock() defer globalContext.web.httpsServer.condLock.Unlock() cert = globalContext.web.httpsServer.cert if cert.Leaf == nil { return false } assert.Equal(t, wantIssuer, cert.Leaf.Issuer.String()) return true }, testTimeout, testTimeout/10) }