mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-24 05:55:43 +03:00
home: refactor more
This commit is contained in:
parent
a8850059db
commit
f36efa26a4
4 changed files with 49 additions and 36 deletions
|
@ -22,7 +22,6 @@ import (
|
|||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/ameshkov/dnscrypt/v2"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// BlockingMode is an enum of all allowed blocking modes.
|
||||
|
@ -186,11 +185,6 @@ type TLSConfig struct {
|
|||
hasIPAddrs bool
|
||||
}
|
||||
|
||||
// CertDataClone returns a deep copy of certificate data.
|
||||
func (c TLSConfig) CertDataClone() (certData, keyData []byte) {
|
||||
return slices.Clone(c.CertificateChainData), slices.Clone(c.PrivateKeyData)
|
||||
}
|
||||
|
||||
// DNSCryptConfig is the DNSCrypt server configuration struct.
|
||||
type DNSCryptConfig struct {
|
||||
ResolverCert *dnscrypt.Cert
|
||||
|
|
|
@ -225,18 +225,25 @@ type tlsConfiguration struct {
|
|||
dnsforward.TLSConfig `yaml:",inline" json:",inline"`
|
||||
}
|
||||
|
||||
// partialClone returns a clone of c with all top-level fields of c and all
|
||||
// cloneForEncoding returns a clone of c with all top-level fields of c and all
|
||||
// exported and YAML-encoded fields of c.TLSConfig cloned.
|
||||
//
|
||||
// TODO(a.garipov): This is better than races, but still not good enough.
|
||||
func (c *tlsConfiguration) partialClone() (cloned *tlsConfiguration) {
|
||||
func (c *tlsConfiguration) cloneForEncoding() (cloned *tlsConfiguration) {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
v := *c
|
||||
cloned = &v
|
||||
cloned.OverrideTLSCiphers = slices.Clone(c.OverrideTLSCiphers)
|
||||
cloned.TLSConfig = dnsforward.TLSConfig{
|
||||
CertificateChain: c.CertificateChain,
|
||||
PrivateKey: c.PrivateKey,
|
||||
CertificatePath: c.CertificatePath,
|
||||
PrivateKeyPath: c.PrivateKeyPath,
|
||||
OverrideTLSCiphers: slices.Clone(c.OverrideTLSCiphers),
|
||||
StrictSNICheck: c.StrictSNICheck,
|
||||
}
|
||||
|
||||
return cloned
|
||||
}
|
||||
|
|
|
@ -23,15 +23,15 @@ import (
|
|||
// tlsManager contains the current configuration and state of AdGuard Home TLS
|
||||
// encryption.
|
||||
type tlsManager struct {
|
||||
// mu protects all fields.
|
||||
mu *sync.RWMutex
|
||||
|
||||
// certLastMod is the last modification time of the certificate file.
|
||||
certLastMod time.Time
|
||||
|
||||
// status is the current status of the configuration. It is never nil.
|
||||
status *tlsConfigStatus
|
||||
|
||||
// confMu protects conf.
|
||||
confMu *sync.RWMutex
|
||||
|
||||
// conf is the current TLS configuration.
|
||||
conf *tlsConfiguration
|
||||
}
|
||||
|
@ -40,7 +40,7 @@ type tlsManager struct {
|
|||
func newTLSManager(conf *tlsConfiguration) (m *tlsManager, err error) {
|
||||
m = &tlsManager{
|
||||
status: &tlsConfigStatus{},
|
||||
confMu: &sync.RWMutex{},
|
||||
mu: &sync.RWMutex{},
|
||||
conf: conf,
|
||||
}
|
||||
|
||||
|
@ -56,17 +56,17 @@ func newTLSManager(conf *tlsConfiguration) (m *tlsManager, err error) {
|
|||
return m, nil
|
||||
}
|
||||
|
||||
// partialTLSConf returns a partial clone of the current TLS configuration. It
|
||||
// confForEncoding returns a partial clone of the current TLS configuration. It
|
||||
// is safe for concurrent use.
|
||||
func (m *tlsManager) partialTLSConf() (conf *tlsConfiguration) {
|
||||
m.confMu.RLock()
|
||||
defer m.confMu.RUnlock()
|
||||
func (m *tlsManager) confForEncoding() (conf *tlsConfiguration) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
return m.conf.partialClone()
|
||||
return m.conf.cloneForEncoding()
|
||||
}
|
||||
|
||||
// load reloads the TLS configuration from files or data from the config file.
|
||||
// load assumes that m.confLock is locked for writing.
|
||||
// m.mu is expected to be locked for writing.
|
||||
func (m *tlsManager) load() (err error) {
|
||||
err = loadTLSConf(m.conf, m.status)
|
||||
if err != nil {
|
||||
|
@ -78,11 +78,11 @@ func (m *tlsManager) load() (err error) {
|
|||
|
||||
// WriteDiskConfig - write config
|
||||
func (m *tlsManager) WriteDiskConfig(conf *tlsConfiguration) {
|
||||
*conf = *m.partialTLSConf()
|
||||
*conf = *m.confForEncoding()
|
||||
}
|
||||
|
||||
// setCertFileTime sets t.certLastMod from the certificate. If there are
|
||||
// errors, setCertFileTime logs them.
|
||||
// errors, setCertFileTime logs them. mu is expected to be locked for writing.
|
||||
func (m *tlsManager) setCertFileTime() {
|
||||
if len(m.conf.CertificatePath) == 0 {
|
||||
return
|
||||
|
@ -105,13 +105,13 @@ func (m *tlsManager) start() {
|
|||
// The background context is used because the TLSConfigChanged wraps context
|
||||
// with timeout on its own and shuts down the server, which handles current
|
||||
// request.
|
||||
Context.web.TLSConfigChanged(context.Background(), m.partialTLSConf())
|
||||
Context.web.TLSConfigChanged(context.Background(), m.confForEncoding())
|
||||
}
|
||||
|
||||
// reload updates the configuration and restarts m.
|
||||
func (m *tlsManager) reload() {
|
||||
m.confMu.Lock()
|
||||
defer m.confMu.Unlock()
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if !m.conf.Enabled || len(m.conf.CertificatePath) == 0 {
|
||||
return
|
||||
|
|
|
@ -77,15 +77,23 @@ type tlsConfigReq struct {
|
|||
PrivateKeySaved bool `yaml:"-" json:"private_key_saved"`
|
||||
}
|
||||
|
||||
// handleTLSStatus is the handler for the GET /control/tls/status HTTP API.
|
||||
func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
|
||||
resp := &tlsConfigResp{
|
||||
tlsConfigStatus: m.status,
|
||||
tlsConfiguration: m.partialTLSConf(),
|
||||
}
|
||||
var resp *tlsConfigResp
|
||||
func() {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
resp = &tlsConfigResp{
|
||||
tlsConfigStatus: m.status,
|
||||
tlsConfiguration: m.conf.cloneForEncoding(),
|
||||
}
|
||||
}()
|
||||
|
||||
marshalTLS(w, r, resp)
|
||||
}
|
||||
|
||||
// handleTLSValidate is the handler for the POST /control/tls/validate HTTP API.
|
||||
func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
req, err := unmarshalTLS(r)
|
||||
if err != nil {
|
||||
|
@ -95,7 +103,7 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
if req.PrivateKeySaved {
|
||||
req.PrivateKey = m.conf.PrivateKey
|
||||
req.PrivateKey = m.confForEncoding().PrivateKey
|
||||
}
|
||||
|
||||
if req.Enabled {
|
||||
|
@ -127,12 +135,13 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// Skip the error check, since we are only interested in the value of
|
||||
// status.WarningValidation.
|
||||
resp := &tlsConfigResp{
|
||||
tlsConfigStatus: &tlsConfigStatus{},
|
||||
tlsConfiguration: &req.tlsConfiguration,
|
||||
}
|
||||
|
||||
// Skip the error check, since we are only interested in the value of
|
||||
// resl.tlsConfigStatus.WarningValidation.
|
||||
_ = loadTLSConf(resp.tlsConfiguration, resp.tlsConfigStatus)
|
||||
|
||||
marshalTLS(w, r, resp)
|
||||
|
@ -170,6 +179,8 @@ func validatePorts(
|
|||
return nil
|
||||
}
|
||||
|
||||
// handleTLSConfigure is the handler for the POST /control/tls/configure HTTP
|
||||
// API.
|
||||
func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
req, err := unmarshalTLS(r)
|
||||
if err != nil {
|
||||
|
@ -179,7 +190,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
|
||||
if req.PrivateKeySaved {
|
||||
req.PrivateKey = m.partialTLSConf().PrivateKey
|
||||
req.PrivateKey = m.confForEncoding().PrivateKey
|
||||
}
|
||||
|
||||
if req.Enabled {
|
||||
|
@ -224,7 +235,6 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
|
||||
restartRequired := m.setConf(resp)
|
||||
m.setCertFileTime()
|
||||
onConfigModified()
|
||||
|
||||
err = reconfigureDNSServer()
|
||||
|
@ -234,7 +244,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
|||
return
|
||||
}
|
||||
|
||||
resp.tlsConfiguration = m.partialTLSConf()
|
||||
resp.tlsConfiguration = m.confForEncoding()
|
||||
marshalTLS(w, r, resp)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
|
@ -253,8 +263,8 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
|||
|
||||
// setConf sets the necessary values from the new configuration.
|
||||
func (m *tlsManager) setConf(newConf *tlsConfigResp) (restartRequired bool) {
|
||||
m.confMu.Lock()
|
||||
defer m.confMu.Unlock()
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Reset the DNSCrypt data before comparing, since we currently do not
|
||||
// accept these from the frontend.
|
||||
|
@ -285,6 +295,8 @@ func (m *tlsManager) setConf(newConf *tlsConfigResp) (restartRequired bool) {
|
|||
m.conf.PrivateKeyPath = newConf.PrivateKeyPath
|
||||
m.conf.PrivateKeyData = newConf.PrivateKeyData
|
||||
|
||||
m.setCertFileTime()
|
||||
|
||||
m.status = newConf.tlsConfigStatus
|
||||
|
||||
return restartRequired
|
||||
|
|
Loading…
Reference in a new issue