home: refactor more

This commit is contained in:
Ainar Garipov 2022-11-21 19:45:18 +03:00
parent a8850059db
commit f36efa26a4
4 changed files with 49 additions and 36 deletions

View file

@ -22,7 +22,6 @@ import (
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/ameshkov/dnscrypt/v2"
"golang.org/x/exp/slices"
)
// BlockingMode is an enum of all allowed blocking modes.
@ -186,11 +185,6 @@ type TLSConfig struct {
hasIPAddrs bool
}
// CertDataClone returns a deep copy of certificate data.
func (c TLSConfig) CertDataClone() (certData, keyData []byte) {
return slices.Clone(c.CertificateChainData), slices.Clone(c.PrivateKeyData)
}
// DNSCryptConfig is the DNSCrypt server configuration struct.
type DNSCryptConfig struct {
ResolverCert *dnscrypt.Cert

View file

@ -225,18 +225,25 @@ type tlsConfiguration struct {
dnsforward.TLSConfig `yaml:",inline" json:",inline"`
}
// partialClone returns a clone of c with all top-level fields of c and all
// cloneForEncoding returns a clone of c with all top-level fields of c and all
// exported and YAML-encoded fields of c.TLSConfig cloned.
//
// TODO(a.garipov): This is better than races, but still not good enough.
func (c *tlsConfiguration) partialClone() (cloned *tlsConfiguration) {
func (c *tlsConfiguration) cloneForEncoding() (cloned *tlsConfiguration) {
if c == nil {
return nil
}
v := *c
cloned = &v
cloned.OverrideTLSCiphers = slices.Clone(c.OverrideTLSCiphers)
cloned.TLSConfig = dnsforward.TLSConfig{
CertificateChain: c.CertificateChain,
PrivateKey: c.PrivateKey,
CertificatePath: c.CertificatePath,
PrivateKeyPath: c.PrivateKeyPath,
OverrideTLSCiphers: slices.Clone(c.OverrideTLSCiphers),
StrictSNICheck: c.StrictSNICheck,
}
return cloned
}

View file

@ -23,15 +23,15 @@ import (
// tlsManager contains the current configuration and state of AdGuard Home TLS
// encryption.
type tlsManager struct {
// mu protects all fields.
mu *sync.RWMutex
// certLastMod is the last modification time of the certificate file.
certLastMod time.Time
// status is the current status of the configuration. It is never nil.
status *tlsConfigStatus
// confMu protects conf.
confMu *sync.RWMutex
// conf is the current TLS configuration.
conf *tlsConfiguration
}
@ -40,7 +40,7 @@ type tlsManager struct {
func newTLSManager(conf *tlsConfiguration) (m *tlsManager, err error) {
m = &tlsManager{
status: &tlsConfigStatus{},
confMu: &sync.RWMutex{},
mu: &sync.RWMutex{},
conf: conf,
}
@ -56,17 +56,17 @@ func newTLSManager(conf *tlsConfiguration) (m *tlsManager, err error) {
return m, nil
}
// partialTLSConf returns a partial clone of the current TLS configuration. It
// confForEncoding returns a partial clone of the current TLS configuration. It
// is safe for concurrent use.
func (m *tlsManager) partialTLSConf() (conf *tlsConfiguration) {
m.confMu.RLock()
defer m.confMu.RUnlock()
func (m *tlsManager) confForEncoding() (conf *tlsConfiguration) {
m.mu.RLock()
defer m.mu.RUnlock()
return m.conf.partialClone()
return m.conf.cloneForEncoding()
}
// load reloads the TLS configuration from files or data from the config file.
// load assumes that m.confLock is locked for writing.
// m.mu is expected to be locked for writing.
func (m *tlsManager) load() (err error) {
err = loadTLSConf(m.conf, m.status)
if err != nil {
@ -78,11 +78,11 @@ func (m *tlsManager) load() (err error) {
// WriteDiskConfig - write config
func (m *tlsManager) WriteDiskConfig(conf *tlsConfiguration) {
*conf = *m.partialTLSConf()
*conf = *m.confForEncoding()
}
// setCertFileTime sets t.certLastMod from the certificate. If there are
// errors, setCertFileTime logs them.
// errors, setCertFileTime logs them. mu is expected to be locked for writing.
func (m *tlsManager) setCertFileTime() {
if len(m.conf.CertificatePath) == 0 {
return
@ -105,13 +105,13 @@ func (m *tlsManager) start() {
// The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current
// request.
Context.web.TLSConfigChanged(context.Background(), m.partialTLSConf())
Context.web.TLSConfigChanged(context.Background(), m.confForEncoding())
}
// reload updates the configuration and restarts m.
func (m *tlsManager) reload() {
m.confMu.Lock()
defer m.confMu.Unlock()
m.mu.Lock()
defer m.mu.Unlock()
if !m.conf.Enabled || len(m.conf.CertificatePath) == 0 {
return

View file

@ -77,15 +77,23 @@ type tlsConfigReq struct {
PrivateKeySaved bool `yaml:"-" json:"private_key_saved"`
}
// handleTLSStatus is the handler for the GET /control/tls/status HTTP API.
func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
resp := &tlsConfigResp{
tlsConfigStatus: m.status,
tlsConfiguration: m.partialTLSConf(),
}
var resp *tlsConfigResp
func() {
m.mu.RLock()
defer m.mu.RUnlock()
resp = &tlsConfigResp{
tlsConfigStatus: m.status,
tlsConfiguration: m.conf.cloneForEncoding(),
}
}()
marshalTLS(w, r, resp)
}
// handleTLSValidate is the handler for the POST /control/tls/validate HTTP API.
func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
req, err := unmarshalTLS(r)
if err != nil {
@ -95,7 +103,7 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
}
if req.PrivateKeySaved {
req.PrivateKey = m.conf.PrivateKey
req.PrivateKey = m.confForEncoding().PrivateKey
}
if req.Enabled {
@ -127,12 +135,13 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
return
}
// Skip the error check, since we are only interested in the value of
// status.WarningValidation.
resp := &tlsConfigResp{
tlsConfigStatus: &tlsConfigStatus{},
tlsConfiguration: &req.tlsConfiguration,
}
// Skip the error check, since we are only interested in the value of
// resl.tlsConfigStatus.WarningValidation.
_ = loadTLSConf(resp.tlsConfiguration, resp.tlsConfigStatus)
marshalTLS(w, r, resp)
@ -170,6 +179,8 @@ func validatePorts(
return nil
}
// handleTLSConfigure is the handler for the POST /control/tls/configure HTTP
// API.
func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
req, err := unmarshalTLS(r)
if err != nil {
@ -179,7 +190,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
}
if req.PrivateKeySaved {
req.PrivateKey = m.partialTLSConf().PrivateKey
req.PrivateKey = m.confForEncoding().PrivateKey
}
if req.Enabled {
@ -224,7 +235,6 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
}
restartRequired := m.setConf(resp)
m.setCertFileTime()
onConfigModified()
err = reconfigureDNSServer()
@ -234,7 +244,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
return
}
resp.tlsConfiguration = m.partialTLSConf()
resp.tlsConfiguration = m.confForEncoding()
marshalTLS(w, r, resp)
if f, ok := w.(http.Flusher); ok {
f.Flush()
@ -253,8 +263,8 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
// setConf sets the necessary values from the new configuration.
func (m *tlsManager) setConf(newConf *tlsConfigResp) (restartRequired bool) {
m.confMu.Lock()
defer m.confMu.Unlock()
m.mu.Lock()
defer m.mu.Unlock()
// Reset the DNSCrypt data before comparing, since we currently do not
// accept these from the frontend.
@ -285,6 +295,8 @@ func (m *tlsManager) setConf(newConf *tlsConfigResp) (restartRequired bool) {
m.conf.PrivateKeyPath = newConf.PrivateKeyPath
m.conf.PrivateKeyData = newConf.PrivateKeyData
m.setCertFileTime()
m.status = newConf.tlsConfigStatus
return restartRequired