diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index c76ccf75..6095af08 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -180,6 +180,8 @@ type ServerConfig struct { FilteringConfig TLSConfig + TLSv12Roots *x509.CertPool // list of root CAs for TLSv1.2 + // Called when the configuration is changed by HTTP request ConfigModified func() @@ -338,6 +340,7 @@ func (s *Server) Prepare(config *ServerConfig) error { MinVersion: tls.VersionTLS12, } } + upstream.RootCAs = s.conf.TLSv12Roots if len(proxyConfig.Upstreams) == 0 { log.Fatal("len(proxyConfig.Upstreams) == 0") diff --git a/home/control_tls.go b/home/control_tls.go index 0df8b729..e102ecbd 100644 --- a/home/control_tls.go +++ b/home/control_tls.go @@ -200,6 +200,7 @@ func verifyCertChain(data *tlsConfigStatus, certChain string, serverName string) opts := x509.VerifyOptions{ DNSName: serverName, + Roots: Context.tlsRoots, } log.Printf("number of certs - %d", len(parsedCerts)) diff --git a/home/dns.go b/home/dns.go index 167662b1..90461930 100644 --- a/home/dns.go +++ b/home/dns.go @@ -175,6 +175,7 @@ func generateServerConfig() dnsforward.ServerConfig { newconfig.TLSListenAddr = &net.TCPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.TLS.PortDNSOverTLS} } } + newconfig.TLSv12Roots = Context.tlsRoots newconfig.FilterHandler = applyAdditionalFiltering newconfig.GetUpstreamsByClient = getUpstreamsByClient diff --git a/home/home.go b/home/home.go index abb46809..e30d4f7a 100644 --- a/home/home.go +++ b/home/home.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "crypto/tls" + "crypto/x509" "fmt" "io" "io/ioutil" @@ -78,6 +79,7 @@ type homeContext struct { pidFileName string // PID file name. Empty if no PID file was created. disableUpdate bool // If set, don't check for updates controlLock sync.Mutex + tlsRoots *x509.CertPool // list of root CAs for TLSv1.2 transport *http.Transport client *http.Client appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app @@ -135,16 +137,6 @@ func run(args options) { Context.configFilename = "AdGuardHome.yaml" } - // Init some of the Context fields right away - Context.transport = &http.Transport{ - DialContext: customDialContext, - Proxy: getHTTPProxy, - } - Context.client = &http.Client{ - Timeout: time.Minute * 5, - Transport: Context.transport, - } - // configure working dir and config path initWorkingDir(args) @@ -172,6 +164,19 @@ func run(args options) { initConfig() initServices() + Context.tlsRoots = util.LoadSystemRootCAs() + Context.transport = &http.Transport{ + DialContext: customDialContext, + Proxy: getHTTPProxy, + TLSClientConfig: &tls.Config{ + RootCAs: Context.tlsRoots, + }, + } + Context.client = &http.Client{ + Timeout: time.Minute * 5, + Transport: Context.transport, + } + if !Context.firstRun { // Do the upgrade if necessary err := upgradeConfig() @@ -321,6 +326,7 @@ func httpServerLoop() { TLSConfig: &tls.Config{ Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12, + RootCAs: Context.tlsRoots, }, } diff --git a/util/tls.go b/util/tls.go new file mode 100644 index 00000000..a4d125f9 --- /dev/null +++ b/util/tls.go @@ -0,0 +1,47 @@ +package util + +import ( + "crypto/x509" + "io/ioutil" + "os" + "runtime" + + "github.com/AdguardTeam/golibs/log" +) + +// LoadSystemRootCAs - load root CAs from the system +// Return the x509 certificate pool object +// Return nil if nothing has been found. +// This means that Go.crypto will use its default algorithm to find system root CA list. +// https://github.com/AdguardTeam/AdGuardHome/issues/1311 +func LoadSystemRootCAs() *x509.CertPool { + if runtime.GOOS != "linux" { + return nil + } + + // Directories with the system root certificates, that aren't supported by Go.crypto + dirs := []string{ + "/opt/etc/ssl/certs", // Entware + } + roots := x509.NewCertPool() + for _, dir := range dirs { + fis, err := ioutil.ReadDir(dir) + if err != nil { + if !os.IsNotExist(err) { + log.Error("Opening directory: %s: %s", dir, err) + } + continue + } + rootsAdded := false + for _, fi := range fis { + data, err := ioutil.ReadFile(dir + "/" + fi.Name()) + if err == nil && roots.AppendCertsFromPEM(data) { + rootsAdded = true + } + } + if rootsAdded { + return roots + } + } + return nil +}