Pull request 2371: AGDNS-2714-tls-manager

Merge in DNS/adguard-home from AGDNS-2714-tls-manager to master

Squashed commit of the following:

commit 5c7cd1fa6d8a9bc1fd0f891818589b48bee641dc
Merge: 381f7666b 810ae9483
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Mar 26 14:13:49 2025 +0300

    Merge branch 'master' into AGDNS-2714-tls-manager

commit 381f7666b0
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Mar 25 19:53:12 2025 +0300

    home: imp code

commit 20be72abd4
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Mar 25 19:19:51 2025 +0300

    home: imp code

commit b5a06e6a15
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Mar 24 21:45:41 2025 +0300

    home: imp code

commit a6a5ba727e
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Mar 20 21:06:34 2025 +0300

    home: imp docs

commit 71d379bafc
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Mar 20 20:47:15 2025 +0300

    all: upd chlog

commit be69a5b85d
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Mar 19 20:14:20 2025 +0300

    home: imp docs

commit 85b28db73b
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Mar 19 20:07:59 2025 +0300

    home: imp code

commit c11e4c9e50
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Mar 19 19:11:59 2025 +0300

    home: imp code

commit 60eff2c663
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Mar 18 21:27:49 2025 +0300

    home: imp code

commit fa9d57b283
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Mar 18 21:14:56 2025 +0300

    home: imp docs

commit 3f561b6475
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Mar 18 20:59:59 2025 +0300

    home: imp code

commit 927296c49f
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Mar 18 18:19:22 2025 +0300

    home: imp naming

commit e35f742e42
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Mar 18 17:53:17 2025 +0300

    home: tls manager web api

commit 85a4de7931
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Mar 18 15:06:34 2025 +0300

    home: tls manager config

commit 515b26d6bd
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Mar 17 22:15:25 2025 +0300

    home: tls manager
This commit is contained in:
Stanislav Chzhen 2025-03-26 14:26:57 +03:00
parent 810ae94832
commit 8b4768aadd
8 changed files with 420 additions and 208 deletions

View file

@ -53,6 +53,8 @@ See also the [v0.107.58 GitHub milestone][ms-v0.107.58].
### Fixed ### Fixed
- Validation process for the HTTPS port on the *Encryption Settings* page.
- Clearing the DNS cache on the *DNS settings* page now includes both global cache and custom client cache. - Clearing the DNS cache on the *DNS settings* page now includes both global cache and custom client cache.
- Invalid ICMPv6 Router Advertisement messages ([#7547]). - Invalid ICMPv6 Router Advertisement messages ([#7547]).

View file

@ -568,7 +568,7 @@ func parseConfig() (err error) {
} }
// Do not wrap the error because it's informative enough as is. // Do not wrap the error because it's informative enough as is.
return setContextTLSCipherIDs() return validateTLSCipherIDs(config.TLS.OverrideTLSCiphers)
} }
// validateConfig returns error if the configuration is invalid. // validateConfig returns error if the configuration is invalid.
@ -721,21 +721,15 @@ func (c *configuration) write(tlsMgr *tlsManager) (err error) {
return nil return nil
} }
// setContextTLSCipherIDs sets the TLS cipher suite IDs to use. // validateTLSCipherIDs validates the custom TLS cipher suite IDs.
func setContextTLSCipherIDs() (err error) { func validateTLSCipherIDs(cipherIDs []string) (err error) {
if len(config.TLS.OverrideTLSCiphers) == 0 { if len(cipherIDs) == 0 {
log.Info("tls: using default ciphers")
globalContext.tlsCipherIDs = aghtls.SaferCipherSuites()
return nil return nil
} }
log.Info("tls: overriding ciphers: %s", config.TLS.OverrideTLSCiphers) _, err = aghtls.ParseCiphers(cipherIDs)
globalContext.tlsCipherIDs, err = aghtls.ParseCiphers(config.TLS.OverrideTLSCiphers)
if err != nil { if err != nil {
return fmt.Errorf("parsing override ciphers: %w", err) return fmt.Errorf("override_tls_ciphers: %w", err)
} }
return nil return nil

View file

@ -38,6 +38,8 @@ const (
) )
// Called by other modules when configuration is changed // Called by other modules when configuration is changed
//
// TODO(s.chzhen): Remove this after refactoring.
func onConfigModified() { func onConfigModified() {
err := config.write(globalContext.tls) err := config.write(globalContext.tls)
if err != nil { if err != nil {
@ -120,14 +122,15 @@ func initDNS(
anonymizer, anonymizer,
httpRegister, httpRegister,
tlsConf, tlsConf,
tlsMgr,
baseLogger, baseLogger,
) )
} }
// initDNSServer initializes the [context.dnsServer]. To only use the internal // initDNSServer initializes the [context.dnsServer]. To only use the internal
// proxy, none of the arguments are required, but tlsConf and l still must not // proxy, none of the arguments are required, but tlsConf, tlsMgr and l still
// be nil, in other cases all the arguments also must not be nil. It also must // must not be nil, in other cases all the arguments also must not be nil. It
// not be called unless [config] and [globalContext] are initialized. // also must not be called unless [config] and [globalContext] are initialized.
// //
// TODO(e.burkov): Use [dnsforward.DNSCreateParams] as a parameter. // TODO(e.burkov): Use [dnsforward.DNSCreateParams] as a parameter.
func initDNSServer( func initDNSServer(
@ -138,6 +141,7 @@ func initDNSServer(
anonymizer *aghnet.IPMut, anonymizer *aghnet.IPMut,
httpReg aghhttp.RegisterFunc, httpReg aghhttp.RegisterFunc,
tlsConf *tlsConfigSettings, tlsConf *tlsConfigSettings,
tlsMgr *tlsManager,
l *slog.Logger, l *slog.Logger,
) (err error) { ) (err error) {
globalContext.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{ globalContext.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{
@ -166,6 +170,7 @@ func initDNSServer(
&config.DNS, &config.DNS,
config.Clients.Sources, config.Clients.Sources,
tlsConf, tlsConf,
tlsMgr,
httpReg, httpReg,
globalContext.clients.storage, globalContext.clients.storage,
) )
@ -236,11 +241,12 @@ func ipsToUDPAddrs(ips []netip.Addr, port uint16) (udpAddrs []*net.UDPAddr) {
} }
// newServerConfig converts values from the configuration file into the internal // newServerConfig converts values from the configuration file into the internal
// DNS server configuration. All arguments must not be nil. // DNS server configuration. All arguments must not be nil, except for httpReg.
func newServerConfig( func newServerConfig(
dnsConf *dnsConfig, dnsConf *dnsConfig,
clientSrcConf *clientSourcesConfig, clientSrcConf *clientSourcesConfig,
tlsConf *tlsConfigSettings, tlsConf *tlsConfigSettings,
tlsMgr *tlsManager,
httpReg aghhttp.RegisterFunc, httpReg aghhttp.RegisterFunc,
clientsContainer dnsforward.ClientsContainer, clientsContainer dnsforward.ClientsContainer,
) (newConf *dnsforward.ServerConfig, err error) { ) (newConf *dnsforward.ServerConfig, err error) {
@ -256,7 +262,7 @@ func newServerConfig(
TLSConfig: newDNSTLSConfig(tlsConf, hosts), TLSConfig: newDNSTLSConfig(tlsConf, hosts),
TLSAllowUnencryptedDoH: tlsConf.AllowUnencryptedDoH, TLSAllowUnencryptedDoH: tlsConf.AllowUnencryptedDoH,
UpstreamTimeout: time.Duration(dnsConf.UpstreamTimeout), UpstreamTimeout: time.Duration(dnsConf.UpstreamTimeout),
TLSv12Roots: globalContext.tlsRoots, TLSv12Roots: tlsMgr.rootCerts,
ConfigModified: onConfigModified, ConfigModified: onConfigModified,
HTTPRegister: httpReg, HTTPRegister: httpReg,
LocalPTRResolvers: dnsConf.PrivateRDNSResolvers, LocalPTRResolvers: dnsConf.PrivateRDNSResolvers,

View file

@ -3,7 +3,6 @@ package home
import ( import (
"context" "context"
"crypto/x509"
"fmt" "fmt"
"io/fs" "io/fs"
"log/slog" "log/slog"
@ -22,7 +21,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
"github.com/AdguardTeam/AdGuardHome/internal/arpdb" "github.com/AdguardTeam/AdGuardHome/internal/arpdb"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
@ -81,10 +79,6 @@ type homeContext struct {
workDir string // Location of our directory, used to protect against CWD being somewhere else workDir string // Location of our directory, used to protect against CWD being somewhere else
pidFileName string // PID file name. Empty if no PID file was created. pidFileName string // PID file name. Empty if no PID file was created.
controlLock sync.Mutex controlLock sync.Mutex
tlsRoots *x509.CertPool // list of root CAs for TLSv1.2
// tlsCipherIDs are the ID of the cipher suites that AdGuard Home must use.
tlsCipherIDs []uint16
// firstRun, if true, tells AdGuard Home to only start the web interface // firstRun, if true, tells AdGuard Home to only start the web interface
// service, and only serve the first-run APIs. // service, and only serve the first-run APIs.
@ -142,7 +136,6 @@ func Main(clientBuildFS fs.FS) {
func setupContext(opts options) (err error) { func setupContext(opts options) (err error) {
globalContext.firstRun = detectFirstRun() globalContext.firstRun = detectFirstRun()
globalContext.tlsRoots = aghtls.SystemRootCAs()
globalContext.mux = http.NewServeMux() globalContext.mux = http.NewServeMux()
if !opts.noEtcHosts { if !opts.noEtcHosts {
@ -274,18 +267,13 @@ func setupOpts(opts options) (err error) {
return nil return nil
} }
// initContextClients initializes Context clients and related fields. // initContextClients initializes Context clients and related fields. All
// arguments must not be nil.
func initContextClients( func initContextClients(
ctx context.Context, ctx context.Context,
logger *slog.Logger, logger *slog.Logger,
sigHdlr *signalHandler, sigHdlr *signalHandler,
) (err error) { ) (err error) {
err = setupDNSFilteringConf(ctx, logger, config.Filtering)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
//lint:ignore SA1019 Migration is not over. //lint:ignore SA1019 Migration is not over.
config.DHCP.WorkDir = globalContext.workDir config.DHCP.WorkDir = globalContext.workDir
config.DHCP.DataDir = globalContext.getDataDir() config.DHCP.DataDir = globalContext.getDataDir()
@ -358,11 +346,13 @@ func setupBindOpts(opts options) (err error) {
return nil return nil
} }
// setupDNSFilteringConf sets up DNS filtering configuration settings. // setupDNSFilteringConf sets up DNS filtering configuration settings. All
// arguments must not be nil.
func setupDNSFilteringConf( func setupDNSFilteringConf(
ctx context.Context, ctx context.Context,
baseLogger *slog.Logger, baseLogger *slog.Logger,
conf *filtering.Config, conf *filtering.Config,
tlsMgr *tlsManager,
) (err error) { ) (err error) {
const ( const (
dnsTimeout = 3 * time.Second dnsTimeout = 3 * time.Second
@ -388,7 +378,7 @@ func setupDNSFilteringConf(
conf.Filters = slices.Clone(config.Filters) conf.Filters = slices.Clone(config.Filters)
conf.WhitelistFilters = slices.Clone(config.WhitelistFilters) conf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
conf.UserRules = slices.Clone(config.UserRules) conf.UserRules = slices.Clone(config.UserRules)
conf.HTTPClient = httpClient() conf.HTTPClient = httpClient(tlsMgr)
cacheTime := time.Duration(conf.CacheTime) * time.Minute cacheTime := time.Duration(conf.CacheTime) * time.Minute
@ -630,6 +620,23 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH
err = initContextClients(ctx, slogLogger, sigHdlr) err = initContextClients(ctx, slogLogger, sigHdlr)
fatalOnError(err) fatalOnError(err)
tlsMgrLogger := slogLogger.With(slogutil.KeyPrefix, "tls_manager")
tlsMgr, err := newTLSManager(ctx, &tlsManagerConfig{
logger: tlsMgrLogger,
configModified: onConfigModified,
tlsSettings: config.TLS,
servePlainDNS: config.DNS.ServePlainDNS,
})
if err != nil {
tlsMgrLogger.ErrorContext(ctx, "initializing", slogutil.KeyError, err)
onConfigModified()
}
globalContext.tls = tlsMgr
err = setupDNSFilteringConf(ctx, slogLogger, config.Filtering, tlsMgr)
fatalOnError(err)
err = setupOpts(opts) err = setupOpts(opts)
fatalOnError(err) fatalOnError(err)
@ -642,7 +649,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH
// TODO(e.burkov): This could be made earlier, probably as the option's // TODO(e.burkov): This could be made earlier, probably as the option's
// effect. // effect.
cmdlineUpdate(ctx, slogLogger, opts, upd) cmdlineUpdate(ctx, slogLogger, opts, upd, tlsMgr)
if !globalContext.firstRun { if !globalContext.firstRun {
// Save the updated config. // Save the updated config.
@ -664,19 +671,14 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH
globalContext.auth, err = initUsers() globalContext.auth, err = initUsers()
fatalOnError(err) fatalOnError(err)
tlsMgrLogger := slogLogger.With(slogutil.KeyPrefix, "tls_manager") web, err := initWeb(ctx, opts, clientBuildFS, upd, slogLogger, tlsMgr, customURL)
tlsMgr, err := newTLSManager(ctx, tlsMgrLogger, config.TLS, config.DNS.ServePlainDNS)
if err != nil {
log.Error("initializing tls: %s", err)
onConfigModified()
}
globalContext.tls = tlsMgr
sigHdlr.addTLSManager(tlsMgr)
globalContext.web, err = initWeb(ctx, opts, clientBuildFS, upd, slogLogger, tlsMgr, customURL)
fatalOnError(err) fatalOnError(err)
globalContext.web = web
tlsMgr.setWebAPI(web)
sigHdlr.addTLSManager(tlsMgr)
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&globalContext, config) statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&globalContext, config)
fatalOnError(err) fatalOnError(err)
@ -706,7 +708,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH
checkPermissions(ctx, slogLogger, globalContext.workDir, confPath, dataDir, statsDir, querylogDir) checkPermissions(ctx, slogLogger, globalContext.workDir, confPath, dataDir, statsDir, querylogDir)
} }
globalContext.web.start(ctx) web.start(ctx)
// Wait for other goroutines to complete their job. // Wait for other goroutines to complete their job.
<-done <-done
@ -1058,8 +1060,15 @@ type jsonError struct {
Message string `json:"message"` Message string `json:"message"`
} }
// cmdlineUpdate updates current application and exits. l must not be nil. // cmdlineUpdate updates current application and exits. l and tlsMgr must not
func cmdlineUpdate(ctx context.Context, l *slog.Logger, opts options, upd *updater.Updater) { // be nil.
func cmdlineUpdate(
ctx context.Context,
l *slog.Logger,
opts options,
upd *updater.Updater,
tlsMgr *tlsManager,
) {
if !opts.performUpdate { if !opts.performUpdate {
return return
} }
@ -1069,7 +1078,7 @@ func cmdlineUpdate(ctx context.Context, l *slog.Logger, opts options, upd *updat
// //
// TODO(e.burkov): We could probably initialize the internal resolver // TODO(e.burkov): We could probably initialize the internal resolver
// separately. // separately.
err := initDNSServer(nil, nil, nil, nil, nil, nil, &tlsConfigSettings{}, l) err := initDNSServer(nil, nil, nil, nil, nil, nil, &tlsConfigSettings{}, tlsMgr, l)
fatalOnError(err) fatalOnError(err)
l.InfoContext(ctx, "performing update via cli") l.InfoContext(ctx, "performing update via cli")

View file

@ -10,10 +10,10 @@ import (
// httpClient returns a new HTTP client that uses the AdGuard Home's own DNS // httpClient returns a new HTTP client that uses the AdGuard Home's own DNS
// server for resolving hostnames. The resulting client should not be used // server for resolving hostnames. The resulting client should not be used
// until [Context.dnsServer] is initialized. // until [Context.dnsServer] is initialized. tlsMgr must not be nil.
// //
// TODO(a.garipov, e.burkov): This is rather messy. Refactor. // TODO(a.garipov, e.burkov): This is rather messy. Refactor.
func httpClient() (c *http.Client) { func httpClient(tlsMgr *tlsManager) (c *http.Client) {
// Do not use Context.dnsServer.DialContext directly in the struct literal // Do not use Context.dnsServer.DialContext directly in the struct literal
// below, since Context.dnsServer may be nil when this function is called. // below, since Context.dnsServer may be nil when this function is called.
dialContext := func(ctx context.Context, network, addr string) (conn net.Conn, err error) { dialContext := func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
@ -27,8 +27,8 @@ func httpClient() (c *http.Client) {
DialContext: dialContext, DialContext: dialContext,
Proxy: httpProxy, Proxy: httpProxy,
TLSClientConfig: &tls.Config{ TLSClientConfig: &tls.Config{
RootCAs: globalContext.tlsRoots, RootCAs: tlsMgr.rootCerts,
CipherSuites: globalContext.tlsCipherIDs, CipherSuites: tlsMgr.customCipherIDs,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
}, },
}, },

View file

@ -14,6 +14,7 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"net/http" "net/http"
"net/netip"
"os" "os"
"strings" "strings"
"sync" "sync"
@ -21,6 +22,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls" "github.com/AdguardTeam/AdGuardHome/internal/aghtls"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
@ -41,6 +43,22 @@ type tlsManager struct {
// certLastMod is the last modification time of the certificate file. // certLastMod is the last modification time of the certificate file.
certLastMod time.Time certLastMod time.Time
// rootCerts is a pool of root CAs for TLSv1.2.
rootCerts *x509.CertPool
// web is the web UI and API server. It must not be nil.
//
// TODO(s.chzhen): Temporary cyclic dependency due to ongoing refactoring.
// Resolve it.
web *webAPI
// configModified is called when the TLS configuration is changed via an
// HTTP request.
configModified func()
// customCipherIDs are the ID of the cipher suites that AdGuard Home must use.
customCipherIDs []uint16
confLock sync.Mutex confLock sync.Mutex
conf tlsConfigSettings conf tlsConfigSettings
@ -48,21 +66,50 @@ type tlsManager struct {
servePlainDNS bool servePlainDNS bool
} }
// tlsManagerConfig contains the settings for initializing the TLS manager.
type tlsManagerConfig struct {
// logger is used for logging the operation of the TLS Manager. It must not
// be nil.
logger *slog.Logger
// configModified is called when the TLS configuration is changed via an
// HTTP request. It must not be nil.
configModified func()
// tlsSettings contains the TLS configuration settings.
tlsSettings tlsConfigSettings
// servePlainDNS defines if plain DNS is allowed for incoming requests.
servePlainDNS bool
}
// newTLSManager initializes the manager of TLS configuration. m is always // newTLSManager initializes the manager of TLS configuration. m is always
// non-nil while any returned error indicates that the TLS configuration isn't // non-nil while any returned error indicates that the TLS configuration isn't
// valid. Thus TLS may be initialized later, e.g. via the web UI. logger must // valid. Thus TLS may be initialized later, e.g. via the web UI. conf must
// not be nil. // not be nil. Note that [tlsManager.web] must be initialized later on by using
func newTLSManager( // [tlsManager.setWebAPI].
ctx context.Context, func newTLSManager(ctx context.Context, conf *tlsManagerConfig) (m *tlsManager, err error) {
logger *slog.Logger,
conf tlsConfigSettings,
servePlainDNS bool,
) (m *tlsManager, err error) {
m = &tlsManager{ m = &tlsManager{
logger: logger, logger: conf.logger,
status: &tlsConfigStatus{}, configModified: conf.configModified,
conf: conf, status: &tlsConfigStatus{},
servePlainDNS: servePlainDNS, conf: conf.tlsSettings,
servePlainDNS: conf.servePlainDNS,
}
m.rootCerts = aghtls.SystemRootCAs()
if len(conf.tlsSettings.OverrideTLSCiphers) > 0 {
m.customCipherIDs, err = aghtls.ParseCiphers(config.TLS.OverrideTLSCiphers)
if err != nil {
// Should not happen because upstreams are already validated. See
// [validateTLSCipherIDs].
panic(err)
}
m.logger.InfoContext(ctx, "overriding ciphers", "ciphers", config.TLS.OverrideTLSCiphers)
} else {
m.logger.InfoContext(ctx, "using default ciphers")
} }
if m.conf.Enabled { if m.conf.Enabled {
@ -79,6 +126,15 @@ func newTLSManager(
return m, nil return m, nil
} }
// setWebAPI stores the provided web API. It must be called before
// [tlsManager.start], [tlsManager.reload], [tlsManager.handleTLSConfigure], or
// [tlsManager.validateTLSSettings].
//
// TODO(s.chzhen): Remove it once cyclic dependency is resolved.
func (m *tlsManager) setWebAPI(webAPI *webAPI) {
m.web = webAPI
}
// load reloads the TLS configuration from files or data from the config file. // load reloads the TLS configuration from files or data from the config file.
func (m *tlsManager) load(ctx context.Context) (err error) { func (m *tlsManager) load(ctx context.Context) (err error) {
err = m.loadTLSConf(ctx, &m.conf, m.status) err = m.loadTLSConf(ctx, &m.conf, m.status)
@ -126,7 +182,7 @@ func (m *tlsManager) start(_ context.Context) {
// The background context is used because the TLSConfigChanged wraps context // The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current // with timeout on its own and shuts down the server, which handles current
// request. // request.
globalContext.web.tlsConfigChanged(context.Background(), tlsConf) m.web.tlsConfigChanged(context.Background(), tlsConf)
} }
// reload updates the configuration and restarts the TLS manager. // reload updates the configuration and restarts the TLS manager.
@ -178,7 +234,7 @@ func (m *tlsManager) reload(ctx context.Context) {
// The background context is used because the TLSConfigChanged wraps context // The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current // with timeout on its own and shuts down the server, which handles current
// request. // request.
globalContext.web.tlsConfigChanged(context.Background(), tlsConf) m.web.tlsConfigChanged(context.Background(), tlsConf)
} }
// reconfigureDNSServer updates the DNS server configuration using the stored // reconfigureDNSServer updates the DNS server configuration using the stored
@ -191,6 +247,7 @@ func (m *tlsManager) reconfigureDNSServer() (err error) {
&config.DNS, &config.DNS,
config.Clients.Sources, config.Clients.Sources,
tlsConf, tlsConf,
m,
httpRegister, httpRegister,
globalContext.clients.storage, globalContext.clients.storage,
) )
@ -368,6 +425,8 @@ func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
// handleTLSValidate is the handler for the POST /control/tls/validate HTTP API. // handleTLSValidate is the handler for the POST /control/tls/validate HTTP API.
func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) { func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
setts, err := unmarshalTLS(r) setts, err := unmarshalTLS(r)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
@ -379,7 +438,9 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
setts.PrivateKey = m.conf.PrivateKey setts.PrivateKey = m.conf.PrivateKey
} }
if err = validateTLSSettings(setts); err != nil { if err = m.validateTLSSettings(setts); err != nil {
m.logger.InfoContext(ctx, "validating tls settings", slogutil.KeyError, err)
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return return
@ -388,7 +449,7 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
// Skip the error check, since we are only interested in the value of // Skip the error check, since we are only interested in the value of
// status.WarningValidation. // status.WarningValidation.
status := &tlsConfigStatus{} status := &tlsConfigStatus{}
_ = m.loadTLSConf(r.Context(), &setts.tlsConfigSettings, status) _ = m.loadTLSConf(ctx, &setts.tlsConfigSettings, status)
resp := tlsConfig{ resp := tlsConfig{
tlsConfigSettingsExt: setts, tlsConfigSettingsExt: setts,
tlsConfigStatus: status, tlsConfigStatus: status,
@ -458,7 +519,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
req.PrivateKey = m.conf.PrivateKey req.PrivateKey = m.conf.PrivateKey
} }
if err = validateTLSSettings(req); err != nil { if err = m.validateTLSSettings(req); err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return return
@ -489,7 +550,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
}() }()
} }
onConfigModified() m.configModified()
err = m.reconfigureDNSServer() err = m.reconfigureDNSServer()
if err != nil { if err != nil {
@ -516,36 +577,54 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
// same reason. // same reason.
if restartHTTPS { if restartHTTPS {
go func() { go func() {
globalContext.web.tlsConfigChanged(context.Background(), req.tlsConfigSettings) m.web.tlsConfigChanged(context.Background(), req.tlsConfigSettings)
}() }()
} }
} }
// validateTLSSettings returns error if the setts are not valid. // validateTLSSettings returns error if the setts are not valid.
func validateTLSSettings(setts tlsConfigSettingsExt) (err error) { func (m *tlsManager) validateTLSSettings(setts tlsConfigSettingsExt) (err error) {
if setts.Enabled { if !setts.Enabled {
err = validatePorts( if setts.ServePlainDNS == aghalg.NBFalse {
tcpPort(config.HTTPConfig.Address.Port()), // TODO(a.garipov): Support full disabling of all DNS.
tcpPort(setts.PortHTTPS), return errors.Error("plain DNS is required in case encryption protocols are disabled")
tcpPort(setts.PortDNSOverTLS),
tcpPort(setts.PortDNSCrypt),
udpPort(config.DNS.Port),
udpPort(setts.PortDNSOverQUIC),
)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
} }
} else if setts.ServePlainDNS == aghalg.NBFalse {
// TODO(a.garipov): Support full disabling of all DNS. return nil
return errors.Error("plain DNS is required in case encryption protocols are disabled")
} }
if !webCheckPortAvailable(setts.PortHTTPS) { var (
return fmt.Errorf("port %d is not available, cannot enable HTTPS on it", setts.PortHTTPS) tlsConf tlsConfigSettings
webAPIAddr netip.Addr
webAPIPort uint16
plainDNSPort uint16
)
func() {
config.Lock()
defer config.Unlock()
tlsConf = config.TLS
webAPIAddr = config.HTTPConfig.Address.Addr()
webAPIPort = config.HTTPConfig.Address.Port()
plainDNSPort = config.DNS.Port
}()
err = validatePorts(
tcpPort(webAPIPort),
tcpPort(setts.PortHTTPS),
tcpPort(setts.PortDNSOverTLS),
tcpPort(setts.PortDNSCrypt),
udpPort(plainDNSPort),
udpPort(setts.PortDNSOverQUIC),
)
if err != nil {
// Don't wrap the error because it's informative enough as is.
return err
} }
return nil // Don't wrap the error because it's informative enough as is.
return m.checkPortAvailability(tlsConf, setts.tlsConfigSettings, webAPIAddr)
} }
// validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home // validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home
@ -557,10 +636,11 @@ func validatePorts(
tcpPorts := aghalg.UniqChecker[tcpPort]{} tcpPorts := aghalg.UniqChecker[tcpPort]{}
addPorts( addPorts(
tcpPorts, tcpPorts,
tcpPort(bindPort), bindPort,
tcpPort(dohPort), dohPort,
tcpPort(dotPort), dotPort,
tcpPort(dnscryptTCPPort), dnscryptTCPPort,
tcpPort(dnsPort),
) )
err = tcpPorts.Validate() err = tcpPorts.Validate()
@ -569,7 +649,7 @@ func validatePorts(
} }
udpPorts := aghalg.UniqChecker[udpPort]{} udpPorts := aghalg.UniqChecker[udpPort]{}
addPorts(udpPorts, udpPort(dnsPort), udpPort(doqPort)) addPorts(udpPorts, dnsPort, doqPort)
err = udpPorts.Validate() err = udpPorts.Validate()
if err != nil { if err != nil {
@ -604,7 +684,7 @@ func (m *tlsManager) validateCertChain(
opts := x509.VerifyOptions{ opts := x509.VerifyOptions{
DNSName: srvName, DNSName: srvName,
Roots: globalContext.tlsRoots, Roots: m.rootCerts,
Intermediates: pool, Intermediates: pool,
} }
_, err = main.Verify(opts) _, err = main.Verify(opts)
@ -615,6 +695,67 @@ func (m *tlsManager) validateCertChain(
return nil return nil
} }
// checkPortAvailability checks [tlsConfigSettings.PortHTTPS],
// [tlsConfigSettings.PortDNSOverTLS], and [tlsConfigSettings.PortDNSOverQUIC]
// are available for use. It checks the current configuration and, if needed,
// attempts to bind to the port. The function returns human-readable error
// messages for the frontend. This is best-effort check to prevent an "address
// already in use" error.
//
// TODO(a.garipov): Adapt for HTTP/3.
func (m *tlsManager) checkPortAvailability(
currConf tlsConfigSettings,
newConf tlsConfigSettings,
addr netip.Addr,
) (err error) {
const (
networkTCP = "tcp"
networkUDP = "udp"
protoHTTPS = "HTTPS"
protoDoT = "DNS-over-TLS"
protoDoQ = "DNS-over-QUIC"
)
needBindingCheck := []struct {
network string
proto string
currPort uint16
newPort uint16
}{{
network: networkTCP,
proto: protoHTTPS,
currPort: currConf.PortHTTPS,
newPort: newConf.PortHTTPS,
}, {
network: networkTCP,
proto: protoDoT,
currPort: currConf.PortDNSOverTLS,
newPort: newConf.PortDNSOverTLS,
}, {
network: networkUDP,
proto: protoDoQ,
currPort: currConf.PortDNSOverQUIC,
newPort: newConf.PortDNSOverQUIC,
}}
var errs []error
for _, v := range needBindingCheck {
port := v.newPort
if v.currPort == port {
continue
}
addrPort := netip.AddrPortFrom(addr, port)
err = aghnet.CheckPort(v.network, addrPort)
if err != nil {
errs = append(errs, fmt.Errorf("port %d for %s is not available", port, v.proto))
}
}
return errors.Join(errs...)
}
// errNoIPInCert is the error that is returned from [tlsManager.parseCertChain] // errNoIPInCert is the error that is returned from [tlsManager.parseCertChain]
// if the leaf certificate doesn't contain IPs. // if the leaf certificate doesn't contain IPs.
const errNoIPInCert errors.Error = `certificates has no IP addresses; ` + const errNoIPInCert errors.Error = `certificates has no IP addresses; ` +
@ -718,27 +859,12 @@ func (m *tlsManager) validateCertificates(
) (err error) { ) (err error) {
// Check only the public certificate separately from the key. // Check only the public certificate separately from the key.
if len(certChain) > 0 { if len(certChain) > 0 {
var certs []*x509.Certificate var ok bool
certs, status.ValidCert, err = m.parseCertChain(ctx, certChain) ok, err = m.validateCertificate(ctx, status, certChain, serverName)
if !status.ValidCert { if !ok {
// Don't wrap the error, since it's informative enough as is. // Don't wrap the error, since it's informative enough as is.
return err return err
} }
mainCert := certs[0]
status.Subject = mainCert.Subject.String()
status.Issuer = mainCert.Issuer.String()
status.NotAfter = mainCert.NotAfter
status.NotBefore = mainCert.NotBefore
status.DNSNames = mainCert.DNSNames
if chainErr := m.validateCertChain(ctx, certs, serverName); chainErr != nil {
// Let self-signed certs through and don't return this error to set
// its message into the status.WarningValidation afterwards.
err = chainErr
} else {
status.ValidChain = true
}
} }
// Validate the private key by parsing it. // Validate the private key by parsing it.
@ -766,6 +892,41 @@ func (m *tlsManager) validateCertificates(
return err return err
} }
// validateCertificate processes certificate data. status must not be nil, as
// it is used to accumulate the validation results. Other parameters are
// optional. If ok is true, the returned error, if any, is not critical.
func (m *tlsManager) validateCertificate(
ctx context.Context,
status *tlsConfigStatus,
certChain []byte,
serverName string,
) (ok bool, err error) {
var certs []*x509.Certificate
certs, status.ValidCert, err = m.parseCertChain(ctx, certChain)
if !status.ValidCert {
// Don't wrap the error, since it's informative enough as is.
return false, err
}
mainCert := certs[0]
status.Subject = mainCert.Subject.String()
status.Issuer = mainCert.Issuer.String()
status.NotAfter = mainCert.NotAfter
status.NotBefore = mainCert.NotBefore
status.DNSNames = mainCert.DNSNames
err = m.validateCertChain(ctx, certs, serverName)
if err != nil {
// Let self-signed certs through and don't return this error to set
// its message into the status.WarningValidation afterwards.
return true, err
}
status.ValidChain = true
return true, nil
}
// Key types. // Key types.
const ( const (
keyTypeECDSA = "ECDSA" keyTypeECDSA = "ECDSA"
@ -828,17 +989,18 @@ func unmarshalTLS(r *http.Request) (tlsConfigSettingsExt, error) {
} }
} }
if data.PrivateKey != "" { if data.PrivateKey == "" {
var key []byte return data, nil
key, err = base64.StdEncoding.DecodeString(data.PrivateKey) }
if err != nil {
return data, fmt.Errorf("failed to base64-decode private key: %w", err)
}
data.PrivateKey = string(key) key, err := base64.StdEncoding.DecodeString(data.PrivateKey)
if data.PrivateKeyPath != "" { if err != nil {
return data, fmt.Errorf("private key data and file can't be set together") return data, fmt.Errorf("failed to base64-decode private key: %w", err)
} }
data.PrivateKey = string(key)
if data.PrivateKeyPath != "" {
return data, fmt.Errorf("private key data and file can't be set together")
} }
return data, nil return data, nil

View file

@ -30,6 +30,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// TODO(s.chzhen): Consider moving to testdata.
var testCertChainData = []byte(`-----BEGIN CERTIFICATE----- var testCertChainData = []byte(`-----BEGIN CERTIFICATE-----
MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV
BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3 BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3
@ -66,7 +67,11 @@ func TestValidateCertificates(t *testing.T) {
ctx := testutil.ContextWithTimeout(t, testTimeout) ctx := testutil.ContextWithTimeout(t, testTimeout)
logger := slogutil.NewDiscardLogger() logger := slogutil.NewDiscardLogger()
m, err := newTLSManager(ctx, logger, tlsConfigSettings{}, false) m, err := newTLSManager(ctx, &tlsManagerConfig{
logger: logger,
configModified: func() {},
servePlainDNS: false,
})
require.NoError(t, err) require.NoError(t, err)
t.Run("bad_certificate", func(t *testing.T) { t.Run("bad_certificate", func(t *testing.T) {
@ -112,7 +117,6 @@ func TestValidateCertificates(t *testing.T) {
// - [homeContext.clients.storage] // - [homeContext.clients.storage]
// - [homeContext.dnsServer] // - [homeContext.dnsServer]
// - [homeContext.mux] // - [homeContext.mux]
// - [homeContext.web]
// //
// TODO(s.chzhen): Remove this once the TLS manager no longer accesses global // TODO(s.chzhen): Remove this once the TLS manager no longer accesses global
// variables. Make tests that use this helper concurrent. // variables. Make tests that use this helper concurrent.
@ -123,14 +127,12 @@ func storeGlobals(tb testing.TB) {
storage := globalContext.clients.storage storage := globalContext.clients.storage
dnsServer := globalContext.dnsServer dnsServer := globalContext.dnsServer
mux := globalContext.mux mux := globalContext.mux
web := globalContext.web
tb.Cleanup(func() { tb.Cleanup(func() {
config = prevConfig config = prevConfig
globalContext.clients.storage = storage globalContext.clients.storage = storage
globalContext.dnsServer = dnsServer globalContext.dnsServer = dnsServer
globalContext.mux = mux globalContext.mux = mux
globalContext.web = web
}) })
} }
@ -221,9 +223,6 @@ func TestTLSManager_Reload(t *testing.T) {
globalContext.mux = http.NewServeMux() globalContext.mux = http.NewServeMux()
globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false)
require.NoError(t, err)
const ( const (
snBefore int64 = 1 snBefore int64 = 1
snAfter int64 = 2 snAfter int64 = 2
@ -236,15 +235,25 @@ func TestTLSManager_Reload(t *testing.T) {
certDER, key := newCertAndKey(t, snBefore) certDER, key := newCertAndKey(t, snBefore)
writeCertAndKey(t, certDER, certPath, key, keyPath) writeCertAndKey(t, certDER, certPath, key, keyPath)
m, err := newTLSManager(ctx, logger, tlsConfigSettings{ m, err := newTLSManager(ctx, &tlsManagerConfig{
Enabled: true, logger: logger,
TLSConfig: dnsforward.TLSConfig{ configModified: func() {},
CertificatePath: certPath, tlsSettings: tlsConfigSettings{
PrivateKeyPath: keyPath, Enabled: true,
TLSConfig: dnsforward.TLSConfig{
CertificatePath: certPath,
PrivateKeyPath: keyPath,
},
}, },
}, false) servePlainDNS: false,
})
require.NoError(t, err) require.NoError(t, err)
web, err := initWeb(ctx, options{}, nil, nil, logger, nil, false)
require.NoError(t, err)
m.setWebAPI(web)
conf := &tlsConfigSettings{} conf := &tlsConfigSettings{}
m.WriteDiskConfig(conf) m.WriteDiskConfig(conf)
assertCertSerialNumber(t, conf, snBefore) assertCertSerialNumber(t, conf, snBefore)
@ -265,13 +274,18 @@ func TestTLSManager_HandleTLSStatus(t *testing.T) {
err error err error
) )
m, err := newTLSManager(ctx, logger, tlsConfigSettings{ m, err := newTLSManager(ctx, &tlsManagerConfig{
Enabled: true, logger: logger,
TLSConfig: dnsforward.TLSConfig{ configModified: func() {},
CertificateChain: string(testCertChainData), tlsSettings: tlsConfigSettings{
PrivateKey: string(testPrivateKeyData), Enabled: true,
TLSConfig: dnsforward.TLSConfig{
CertificateChain: string(testCertChainData),
PrivateKey: string(testPrivateKeyData),
},
}, },
}, false) servePlainDNS: false,
})
require.NoError(t, err) require.NoError(t, err)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -291,26 +305,42 @@ func TestTLSManager_HandleTLSStatus(t *testing.T) {
func TestValidateTLSSettings(t *testing.T) { func TestValidateTLSSettings(t *testing.T) {
storeGlobals(t) storeGlobals(t)
globalContext.mux = http.NewServeMux()
var ( var (
logger = slogutil.NewDiscardLogger() logger = slogutil.NewDiscardLogger()
ctx = testutil.ContextWithTimeout(t, testTimeout) ctx = testutil.ContextWithTimeout(t, testTimeout)
err error err error
) )
ln, err := net.Listen("tcp", ":0") m, err := newTLSManager(ctx, &tlsManagerConfig{
logger: logger,
configModified: func() {},
servePlainDNS: false,
})
require.NoError(t, err) require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, ln.Close) web, err := initWeb(ctx, options{}, nil, nil, logger, nil, false)
addr := testutil.RequireTypeAssert[*net.TCPAddr](t, ln.Addr())
busyPort := addr.Port
globalContext.mux = http.NewServeMux()
globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false)
require.NoError(t, err) require.NoError(t, err)
m.setWebAPI(web)
tcpLn, err := net.Listen("tcp", ":0")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, tcpLn.Close)
tcpAddr := testutil.RequireTypeAssert[*net.TCPAddr](t, tcpLn.Addr())
busyTCPPort := tcpAddr.Port
udpLn, err := net.ListenPacket("udp", ":0")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, udpLn.Close)
udpAddr := testutil.RequireTypeAssert[*net.UDPAddr](t, udpLn.LocalAddr())
busyUDPPort := udpAddr.Port
testCases := []struct { testCases := []struct {
setts tlsConfigSettingsExt setts tlsConfigSettingsExt
name string name string
@ -329,11 +359,29 @@ func TestValidateTLSSettings(t *testing.T) {
setts: tlsConfigSettingsExt{ setts: tlsConfigSettingsExt{
tlsConfigSettings: tlsConfigSettings{ tlsConfigSettings: tlsConfigSettings{
Enabled: true, Enabled: true,
PortHTTPS: uint16(busyPort), PortHTTPS: uint16(busyTCPPort),
}, },
}, },
name: "busy_port", name: "busy_https_port",
wantErr: fmt.Sprintf("port %d is not available, cannot enable HTTPS on it", busyPort), wantErr: fmt.Sprintf("port %d for HTTPS is not available", busyTCPPort),
}, {
setts: tlsConfigSettingsExt{
tlsConfigSettings: tlsConfigSettings{
Enabled: true,
PortDNSOverTLS: uint16(busyTCPPort),
},
},
name: "busy_dot_port",
wantErr: fmt.Sprintf("port %d for DNS-over-TLS is not available", busyTCPPort),
}, {
setts: tlsConfigSettingsExt{
tlsConfigSettings: tlsConfigSettings{
Enabled: true,
PortDNSOverQUIC: uint16(busyUDPPort),
},
},
name: "busy_doq_port",
wantErr: fmt.Sprintf("port %d for DNS-over-QUIC is not available", busyUDPPort),
}, { }, {
setts: tlsConfigSettingsExt{ setts: tlsConfigSettingsExt{
tlsConfigSettings: tlsConfigSettings{ tlsConfigSettings: tlsConfigSettings{
@ -348,7 +396,7 @@ func TestValidateTLSSettings(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
err = validateTLSSettings(tc.setts) err = m.validateTLSSettings(tc.setts)
testutil.AssertErrorMsg(t, tc.wantErr, err) testutil.AssertErrorMsg(t, tc.wantErr, err)
}) })
} }
@ -357,26 +405,33 @@ func TestValidateTLSSettings(t *testing.T) {
func TestTLSManager_HandleTLSValidate(t *testing.T) { func TestTLSManager_HandleTLSValidate(t *testing.T) {
storeGlobals(t) storeGlobals(t)
globalContext.mux = http.NewServeMux()
var ( var (
logger = slogutil.NewDiscardLogger() logger = slogutil.NewDiscardLogger()
ctx = testutil.ContextWithTimeout(t, testTimeout) ctx = testutil.ContextWithTimeout(t, testTimeout)
err error err error
) )
globalContext.mux = http.NewServeMux() m, err := newTLSManager(ctx, &tlsManagerConfig{
logger: logger,
globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false) configModified: func() {},
require.NoError(t, err) tlsSettings: tlsConfigSettings{
Enabled: true,
m, err := newTLSManager(ctx, logger, tlsConfigSettings{ TLSConfig: dnsforward.TLSConfig{
Enabled: true, CertificateChain: string(testCertChainData),
TLSConfig: dnsforward.TLSConfig{ PrivateKey: string(testPrivateKeyData),
CertificateChain: string(testCertChainData), },
PrivateKey: string(testPrivateKeyData),
}, },
}, false) servePlainDNS: false,
})
require.NoError(t, err) require.NoError(t, err)
web, err := initWeb(ctx, options{}, nil, nil, logger, nil, false)
require.NoError(t, err)
m.setWebAPI(web)
setts := &tlsConfigSettingsExt{ setts := &tlsConfigSettingsExt{
tlsConfigSettings: tlsConfigSettings{ tlsConfigSettings: tlsConfigSettings{
Enabled: true, Enabled: true,
@ -438,9 +493,6 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
globalContext.mux = http.NewServeMux() globalContext.mux = http.NewServeMux()
globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false)
require.NoError(t, err)
config.DNS.BindHosts = []netip.Addr{netip.MustParseAddr("127.0.0.1")} config.DNS.BindHosts = []netip.Addr{netip.MustParseAddr("127.0.0.1")}
config.DNS.Port = 0 config.DNS.Port = 0
@ -455,15 +507,25 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
writeCertAndKey(t, certDER, certPath, key, keyPath) writeCertAndKey(t, certDER, certPath, key, keyPath)
// Initialize the TLS manager and assert its configuration. // Initialize the TLS manager and assert its configuration.
m, err := newTLSManager(ctx, logger, tlsConfigSettings{ m, err := newTLSManager(ctx, &tlsManagerConfig{
Enabled: true, logger: logger,
TLSConfig: dnsforward.TLSConfig{ configModified: func() {},
CertificatePath: certPath, tlsSettings: tlsConfigSettings{
PrivateKeyPath: keyPath, Enabled: true,
TLSConfig: dnsforward.TLSConfig{
CertificatePath: certPath,
PrivateKeyPath: keyPath,
},
}, },
}, true) servePlainDNS: true,
})
require.NoError(t, err) require.NoError(t, err)
web, err := initWeb(ctx, options{}, nil, nil, logger, nil, false)
require.NoError(t, err)
m.setWebAPI(web)
conf := &tlsConfigSettings{} conf := &tlsConfigSettings{}
m.WriteDiskConfig(conf) m.WriteDiskConfig(conf)
assertCertSerialNumber(t, conf, wantSerialNumber) assertCertSerialNumber(t, conf, wantSerialNumber)
@ -509,10 +571,10 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
// //
// TODO(s.chzhen): Remove when [httpsServer.cond] is removed. // TODO(s.chzhen): Remove when [httpsServer.cond] is removed.
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
globalContext.web.httpsServer.condLock.Lock() web.httpsServer.condLock.Lock()
defer globalContext.web.httpsServer.condLock.Unlock() defer web.httpsServer.condLock.Unlock()
cert = globalContext.web.httpsServer.cert cert = web.httpsServer.cert
if cert.Leaf == nil { if cert.Leaf == nil {
return false return false
} }

View file

@ -12,10 +12,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/updater" "github.com/AdguardTeam/AdGuardHome/internal/updater"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/netutil/httputil" "github.com/AdguardTeam/golibs/netutil/httputil"
@ -158,27 +156,6 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
return w return w
} }
// webCheckPortAvailable checks if port, which is considered an HTTPS port, is
// available, unless the HTTPS server isn't active.
//
// TODO(a.garipov): Adapt for HTTP/3.
func webCheckPortAvailable(port uint16) (ok bool) {
if globalContext.web.httpsServer.server != nil {
return true
}
addrPort := netip.AddrPortFrom(config.HTTPConfig.Address.Addr(), port)
err := aghnet.CheckPort("tcp", addrPort)
if err != nil {
log.Info("web: warning: checking https port: %s", err)
return false
}
return true
}
// tlsConfigChanged updates the TLS configuration and restarts the HTTPS server // tlsConfigChanged updates the TLS configuration and restarts the HTTPS server
// if necessary. // if necessary.
func (web *webAPI) tlsConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) { func (web *webAPI) tlsConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) {
@ -329,8 +306,8 @@ func (web *webAPI) tlsServerLoop(ctx context.Context) {
Handler: hdlr, Handler: hdlr,
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
Certificates: []tls.Certificate{web.httpsServer.cert}, Certificates: []tls.Certificate{web.httpsServer.cert},
RootCAs: globalContext.tlsRoots, RootCAs: web.tlsManager.rootCerts,
CipherSuites: globalContext.tlsCipherIDs, CipherSuites: web.tlsManager.customCipherIDs,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
}, },
ReadTimeout: web.conf.ReadTimeout, ReadTimeout: web.conf.ReadTimeout,
@ -363,8 +340,8 @@ func (web *webAPI) mustStartHTTP3(ctx context.Context, address string) {
Addr: address, Addr: address,
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
Certificates: []tls.Certificate{web.httpsServer.cert}, Certificates: []tls.Certificate{web.httpsServer.cert},
RootCAs: globalContext.tlsRoots, RootCAs: web.tlsManager.rootCerts,
CipherSuites: globalContext.tlsCipherIDs, CipherSuites: web.tlsManager.customCipherIDs,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
}, },
Handler: withMiddlewares(globalContext.mux, limitRequestBody), Handler: withMiddlewares(globalContext.mux, limitRequestBody),