From 8e4bc29103e5a9093ac277f05bb069e2e7b17db1 Mon Sep 17 00:00:00 2001
From: Simon Zolin <s.zolin@adguard.com>
Date: Wed, 19 Feb 2020 15:24:55 +0300
Subject: [PATCH] * move HTTP server code

---
 home/config.go          |   9 --
 home/control.go         |   2 +-
 home/control_install.go |   2 +-
 home/control_tls.go     |  14 +--
 home/home.go            | 177 ++++++++-----------------------------
 home/web.go             | 189 ++++++++++++++++++++++++++++++++++++++++
 6 files changed, 234 insertions(+), 159 deletions(-)
 create mode 100644 home/web.go

diff --git a/home/config.go b/home/config.go
index 41df6829..1150c1f2 100644
--- a/home/config.go
+++ b/home/config.go
@@ -2,7 +2,6 @@ package home
 
 import (
 	"io/ioutil"
-	"net/http"
 	"os"
 	"path/filepath"
 	"sync"
@@ -29,14 +28,6 @@ type logSettings struct {
 	Verbose bool   `yaml:"verbose"`  // If true, verbose logging is enabled
 }
 
-// HTTPSServer - HTTPS Server
-type HTTPSServer struct {
-	server     *http.Server
-	cond       *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey
-	sync.Mutex            // protects config.TLS
-	shutdown   bool       // if TRUE, don't restart the server
-}
-
 // configuration is loaded from YAML
 // field ordering is important -- yaml fields will mirror ordering from here
 type configuration struct {
diff --git a/home/control.go b/home/control.go
index 571b0708..289cfd3c 100644
--- a/home/control.go
+++ b/home/control.go
@@ -265,7 +265,7 @@ func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.Res
 		}
 
 		// enforce https?
-		if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && Context.httpsServer.server != nil {
+		if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && Context.web.httpsServer.server != nil {
 			// yes, and we want host from host:port
 			host, _, err := net.SplitHostPort(r.Host)
 			if err != nil {
diff --git a/home/control_install.go b/home/control_install.go
index 5b4c4297..ab6f0ca5 100644
--- a/home/control_install.go
+++ b/home/control_install.go
@@ -362,7 +362,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
 	// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
 	if restartHTTP {
 		go func() {
-			_ = Context.httpServer.Shutdown(context.TODO())
+			_ = Context.web.httpServer.Shutdown(context.TODO())
 		}()
 	}
 
diff --git a/home/control_tls.go b/home/control_tls.go
index e102ecbd..b048800d 100644
--- a/home/control_tls.go
+++ b/home/control_tls.go
@@ -82,7 +82,7 @@ func handleTLSValidate(w http.ResponseWriter, r *http.Request) {
 	// check if port is available
 	// BUT: if we are already using this port, no need
 	alreadyRunning := false
-	if Context.httpsServer.server != nil {
+	if Context.web.httpsServer.server != nil {
 		alreadyRunning = true
 	}
 	if !alreadyRunning {
@@ -112,7 +112,7 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
 	// check if port is available
 	// BUT: if we are already using this port, no need
 	alreadyRunning := false
-	if Context.httpsServer.server != nil {
+	if Context.web.httpsServer.server != nil {
 		alreadyRunning = true
 	}
 	if !alreadyRunning {
@@ -147,12 +147,12 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
 	if restartHTTPS {
 		go func() {
 			time.Sleep(time.Second) // TODO: could not find a way to reliably know that data was fully sent to client by https server, so we wait a bit to let response through before closing the server
-			Context.httpsServer.cond.L.Lock()
-			Context.httpsServer.cond.Broadcast()
-			if Context.httpsServer.server != nil {
-				Context.httpsServer.server.Shutdown(context.TODO())
+			Context.web.httpsServer.cond.L.Lock()
+			Context.web.httpsServer.cond.Broadcast()
+			if Context.web.httpsServer.server != nil {
+				Context.web.httpsServer.server.Shutdown(context.TODO())
 			}
-			Context.httpsServer.cond.L.Unlock()
+			Context.web.httpsServer.cond.L.Unlock()
 		}()
 	}
 }
diff --git a/home/home.go b/home/home.go
index d4a4f215..dfa0118c 100644
--- a/home/home.go
+++ b/home/home.go
@@ -34,8 +34,6 @@ import (
 	"github.com/AdguardTeam/AdGuardHome/querylog"
 	"github.com/AdguardTeam/AdGuardHome/stats"
 	"github.com/AdguardTeam/golibs/log"
-	"github.com/NYTimes/gziphandler"
-	"github.com/gobuffalo/packr"
 )
 
 const (
@@ -58,18 +56,17 @@ type homeContext struct {
 	// Modules
 	// --
 
-	clients     clientsContainer     // per-client-settings module
-	stats       stats.Stats          // statistics module
-	queryLog    querylog.QueryLog    // query log module
-	dnsServer   *dnsforward.Server   // DNS module
-	rdns        *RDNS                // rDNS module
-	whois       *Whois               // WHOIS module
-	dnsFilter   *dnsfilter.Dnsfilter // DNS filtering module
-	dhcpServer  *dhcpd.Server        // DHCP module
-	auth        *Auth                // HTTP authentication module
-	httpServer  *http.Server         // HTTP module
-	httpsServer HTTPSServer          // HTTPS module
-	filters     Filtering
+	clients    clientsContainer     // per-client-settings module
+	stats      stats.Stats          // statistics module
+	queryLog   querylog.QueryLog    // query log module
+	dnsServer  *dnsforward.Server   // DNS module
+	rdns       *RDNS                // rDNS module
+	whois      *Whois               // WHOIS module
+	dnsFilter  *dnsfilter.Dnsfilter // DNS filtering module
+	dhcpServer *dhcpd.Server        // DHCP module
+	auth       *Auth                // HTTP authentication module
+	filters    Filtering
+	web        *Web
 
 	// Runtime properties
 	// --
@@ -243,9 +240,22 @@ func run(args options) {
 		log.Fatalf("Cannot create DNS data dir at %s: %s", Context.getDataDir(), err)
 	}
 
-	err = initWeb()
-	if err != nil {
-		log.Fatalf("%s", err)
+	sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
+	Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60)
+	if Context.auth == nil {
+		log.Fatalf("Couldn't initialize Auth module")
+	}
+	config.Users = nil
+
+	webConf := WebConfig{
+		firstRun: Context.firstRun,
+		BindHost: config.BindHost,
+		BindPort: config.BindPort,
+		TLS:      config.TLS,
+	}
+	Context.web = CreateWeb(&webConf)
+	if Context.web == nil {
+		log.Fatalf("Can't initialize Web module")
 	}
 
 	if !Context.firstRun {
@@ -266,115 +276,12 @@ func run(args options) {
 		}
 	}
 
-	startWeb()
+	Context.web.Start()
 
 	// wait indefinitely for other go-routines to complete their job
 	select {}
 }
 
-// Initialize Web modules
-func initWeb() error {
-	sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
-	Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60)
-	if Context.auth == nil {
-		return fmt.Errorf("Couldn't initialize Auth module")
-	}
-	config.Users = nil
-
-	// Initialize and run the admin Web interface
-	box := packr.NewBox("../build/static")
-
-	// if not configured, redirect / to /install.html, otherwise redirect /install.html to /
-	http.Handle("/", postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(http.FileServer(box)))))
-
-	// add handlers for /install paths, we only need them when we're not configured yet
-	if Context.firstRun {
-		log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ")
-		http.Handle("/install.html", preInstallHandler(http.FileServer(box)))
-		registerInstallHandlers()
-	} else {
-		registerControlHandlers()
-	}
-
-	Context.httpsServer.cond = sync.NewCond(&Context.httpsServer.Mutex)
-	return nil
-}
-
-func startWeb() {
-	// for https, we have a separate goroutine loop
-	go httpServerLoop()
-
-	// this loop is used as an ability to change listening host and/or port
-	for !Context.httpsServer.shutdown {
-		printHTTPAddresses("http")
-
-		// we need to have new instance, because after Shutdown() the Server is not usable
-		address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
-		Context.httpServer = &http.Server{
-			Addr: address,
-		}
-		err := Context.httpServer.ListenAndServe()
-		if err != http.ErrServerClosed {
-			cleanupAlways()
-			log.Fatal(err)
-		}
-		// We use ErrServerClosed as a sign that we need to rebind on new address, so go back to the start of the loop
-	}
-}
-
-func httpServerLoop() {
-	for !Context.httpsServer.shutdown {
-		Context.httpsServer.cond.L.Lock()
-		// this mechanism doesn't let us through until all conditions are met
-		for config.TLS.Enabled == false ||
-			config.TLS.PortHTTPS == 0 ||
-			len(config.TLS.PrivateKeyData) == 0 ||
-			len(config.TLS.CertificateChainData) == 0 { // sleep until necessary data is supplied
-			Context.httpsServer.cond.Wait()
-		}
-		address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.TLS.PortHTTPS))
-		// validate current TLS config and update warnings (it could have been loaded from file)
-		data := validateCertificates(string(config.TLS.CertificateChainData), string(config.TLS.PrivateKeyData), config.TLS.ServerName)
-		if !data.ValidPair {
-			cleanupAlways()
-			log.Fatal(data.WarningValidation)
-		}
-		config.Lock()
-		config.TLS.tlsConfigStatus = data // update warnings
-		config.Unlock()
-
-		// prepare certs for HTTPS server
-		// important -- they have to be copies, otherwise changing the contents in config.TLS will break encryption for in-flight requests
-		certchain := make([]byte, len(config.TLS.CertificateChainData))
-		copy(certchain, config.TLS.CertificateChainData)
-		privatekey := make([]byte, len(config.TLS.PrivateKeyData))
-		copy(privatekey, config.TLS.PrivateKeyData)
-		cert, err := tls.X509KeyPair(certchain, privatekey)
-		if err != nil {
-			cleanupAlways()
-			log.Fatal(err)
-		}
-		Context.httpsServer.cond.L.Unlock()
-
-		// prepare HTTPS server
-		Context.httpsServer.server = &http.Server{
-			Addr: address,
-			TLSConfig: &tls.Config{
-				Certificates: []tls.Certificate{cert},
-				MinVersion:   tls.VersionTLS12,
-				RootCAs:      Context.tlsRoots,
-			},
-		}
-
-		printHTTPAddresses("https")
-		err = Context.httpsServer.server.ListenAndServeTLS("", "")
-		if err != http.ErrServerClosed {
-			cleanupAlways()
-			log.Fatal(err)
-		}
-	}
-}
-
 // Check if the current user has root (administrator) rights
 //  and if not, ask and try to run as root
 func requireAdminRights() {
@@ -484,7 +391,14 @@ func configureLogger(args options) {
 func cleanup() {
 	log.Info("Stopping AdGuard Home")
 
-	stopHTTPServer()
+	if Context.web != nil {
+		Context.web.Close()
+		Context.web = nil
+	}
+	if Context.auth != nil {
+		Context.auth.Close()
+		Context.auth = nil
+	}
 
 	err := stopDNSServer()
 	if err != nil {
@@ -496,25 +410,6 @@ func cleanup() {
 	}
 }
 
-// Stop HTTP server, possibly waiting for all active connections to be closed
-func stopHTTPServer() {
-	log.Info("Stopping HTTP server...")
-	Context.httpsServer.shutdown = true
-	if Context.httpsServer.server != nil {
-		_ = Context.httpsServer.server.Shutdown(context.TODO())
-	}
-	if Context.httpServer != nil {
-		_ = Context.httpServer.Shutdown(context.TODO())
-	}
-
-	if Context.auth != nil {
-		Context.auth.Close()
-		Context.auth = nil
-	}
-
-	log.Info("Stopped HTTP server")
-}
-
 // This function is called before application exits
 func cleanupAlways() {
 	if len(Context.pidFileName) != 0 {
diff --git a/home/web.go b/home/web.go
new file mode 100644
index 00000000..28db3948
--- /dev/null
+++ b/home/web.go
@@ -0,0 +1,189 @@
+package home
+
+import (
+	"context"
+	"crypto/tls"
+	"net"
+	"net/http"
+	"strconv"
+	"sync"
+
+	"github.com/AdguardTeam/AdGuardHome/util"
+	"github.com/AdguardTeam/golibs/log"
+	"github.com/NYTimes/gziphandler"
+	"github.com/gobuffalo/packr"
+)
+
+type WebConfig struct {
+	firstRun  bool
+	BindHost  string
+	BindPort  int
+	PortHTTPS int
+}
+
+// HTTPSServer - HTTPS Server
+type HTTPSServer struct {
+	server   *http.Server
+	cond     *sync.Cond
+	condLock sync.Mutex
+	shutdown bool // if TRUE, don't restart the server
+	enabled  bool
+	cert     tls.Certificate
+}
+
+// Web - module object
+type Web struct {
+	conf        *WebConfig
+	forceHTTPS  bool
+	portHTTPS   int
+	httpServer  *http.Server // HTTP module
+	httpsServer HTTPSServer  // HTTPS module
+}
+
+// CreateWeb - create module
+func CreateWeb(conf *WebConfig) *Web {
+	w := Web{}
+	w.conf = conf
+
+	// Initialize and run the admin Web interface
+	box := packr.NewBox("../build/static")
+
+	// if not configured, redirect / to /install.html, otherwise redirect /install.html to /
+	http.Handle("/", postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(http.FileServer(box)))))
+
+	// add handlers for /install paths, we only need them when we're not configured yet
+	if conf.firstRun {
+		log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ")
+		http.Handle("/install.html", preInstallHandler(http.FileServer(box)))
+		w.registerInstallHandlers()
+	} else {
+		registerControlHandlers()
+	}
+
+	w.httpsServer.cond = sync.NewCond(&w.httpsServer.condLock)
+	return &w
+}
+
+// WebCheckPortAvailable - check if port is available
+// BUT: if we are already using this port, no need
+func WebCheckPortAvailable(port int) bool {
+	alreadyRunning := false
+	if Context.web.httpsServer.server != nil {
+		alreadyRunning = true
+	}
+	if !alreadyRunning {
+		err := util.CheckPortAvailable(config.BindHost, port)
+		if err != nil {
+			return false
+		}
+	}
+	return true
+}
+
+// TLSConfigChanged - called when TLS configuration has changed
+func (w *Web) TLSConfigChanged(tlsConf tlsConfigSettings) {
+	log.Debug("Web: applying new TLS configuration")
+	w.conf.PortHTTPS = tlsConf.PortHTTPS
+	w.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0)
+	w.portHTTPS = tlsConf.PortHTTPS
+
+	enabled := tlsConf.Enabled &&
+		tlsConf.PortHTTPS != 0 &&
+		len(tlsConf.PrivateKeyData) != 0 &&
+		len(tlsConf.CertificateChainData) != 0
+	var cert tls.Certificate
+	var err error
+	if enabled {
+		cert, err = tls.X509KeyPair(tlsConf.CertificateChainData, tlsConf.PrivateKeyData)
+		if err != nil {
+			log.Fatal(err)
+		}
+	}
+
+	w.httpsServer.cond.L.Lock()
+	if w.httpsServer.server != nil {
+		w.httpsServer.server.Shutdown(context.TODO())
+	}
+	w.httpsServer.enabled = enabled
+	w.httpsServer.cert = cert
+	w.httpsServer.cond.Broadcast()
+	w.httpsServer.cond.L.Unlock()
+}
+
+// Start - start serving HTTP requests
+func (w *Web) Start() {
+	// for https, we have a separate goroutine loop
+	go w.httpServerLoop()
+
+	// this loop is used as an ability to change listening host and/or port
+	for !w.httpsServer.shutdown {
+		printHTTPAddresses("http")
+
+		// we need to have new instance, because after Shutdown() the Server is not usable
+		address := net.JoinHostPort(w.conf.BindHost, strconv.Itoa(w.conf.BindPort))
+		w.httpServer = &http.Server{
+			Addr: address,
+		}
+		err := w.httpServer.ListenAndServe()
+		if err != http.ErrServerClosed {
+			cleanupAlways()
+			log.Fatal(err)
+		}
+		// We use ErrServerClosed as a sign that we need to rebind on new address, so go back to the start of the loop
+	}
+}
+
+// Close - stop HTTP server, possibly waiting for all active connections to be closed
+func (w *Web) Close() {
+	log.Info("Stopping HTTP server...")
+	w.httpsServer.cond.L.Lock()
+	w.httpsServer.shutdown = true
+	w.httpsServer.cond.L.Unlock()
+	if w.httpsServer.server != nil {
+		_ = w.httpsServer.server.Shutdown(context.TODO())
+	}
+	if w.httpServer != nil {
+		_ = w.httpServer.Shutdown(context.TODO())
+	}
+
+	log.Info("Stopped HTTP server")
+}
+
+func (w *Web) httpServerLoop() {
+	for {
+		w.httpsServer.cond.L.Lock()
+		if w.httpsServer.shutdown {
+			w.httpsServer.cond.L.Unlock()
+			break
+		}
+
+		// this mechanism doesn't let us through until all conditions are met
+		for !w.httpsServer.enabled { // sleep until necessary data is supplied
+			w.httpsServer.cond.Wait()
+			if w.httpsServer.shutdown {
+				w.httpsServer.cond.L.Unlock()
+				return
+			}
+		}
+
+		w.httpsServer.cond.L.Unlock()
+
+		// prepare HTTPS server
+		address := net.JoinHostPort(w.conf.BindHost, strconv.Itoa(w.conf.PortHTTPS))
+		w.httpsServer.server = &http.Server{
+			Addr: address,
+			TLSConfig: &tls.Config{
+				Certificates: []tls.Certificate{w.httpsServer.cert},
+				MinVersion:   tls.VersionTLS12,
+				RootCAs:      Context.tlsRoots,
+			},
+		}
+
+		printHTTPAddresses("https")
+		err := w.httpsServer.server.ListenAndServeTLS("", "")
+		if err != http.ErrServerClosed {
+			cleanupAlways()
+			log.Fatal(err)
+		}
+	}
+}