mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2025-05-05 15:32:54 +03:00
home: refactor tls
This commit is contained in:
parent
93882d6860
commit
a8850059db
9 changed files with 454 additions and 386 deletions
internal/home
|
@ -8,42 +8,39 @@ import (
|
|||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
// tlsManager contains the current configuration and state of AdGuard Home TLS
|
||||
// encryption.
|
||||
type tlsManager struct {
|
||||
// status is the current status of the configuration. It is never nil.
|
||||
status *tlsConfigStatus
|
||||
|
||||
// certLastMod is the last modification time of the certificate file.
|
||||
certLastMod time.Time
|
||||
|
||||
confLock sync.Mutex
|
||||
conf tlsConfigSettings
|
||||
// status is the current status of the configuration. It is never nil.
|
||||
status *tlsConfigStatus
|
||||
|
||||
// confMu protects conf.
|
||||
confMu *sync.RWMutex
|
||||
|
||||
// conf is the current TLS configuration.
|
||||
conf *tlsConfiguration
|
||||
}
|
||||
|
||||
// newTLSManager initializes the TLS configuration.
|
||||
func newTLSManager(conf tlsConfigSettings) (m *tlsManager, err error) {
|
||||
func newTLSManager(conf *tlsConfiguration) (m *tlsManager, err error) {
|
||||
m = &tlsManager{
|
||||
status: &tlsConfigStatus{},
|
||||
confMu: &sync.RWMutex{},
|
||||
conf: conf,
|
||||
}
|
||||
|
||||
|
@ -59,9 +56,19 @@ func newTLSManager(conf tlsConfigSettings) (m *tlsManager, err error) {
|
|||
return m, nil
|
||||
}
|
||||
|
||||
// partialTLSConf returns a partial clone of the current TLS configuration. It
|
||||
// is safe for concurrent use.
|
||||
func (m *tlsManager) partialTLSConf() (conf *tlsConfiguration) {
|
||||
m.confMu.RLock()
|
||||
defer m.confMu.RUnlock()
|
||||
|
||||
return m.conf.partialClone()
|
||||
}
|
||||
|
||||
// load reloads the TLS configuration from files or data from the config file.
|
||||
// load assumes that m.confLock is locked for writing.
|
||||
func (m *tlsManager) load() (err error) {
|
||||
err = loadTLSConf(&m.conf, m.status)
|
||||
err = loadTLSConf(m.conf, m.status)
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading config: %w", err)
|
||||
}
|
||||
|
@ -70,10 +77,8 @@ func (m *tlsManager) load() (err error) {
|
|||
}
|
||||
|
||||
// WriteDiskConfig - write config
|
||||
func (m *tlsManager) WriteDiskConfig(conf *tlsConfigSettings) {
|
||||
m.confLock.Lock()
|
||||
*conf = m.conf
|
||||
m.confLock.Unlock()
|
||||
func (m *tlsManager) WriteDiskConfig(conf *tlsConfiguration) {
|
||||
*conf = *m.partialTLSConf()
|
||||
}
|
||||
|
||||
// setCertFileTime sets t.certLastMod from the certificate. If there are
|
||||
|
@ -97,27 +102,22 @@ func (m *tlsManager) setCertFileTime() {
|
|||
func (m *tlsManager) start() {
|
||||
m.registerWebHandlers()
|
||||
|
||||
m.confLock.Lock()
|
||||
tlsConf := m.conf
|
||||
m.confLock.Unlock()
|
||||
|
||||
// The background context is used because the TLSConfigChanged wraps context
|
||||
// with timeout on its own and shuts down the server, which handles current
|
||||
// request.
|
||||
Context.web.TLSConfigChanged(context.Background(), tlsConf)
|
||||
Context.web.TLSConfigChanged(context.Background(), m.partialTLSConf())
|
||||
}
|
||||
|
||||
// reload updates the configuration and restarts t.
|
||||
// reload updates the configuration and restarts m.
|
||||
func (m *tlsManager) reload() {
|
||||
m.confLock.Lock()
|
||||
tlsConf := m.conf
|
||||
m.confLock.Unlock()
|
||||
m.confMu.Lock()
|
||||
defer m.confMu.Unlock()
|
||||
|
||||
if !tlsConf.Enabled || len(tlsConf.CertificatePath) == 0 {
|
||||
if !m.conf.Enabled || len(m.conf.CertificatePath) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
fi, err := os.Stat(tlsConf.CertificatePath)
|
||||
fi, err := os.Stat(m.conf.CertificatePath)
|
||||
if err != nil {
|
||||
log.Error("tls: %s", err)
|
||||
|
||||
|
@ -132,9 +132,7 @@ func (m *tlsManager) reload() {
|
|||
|
||||
log.Debug("tls: certificate file is modified")
|
||||
|
||||
m.confLock.Lock()
|
||||
err = m.load()
|
||||
m.confLock.Unlock()
|
||||
if err != nil {
|
||||
log.Error("tls: reloading: %s", err)
|
||||
|
||||
|
@ -145,19 +143,15 @@ func (m *tlsManager) reload() {
|
|||
|
||||
_ = reconfigureDNSServer()
|
||||
|
||||
m.confLock.Lock()
|
||||
tlsConf = m.conf
|
||||
m.confLock.Unlock()
|
||||
|
||||
// The background context is used because the TLSConfigChanged wraps context
|
||||
// with timeout on its own and shuts down the server, which handles current
|
||||
// request.
|
||||
Context.web.TLSConfigChanged(context.Background(), tlsConf)
|
||||
Context.web.TLSConfigChanged(context.Background(), m.conf)
|
||||
}
|
||||
|
||||
// loadTLSConf loads and validates the TLS configuration. The returned error is
|
||||
// also set in status.WarningValidation.
|
||||
func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error) {
|
||||
func loadTLSConf(tlsConf *tlsConfiguration, status *tlsConfigStatus) (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
status.WarningValidation = err.Error()
|
||||
|
@ -172,13 +166,10 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
|
|||
tlsConf.PrivateKeyData = []byte(tlsConf.PrivateKey)
|
||||
|
||||
if tlsConf.CertificatePath != "" {
|
||||
if tlsConf.CertificateChain != "" {
|
||||
return errors.Error("certificate data and file can't be set together")
|
||||
}
|
||||
|
||||
tlsConf.CertificateChainData, err = os.ReadFile(tlsConf.CertificatePath)
|
||||
err = loadCert(tlsConf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading cert file: %w", err)
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
// Set status.ValidCert to true to signal the frontend that the
|
||||
|
@ -187,13 +178,10 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
|
|||
}
|
||||
|
||||
if tlsConf.PrivateKeyPath != "" {
|
||||
if tlsConf.PrivateKey != "" {
|
||||
return errors.Error("private key data and file can't be set together")
|
||||
}
|
||||
|
||||
tlsConf.PrivateKeyData, err = os.ReadFile(tlsConf.PrivateKeyPath)
|
||||
err = loadPKey(tlsConf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading key file: %w", err)
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
status.ValidKey = true
|
||||
|
@ -212,278 +200,29 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
|
|||
return nil
|
||||
}
|
||||
|
||||
// tlsConfigStatus contains the status of a certificate chain and key pair.
|
||||
type tlsConfigStatus struct {
|
||||
// Subject is the subject of the first certificate in the chain.
|
||||
Subject string `json:"subject,omitempty"`
|
||||
|
||||
// Issuer is the issuer of the first certificate in the chain.
|
||||
Issuer string `json:"issuer,omitempty"`
|
||||
|
||||
// KeyType is the type of the private key.
|
||||
KeyType string `json:"key_type,omitempty"`
|
||||
|
||||
// NotBefore is the NotBefore field of the first certificate in the chain.
|
||||
NotBefore time.Time `json:"not_before,omitempty"`
|
||||
|
||||
// NotAfter is the NotAfter field of the first certificate in the chain.
|
||||
NotAfter time.Time `json:"not_after,omitempty"`
|
||||
|
||||
// WarningValidation is a validation warning message with the issue
|
||||
// description.
|
||||
WarningValidation string `json:"warning_validation,omitempty"`
|
||||
|
||||
// DNSNames is the value of SubjectAltNames field of the first certificate
|
||||
// in the chain.
|
||||
DNSNames []string `json:"dns_names"`
|
||||
|
||||
// ValidCert is true if the specified certificate chain is a valid chain of
|
||||
// X509 certificates.
|
||||
ValidCert bool `json:"valid_cert"`
|
||||
|
||||
// ValidChain is true if the specified certificate chain is verified and
|
||||
// issued by a known CA.
|
||||
ValidChain bool `json:"valid_chain"`
|
||||
|
||||
// ValidKey is true if the key is a valid private key.
|
||||
ValidKey bool `json:"valid_key"`
|
||||
|
||||
// ValidPair is true if both certificate and private key are correct for
|
||||
// each other.
|
||||
ValidPair bool `json:"valid_pair"`
|
||||
}
|
||||
|
||||
// tlsConfig is the TLS configuration and status response.
|
||||
type tlsConfig struct {
|
||||
*tlsConfigStatus `json:",inline"`
|
||||
tlsConfigSettingsExt `json:",inline"`
|
||||
}
|
||||
|
||||
// tlsConfigSettingsExt is used to (un)marshal the PrivateKeySaved field to
|
||||
// ensure that clients don't send and receive previously saved private keys.
|
||||
type tlsConfigSettingsExt struct {
|
||||
tlsConfigSettings `json:",inline"`
|
||||
|
||||
// PrivateKeySaved is true if the private key is saved as a string and omit
|
||||
// key from answer.
|
||||
PrivateKeySaved bool `yaml:"-" json:"private_key_saved,inline"`
|
||||
}
|
||||
|
||||
func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
|
||||
m.confLock.Lock()
|
||||
data := tlsConfig{
|
||||
tlsConfigSettingsExt: tlsConfigSettingsExt{
|
||||
tlsConfigSettings: m.conf,
|
||||
},
|
||||
tlsConfigStatus: m.status,
|
||||
// loadCert loads the certificate from file, if necessary.
|
||||
func loadCert(tlsConf *tlsConfiguration) (err error) {
|
||||
if tlsConf.CertificateChain != "" {
|
||||
return errors.Error("certificate data and file can't be set together")
|
||||
}
|
||||
m.confLock.Unlock()
|
||||
|
||||
marshalTLS(w, r, data)
|
||||
}
|
||||
|
||||
func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
setts, err := unmarshalTLS(r)
|
||||
tlsConf.CertificateChainData, err = os.ReadFile(tlsConf.CertificatePath)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
||||
|
||||
return
|
||||
return fmt.Errorf("reading cert file: %w", err)
|
||||
}
|
||||
|
||||
if setts.PrivateKeySaved {
|
||||
setts.PrivateKey = m.conf.PrivateKey
|
||||
}
|
||||
|
||||
if setts.Enabled {
|
||||
err = validatePorts(
|
||||
tcpPort(config.BindPort),
|
||||
tcpPort(config.BetaBindPort),
|
||||
tcpPort(setts.PortHTTPS),
|
||||
tcpPort(setts.PortDNSOverTLS),
|
||||
tcpPort(setts.PortDNSCrypt),
|
||||
udpPort(config.DNS.Port),
|
||||
udpPort(setts.PortDNSOverQUIC),
|
||||
)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !webCheckPortAvailable(setts.PortHTTPS) {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
"port %d is not available, cannot enable HTTPS on it",
|
||||
setts.PortHTTPS,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Skip the error check, since we are only interested in the value of
|
||||
// status.WarningValidation.
|
||||
status := &tlsConfigStatus{}
|
||||
_ = loadTLSConf(&setts.tlsConfigSettings, status)
|
||||
resp := tlsConfig{
|
||||
tlsConfigSettingsExt: setts,
|
||||
tlsConfigStatus: status,
|
||||
}
|
||||
|
||||
marshalTLS(w, r, resp)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *tlsManager) setConfig(newConf tlsConfigSettings, status *tlsConfigStatus) (restartHTTPS bool) {
|
||||
m.confLock.Lock()
|
||||
defer m.confLock.Unlock()
|
||||
|
||||
// Reset the DNSCrypt data before comparing, since we currently do not
|
||||
// accept these from the frontend.
|
||||
//
|
||||
// TODO(a.garipov): Define a custom comparer for dnsforward.TLSConfig.
|
||||
newConf.DNSCryptConfigFile = m.conf.DNSCryptConfigFile
|
||||
newConf.PortDNSCrypt = m.conf.PortDNSCrypt
|
||||
if !cmp.Equal(m.conf, newConf, cmp.AllowUnexported(dnsforward.TLSConfig{})) {
|
||||
log.Info("tls config has changed, restarting https server")
|
||||
restartHTTPS = true
|
||||
} else {
|
||||
log.Info("tls: config has not changed")
|
||||
// loadPKey loads the private key from file, if necessary.
|
||||
func loadPKey(tlsConf *tlsConfiguration) (err error) {
|
||||
if tlsConf.PrivateKey != "" {
|
||||
return errors.Error("private key data and file cannot be set together")
|
||||
}
|
||||
|
||||
// Note: don't do just `t.conf = data` because we must preserve all other members of t.conf
|
||||
m.conf.Enabled = newConf.Enabled
|
||||
m.conf.ServerName = newConf.ServerName
|
||||
m.conf.ForceHTTPS = newConf.ForceHTTPS
|
||||
m.conf.PortHTTPS = newConf.PortHTTPS
|
||||
m.conf.PortDNSOverTLS = newConf.PortDNSOverTLS
|
||||
m.conf.PortDNSOverQUIC = newConf.PortDNSOverQUIC
|
||||
m.conf.CertificateChain = newConf.CertificateChain
|
||||
m.conf.CertificatePath = newConf.CertificatePath
|
||||
m.conf.CertificateChainData = newConf.CertificateChainData
|
||||
m.conf.PrivateKey = newConf.PrivateKey
|
||||
m.conf.PrivateKeyPath = newConf.PrivateKeyPath
|
||||
m.conf.PrivateKeyData = newConf.PrivateKeyData
|
||||
m.status = status
|
||||
|
||||
return restartHTTPS
|
||||
}
|
||||
|
||||
func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
req, err := unmarshalTLS(r)
|
||||
tlsConf.PrivateKeyData, err = os.ReadFile(tlsConf.PrivateKeyPath)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if req.PrivateKeySaved {
|
||||
req.PrivateKey = m.conf.PrivateKey
|
||||
}
|
||||
|
||||
if req.Enabled {
|
||||
err = validatePorts(
|
||||
tcpPort(config.BindPort),
|
||||
tcpPort(config.BetaBindPort),
|
||||
tcpPort(req.PortHTTPS),
|
||||
tcpPort(req.PortDNSOverTLS),
|
||||
tcpPort(req.PortDNSCrypt),
|
||||
udpPort(config.DNS.Port),
|
||||
udpPort(req.PortDNSOverQUIC),
|
||||
)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Investigate and perhaps check other ports.
|
||||
if !webCheckPortAvailable(req.PortHTTPS) {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
"port %d is not available, cannot enable https on it",
|
||||
req.PortHTTPS,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
status := &tlsConfigStatus{}
|
||||
err = loadTLSConf(&req.tlsConfigSettings, status)
|
||||
if err != nil {
|
||||
resp := tlsConfig{
|
||||
tlsConfigSettingsExt: req,
|
||||
tlsConfigStatus: status,
|
||||
}
|
||||
|
||||
marshalTLS(w, r, resp)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
restartHTTPS := m.setConfig(req.tlsConfigSettings, status)
|
||||
m.setCertFileTime()
|
||||
onConfigModified()
|
||||
|
||||
err = reconfigureDNSServer()
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resp := tlsConfig{
|
||||
tlsConfigSettingsExt: req,
|
||||
tlsConfigStatus: m.status,
|
||||
}
|
||||
|
||||
marshalTLS(w, r, resp)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
// The background context is used because the TLSConfigChanged wraps context
|
||||
// with timeout on its own and shuts down the server, which handles current
|
||||
// request. It is also should be done in a separate goroutine due to the
|
||||
// same reason.
|
||||
if restartHTTPS {
|
||||
go func() {
|
||||
Context.web.TLSConfigChanged(context.Background(), req.tlsConfigSettings)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home
|
||||
// DNS protocols.
|
||||
func validatePorts(
|
||||
bindPort, betaBindPort, dohPort, dotPort, dnscryptTCPPort tcpPort,
|
||||
dnsPort, doqPort udpPort,
|
||||
) (err error) {
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
addPorts(
|
||||
tcpPorts,
|
||||
tcpPort(bindPort),
|
||||
tcpPort(betaBindPort),
|
||||
tcpPort(dohPort),
|
||||
tcpPort(dotPort),
|
||||
tcpPort(dnscryptTCPPort),
|
||||
)
|
||||
|
||||
err = tcpPorts.Validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating tcp ports: %w", err)
|
||||
}
|
||||
|
||||
udpPorts := aghalg.UniqChecker[udpPort]{}
|
||||
addPorts(udpPorts, udpPort(dnsPort), udpPort(doqPort))
|
||||
|
||||
err = udpPorts.Validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating udp ports: %w", err)
|
||||
return fmt.Errorf("reading key file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -696,61 +435,3 @@ func parsePrivateKey(der []byte) (key crypto.PrivateKey, typ string, err error)
|
|||
|
||||
return nil, "", errors.Error("tls: failed to parse private key")
|
||||
}
|
||||
|
||||
// unmarshalTLS handles base64-encoded certificates transparently
|
||||
func unmarshalTLS(r *http.Request) (tlsConfigSettingsExt, error) {
|
||||
data := tlsConfigSettingsExt{}
|
||||
err := json.NewDecoder(r.Body).Decode(&data)
|
||||
if err != nil {
|
||||
return data, fmt.Errorf("failed to parse new TLS config json: %w", err)
|
||||
}
|
||||
|
||||
if data.CertificateChain != "" {
|
||||
var cert []byte
|
||||
cert, err = base64.StdEncoding.DecodeString(data.CertificateChain)
|
||||
if err != nil {
|
||||
return data, fmt.Errorf("failed to base64-decode certificate chain: %w", err)
|
||||
}
|
||||
|
||||
data.CertificateChain = string(cert)
|
||||
if data.CertificatePath != "" {
|
||||
return data, fmt.Errorf("certificate data and file can't be set together")
|
||||
}
|
||||
}
|
||||
|
||||
if data.PrivateKey != "" {
|
||||
var key []byte
|
||||
key, err = base64.StdEncoding.DecodeString(data.PrivateKey)
|
||||
if err != nil {
|
||||
return data, fmt.Errorf("failed to base64-decode private key: %w", err)
|
||||
}
|
||||
|
||||
data.PrivateKey = string(key)
|
||||
if data.PrivateKeyPath != "" {
|
||||
return data, fmt.Errorf("private key data and file can't be set together")
|
||||
}
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func marshalTLS(w http.ResponseWriter, r *http.Request, data tlsConfig) {
|
||||
if data.CertificateChain != "" {
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(data.CertificateChain))
|
||||
data.CertificateChain = encoded
|
||||
}
|
||||
|
||||
if data.PrivateKey != "" {
|
||||
data.PrivateKeySaved = true
|
||||
data.PrivateKey = ""
|
||||
}
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, data)
|
||||
}
|
||||
|
||||
// registerWebHandlers registers HTTP handlers for TLS configuration.
|
||||
func (m *tlsManager) registerWebHandlers() {
|
||||
httpRegister(http.MethodGet, "/control/tls/status", m.handleTLSStatus)
|
||||
httpRegister(http.MethodPost, "/control/tls/configure", m.handleTLSConfigure)
|
||||
httpRegister(http.MethodPost, "/control/tls/validate", m.handleTLSValidate)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue