home: imp code

This commit is contained in:
Stanislav Chzhen 2024-12-03 17:25:37 +03:00
parent fce1bf475f
commit 5368d8de50
4 changed files with 69 additions and 35 deletions

View file

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/http" "net/http"
"net/netip" "net/netip"
"os" "os"
@ -125,6 +126,8 @@ func (req *checkConfReq) validateWeb(tcpPorts aghalg.UniqChecker[tcpPort]) (err
// be set. canAutofix is true if the port can be unbound by AdGuard Home // be set. canAutofix is true if the port can be unbound by AdGuard Home
// automatically. // automatically.
func (req *checkConfReq) validateDNS( func (req *checkConfReq) validateDNS(
ctx context.Context,
l *slog.Logger,
tcpPorts aghalg.UniqChecker[tcpPort], tcpPorts aghalg.UniqChecker[tcpPort],
) (canAutofix bool, err error) { ) (canAutofix bool, err error) {
defer func() { err = errors.Annotate(err, "validating ports: %w") }() defer func() { err = errors.Annotate(err, "validating ports: %w") }()
@ -155,10 +158,10 @@ func (req *checkConfReq) validateDNS(
} }
// Try to fix automatically. // Try to fix automatically.
canAutofix = checkDNSStubListener() canAutofix = checkDNSStubListener(ctx, l)
if canAutofix && req.DNS.Autofix { if canAutofix && req.DNS.Autofix {
if derr := disableDNSStubListener(); derr != nil { if derr := disableDNSStubListener(ctx, l); derr != nil {
log.Error("disabling DNSStubListener: %s", err) l.ErrorContext(ctx, "disabling DNSStubListener", slogutil.KeyError, err)
} }
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, port)) err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, port))
@ -185,7 +188,7 @@ func (web *webAPI) handleInstallCheckConfig(w http.ResponseWriter, r *http.Reque
resp.Web.Status = err.Error() resp.Web.Status = err.Error()
} }
if resp.DNS.CanAutofix, err = req.validateDNS(tcpPorts); err != nil { if resp.DNS.CanAutofix, err = req.validateDNS(r.Context(), web.logger, tcpPorts); err != nil {
resp.DNS.Status = err.Error() resp.DNS.Status = err.Error()
} else if !req.DNS.IP.IsUnspecified() { } else if !req.DNS.IP.IsUnspecified() {
resp.StaticIP = handleStaticIP(req.DNS.IP, req.SetStaticIP) resp.StaticIP = handleStaticIP(req.DNS.IP, req.SetStaticIP)
@ -234,27 +237,40 @@ func handleStaticIP(ip netip.Addr, set bool) staticIPJSON {
return resp return resp
} }
// Check if DNSStubListener is active // checkDNSStubListener returns true if DNSStubListener is active.
func checkDNSStubListener() bool { func checkDNSStubListener(ctx context.Context, l *slog.Logger) (ok bool) {
if runtime.GOOS != "linux" { if runtime.GOOS != "linux" {
return false return false
} }
cmd := exec.Command("systemctl", "is-enabled", "systemd-resolved") var cmd *exec.Cmd
log.Tracef("executing %s %v", cmd.Path, cmd.Args) var err error
_, err := cmd.Output()
defer func() {
if ok {
return
}
l.ErrorContext(
ctx,
"execution failed",
"cmd", cmd.Path,
"code", cmd.ProcessState.ExitCode(),
slogutil.KeyError, err,
)
}()
cmd = exec.Command("systemctl", "is-enabled", "systemd-resolved")
l.DebugContext(ctx, "executing", "cmd", cmd.Path, "args", cmd.Args)
_, err = cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 { if err != nil || cmd.ProcessState.ExitCode() != 0 {
log.Info("command %s has failed: %v code:%d",
cmd.Path, err, cmd.ProcessState.ExitCode())
return false return false
} }
cmd = exec.Command("grep", "-E", "#?DNSStubListener=yes", "/etc/systemd/resolved.conf") cmd = exec.Command("grep", "-E", "#?DNSStubListener=yes", "/etc/systemd/resolved.conf")
log.Tracef("executing %s %v", cmd.Path, cmd.Args) l.DebugContext(ctx, "executing", "cmd", cmd.Path, "args", cmd.Args)
_, err = cmd.Output() _, err = cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 { if err != nil || cmd.ProcessState.ExitCode() != 0 {
log.Info("command %s has failed: %v code:%d",
cmd.Path, err, cmd.ProcessState.ExitCode())
return false return false
} }
@ -270,8 +286,9 @@ DNSStubListener=no
) )
const resolvConfPath = "/etc/resolv.conf" const resolvConfPath = "/etc/resolv.conf"
// Deactivate DNSStubListener // disableDNSStubListener deactivates DNSStubListerner and returns an error, if
func disableDNSStubListener() (err error) { // any.
func disableDNSStubListener(ctx context.Context, l *slog.Logger) (err error) {
dir := filepath.Dir(resolvedConfPath) dir := filepath.Dir(resolvedConfPath)
err = os.MkdirAll(dir, 0o755) err = os.MkdirAll(dir, 0o755)
if err != nil { if err != nil {
@ -291,7 +308,7 @@ func disableDNSStubListener() (err error) {
} }
cmd := exec.Command("systemctl", "reload-or-restart", "systemd-resolved") cmd := exec.Command("systemctl", "reload-or-restart", "systemd-resolved")
log.Tracef("executing %s %v", cmd.Path, cmd.Args) l.DebugContext(ctx, "executing", "cmd", cmd.Path, "args", cmd.Args)
_, err = cmd.Output() _, err = cmd.Output()
if err != nil { if err != nil {
return err return err

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log/slog"
"net/http" "net/http"
"os" "os"
"os/exec" "os/exec"
@ -16,7 +17,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/updater" "github.com/AdguardTeam/AdGuardHome/internal/updater"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/logutil/slogutil"
) )
@ -147,7 +147,7 @@ func (web *webAPI) handleUpdate(w http.ResponseWriter, r *http.Request) {
// The background context is used because the underlying functions wrap it // The background context is used because the underlying functions wrap it
// with timeout and shut down the server, which handles current request. It // with timeout and shut down the server, which handles current request. It
// also should be done in a separate goroutine for the same reason. // also should be done in a separate goroutine for the same reason.
go finishUpdate(context.Background(), execPath, web.conf.runningAsService) go finishUpdate(context.Background(), web.logger, execPath, web.conf.runningAsService)
} }
// versionResponse is the response for /control/version.json endpoint. // versionResponse is the response for /control/version.json endpoint.
@ -188,10 +188,17 @@ func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) {
} }
// finishUpdate completes an update procedure. // finishUpdate completes an update procedure.
func finishUpdate(ctx context.Context, execPath string, runningAsService bool) { func finishUpdate(ctx context.Context, l *slog.Logger, execPath string, runningAsService bool) {
var err error var err error
defer func() {
if err != nil {
l.ErrorContext(ctx, "restarting", slogutil.KeyError, err)
log.Info("stopping all tasks") os.Exit(1)
}
}()
l.InfoContext(ctx, "stopping all tasks")
cleanup(ctx) cleanup(ctx)
cleanupAlways() cleanupAlways()
@ -206,28 +213,28 @@ func finishUpdate(ctx context.Context, execPath string, runningAsService bool) {
cmd := exec.Command("cmd", "/c", "net stop AdGuardHome & net start AdGuardHome") cmd := exec.Command("cmd", "/c", "net stop AdGuardHome & net start AdGuardHome")
err = cmd.Start() err = cmd.Start()
if err != nil { if err != nil {
log.Fatalf("restarting: stopping: %s", err) return
} }
os.Exit(0) os.Exit(0)
} }
cmd := exec.Command(execPath, os.Args[1:]...) cmd := exec.Command(execPath, os.Args[1:]...)
log.Info("restarting: %q %q", execPath, os.Args[1:]) l.InfoContext(ctx, "restarting", "exec_path", execPath, "args", os.Args[1:])
cmd.Stdin = os.Stdin cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
err = cmd.Start() err = cmd.Start()
if err != nil { if err != nil {
log.Fatalf("restarting:: %s", err) return
} }
os.Exit(0) os.Exit(0)
} }
log.Info("restarting: %q %q", execPath, os.Args[1:]) l.InfoContext(ctx, "restarting", "exec_path", execPath, "args", os.Args[1:])
err = syscall.Exec(execPath, os.Args, os.Environ()) err = syscall.Exec(execPath, os.Args, os.Environ())
if err != nil { if err != nil {
log.Fatalf("restarting: %s", err) return
} }
} }

View file

@ -640,7 +640,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
fatalOnError(err) fatalOnError(err)
if config.HTTPConfig.Pprof.Enabled { if config.HTTPConfig.Pprof.Enabled {
startPprof(config.HTTPConfig.Pprof.Port) startPprof(slogLogger, config.HTTPConfig.Pprof.Port)
} }
} }

View file

@ -199,6 +199,9 @@ func (web *webAPI) tlsConfigChanged(ctx context.Context, tlsConf tlsConfigSettin
web.httpsServer.cond.L.Unlock() web.httpsServer.cond.L.Unlock()
} }
// loggerKeyServer is the key used by [webAPI] to identify servers.
const loggerKeyServer = "server"
// start - start serving HTTP requests // start - start serving HTTP requests
func (web *webAPI) start(ctx context.Context) { func (web *webAPI) start(ctx context.Context) {
web.logger.InfoContext(ctx, "AdGuard Home is available at the following addresses:") web.logger.InfoContext(ctx, "AdGuard Home is available at the following addresses:")
@ -216,14 +219,16 @@ func (web *webAPI) start(ctx context.Context) {
// Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies. // Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies.
hdlr := h2c.NewHandler(withMiddlewares(Context.mux, limitRequestBody), &http2.Server{}) hdlr := h2c.NewHandler(withMiddlewares(Context.mux, limitRequestBody), &http2.Server{})
logger := web.baseLogger.With(loggerKeyServer, "plain")
// Create a new instance, because the Web is not usable after Shutdown. // Create a new instance, because the Web is not usable after Shutdown.
web.httpServer = &http.Server{ web.httpServer = &http.Server{
ErrorLog: log.StdLog("web: plain", log.DEBUG),
Addr: web.conf.BindAddr.String(), Addr: web.conf.BindAddr.String(),
Handler: hdlr, Handler: hdlr,
ReadTimeout: web.conf.ReadTimeout, ReadTimeout: web.conf.ReadTimeout,
ReadHeaderTimeout: web.conf.ReadHeaderTimeout, ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
WriteTimeout: web.conf.WriteTimeout, WriteTimeout: web.conf.WriteTimeout,
ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError),
} }
go func() { go func() {
errs <- web.httpServer.ListenAndServe() errs <- web.httpServer.ListenAndServe()
@ -289,19 +294,21 @@ func (web *webAPI) tlsServerLoop(ctx context.Context) {
}() }()
addr := netip.AddrPortFrom(web.conf.BindAddr.Addr(), portHTTPS).String() addr := netip.AddrPortFrom(web.conf.BindAddr.Addr(), portHTTPS).String()
logger := web.baseLogger.With(loggerKeyServer, "https")
web.httpsServer.server = &http.Server{ web.httpsServer.server = &http.Server{
ErrorLog: log.StdLog("web: https", log.DEBUG),
Addr: addr, Addr: addr,
Handler: withMiddlewares(Context.mux, limitRequestBody),
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
Certificates: []tls.Certificate{web.httpsServer.cert}, Certificates: []tls.Certificate{web.httpsServer.cert},
RootCAs: Context.tlsRoots, RootCAs: Context.tlsRoots,
CipherSuites: Context.tlsCipherIDs, CipherSuites: Context.tlsCipherIDs,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
}, },
Handler: withMiddlewares(Context.mux, limitRequestBody),
ReadTimeout: web.conf.ReadTimeout, ReadTimeout: web.conf.ReadTimeout,
ReadHeaderTimeout: web.conf.ReadHeaderTimeout, ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
WriteTimeout: web.conf.WriteTimeout, WriteTimeout: web.conf.WriteTimeout,
ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError),
} }
printHTTPAddresses(urlutil.SchemeHTTPS) printHTTPAddresses(urlutil.SchemeHTTPS)
@ -342,7 +349,7 @@ func (web *webAPI) mustStartHTTP3(ctx context.Context, address string) {
} }
// startPprof launches the debug and profiling server on the provided port. // startPprof launches the debug and profiling server on the provided port.
func startPprof(port uint16) { func startPprof(baseLogger *slog.Logger, port uint16) {
addr := netip.AddrPortFrom(netutil.IPv4Localhost(), port) addr := netip.AddrPortFrom(netutil.IPv4Localhost(), port)
runtime.SetBlockProfileRate(1) runtime.SetBlockProfileRate(1)
@ -352,12 +359,15 @@ func startPprof(port uint16) {
httputil.RoutePprof(mux) httputil.RoutePprof(mux)
go func() { go func() {
defer log.OnPanic("pprof server") ctx := context.Background()
logger := baseLogger.With(slogutil.KeyPrefix, "pprof")
log.Info("pprof: listening on %q", addr) defer slogutil.RecoverAndLog(ctx, logger)
logger.InfoContext(ctx, "listening", "addr", addr)
err := http.ListenAndServe(addr.String(), mux) err := http.ListenAndServe(addr.String(), mux)
if !errors.Is(err, http.ErrServerClosed) { if !errors.Is(err, http.ErrServerClosed) {
log.Error("pprof: shutting down: %s", err) logger.ErrorContext(ctx, "shutting down", slogutil.KeyError, err)
} }
}() }()
} }