diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 0c3531e4..011a54f4 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -22,7 +22,6 @@ import ( "github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/ameshkov/dnscrypt/v2" - "golang.org/x/exp/slices" ) // BlockingMode is an enum of all allowed blocking modes. @@ -186,11 +185,6 @@ type TLSConfig struct { hasIPAddrs bool } -// CertDataClone returns a deep copy of certificate data. -func (c TLSConfig) CertDataClone() (certData, keyData []byte) { - return slices.Clone(c.CertificateChainData), slices.Clone(c.PrivateKeyData) -} - // DNSCryptConfig is the DNSCrypt server configuration struct. type DNSCryptConfig struct { ResolverCert *dnscrypt.Cert diff --git a/internal/home/config.go b/internal/home/config.go index 7a03eef4..c5eec345 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -225,18 +225,25 @@ type tlsConfiguration struct { dnsforward.TLSConfig `yaml:",inline" json:",inline"` } -// partialClone returns a clone of c with all top-level fields of c and all +// cloneForEncoding returns a clone of c with all top-level fields of c and all // exported and YAML-encoded fields of c.TLSConfig cloned. // // TODO(a.garipov): This is better than races, but still not good enough. -func (c *tlsConfiguration) partialClone() (cloned *tlsConfiguration) { +func (c *tlsConfiguration) cloneForEncoding() (cloned *tlsConfiguration) { if c == nil { return nil } v := *c cloned = &v - cloned.OverrideTLSCiphers = slices.Clone(c.OverrideTLSCiphers) + cloned.TLSConfig = dnsforward.TLSConfig{ + CertificateChain: c.CertificateChain, + PrivateKey: c.PrivateKey, + CertificatePath: c.CertificatePath, + PrivateKeyPath: c.PrivateKeyPath, + OverrideTLSCiphers: slices.Clone(c.OverrideTLSCiphers), + StrictSNICheck: c.StrictSNICheck, + } return cloned } diff --git a/internal/home/tls.go b/internal/home/tls.go index 1c2036cc..dce777f4 100644 --- a/internal/home/tls.go +++ b/internal/home/tls.go @@ -23,15 +23,15 @@ import ( // tlsManager contains the current configuration and state of AdGuard Home TLS // encryption. type tlsManager struct { + // mu protects all fields. + mu *sync.RWMutex + // certLastMod is the last modification time of the certificate file. certLastMod time.Time // status is the current status of the configuration. It is never nil. status *tlsConfigStatus - // confMu protects conf. - confMu *sync.RWMutex - // conf is the current TLS configuration. conf *tlsConfiguration } @@ -40,7 +40,7 @@ type tlsManager struct { func newTLSManager(conf *tlsConfiguration) (m *tlsManager, err error) { m = &tlsManager{ status: &tlsConfigStatus{}, - confMu: &sync.RWMutex{}, + mu: &sync.RWMutex{}, conf: conf, } @@ -56,17 +56,17 @@ func newTLSManager(conf *tlsConfiguration) (m *tlsManager, err error) { return m, nil } -// partialTLSConf returns a partial clone of the current TLS configuration. It +// confForEncoding returns a partial clone of the current TLS configuration. It // is safe for concurrent use. -func (m *tlsManager) partialTLSConf() (conf *tlsConfiguration) { - m.confMu.RLock() - defer m.confMu.RUnlock() +func (m *tlsManager) confForEncoding() (conf *tlsConfiguration) { + m.mu.RLock() + defer m.mu.RUnlock() - return m.conf.partialClone() + return m.conf.cloneForEncoding() } // load reloads the TLS configuration from files or data from the config file. -// load assumes that m.confLock is locked for writing. +// m.mu is expected to be locked for writing. func (m *tlsManager) load() (err error) { err = loadTLSConf(m.conf, m.status) if err != nil { @@ -78,11 +78,11 @@ func (m *tlsManager) load() (err error) { // WriteDiskConfig - write config func (m *tlsManager) WriteDiskConfig(conf *tlsConfiguration) { - *conf = *m.partialTLSConf() + *conf = *m.confForEncoding() } // setCertFileTime sets t.certLastMod from the certificate. If there are -// errors, setCertFileTime logs them. +// errors, setCertFileTime logs them. mu is expected to be locked for writing. func (m *tlsManager) setCertFileTime() { if len(m.conf.CertificatePath) == 0 { return @@ -105,13 +105,13 @@ func (m *tlsManager) start() { // The background context is used because the TLSConfigChanged wraps context // with timeout on its own and shuts down the server, which handles current // request. - Context.web.TLSConfigChanged(context.Background(), m.partialTLSConf()) + Context.web.TLSConfigChanged(context.Background(), m.confForEncoding()) } // reload updates the configuration and restarts m. func (m *tlsManager) reload() { - m.confMu.Lock() - defer m.confMu.Unlock() + m.mu.Lock() + defer m.mu.Unlock() if !m.conf.Enabled || len(m.conf.CertificatePath) == 0 { return diff --git a/internal/home/tlshttp.go b/internal/home/tlshttp.go index de4c84ef..31b205e5 100644 --- a/internal/home/tlshttp.go +++ b/internal/home/tlshttp.go @@ -77,15 +77,23 @@ type tlsConfigReq struct { PrivateKeySaved bool `yaml:"-" json:"private_key_saved"` } +// handleTLSStatus is the handler for the GET /control/tls/status HTTP API. func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) { - resp := &tlsConfigResp{ - tlsConfigStatus: m.status, - tlsConfiguration: m.partialTLSConf(), - } + var resp *tlsConfigResp + func() { + m.mu.RLock() + defer m.mu.RUnlock() + + resp = &tlsConfigResp{ + tlsConfigStatus: m.status, + tlsConfiguration: m.conf.cloneForEncoding(), + } + }() marshalTLS(w, r, resp) } +// handleTLSValidate is the handler for the POST /control/tls/validate HTTP API. func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) { req, err := unmarshalTLS(r) if err != nil { @@ -95,7 +103,7 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) { } if req.PrivateKeySaved { - req.PrivateKey = m.conf.PrivateKey + req.PrivateKey = m.confForEncoding().PrivateKey } if req.Enabled { @@ -127,12 +135,13 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) { return } - // Skip the error check, since we are only interested in the value of - // status.WarningValidation. resp := &tlsConfigResp{ tlsConfigStatus: &tlsConfigStatus{}, tlsConfiguration: &req.tlsConfiguration, } + + // Skip the error check, since we are only interested in the value of + // resl.tlsConfigStatus.WarningValidation. _ = loadTLSConf(resp.tlsConfiguration, resp.tlsConfigStatus) marshalTLS(w, r, resp) @@ -170,6 +179,8 @@ func validatePorts( return nil } +// handleTLSConfigure is the handler for the POST /control/tls/configure HTTP +// API. func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) { req, err := unmarshalTLS(r) if err != nil { @@ -179,7 +190,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) } if req.PrivateKeySaved { - req.PrivateKey = m.partialTLSConf().PrivateKey + req.PrivateKey = m.confForEncoding().PrivateKey } if req.Enabled { @@ -224,7 +235,6 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) } restartRequired := m.setConf(resp) - m.setCertFileTime() onConfigModified() err = reconfigureDNSServer() @@ -234,7 +244,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) return } - resp.tlsConfiguration = m.partialTLSConf() + resp.tlsConfiguration = m.confForEncoding() marshalTLS(w, r, resp) if f, ok := w.(http.Flusher); ok { f.Flush() @@ -253,8 +263,8 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) // setConf sets the necessary values from the new configuration. func (m *tlsManager) setConf(newConf *tlsConfigResp) (restartRequired bool) { - m.confMu.Lock() - defer m.confMu.Unlock() + m.mu.Lock() + defer m.mu.Unlock() // Reset the DNSCrypt data before comparing, since we currently do not // accept these from the frontend. @@ -285,6 +295,8 @@ func (m *tlsManager) setConf(newConf *tlsConfigResp) (restartRequired bool) { m.conf.PrivateKeyPath = newConf.PrivateKeyPath m.conf.PrivateKeyData = newConf.PrivateKeyData + m.setCertFileTime() + m.status = newConf.tlsConfigStatus return restartRequired