Pull request 2303: AGDNS-2505-upd-next

Squashed commit of the following:

commit 586b0eb180
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Nov 12 19:58:56 2024 +0300

    next: upd more

commit d729aa150f
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Nov 12 16:53:15 2024 +0300

    next/websvc: upd more

commit 0c64e6cfc6
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon Nov 11 21:08:51 2024 +0300

    next: upd more

commit 05eec75222
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Nov 8 19:20:02 2024 +0300

    next: upd code
This commit is contained in:
Ainar Garipov 2024-11-13 15:44:21 +03:00
parent ac5a96fada
commit 1d6d85cff4
34 changed files with 637 additions and 601 deletions

3
go.mod
View file

@ -4,14 +4,13 @@ go 1.23.3
require ( require (
github.com/AdguardTeam/dnsproxy v0.73.3 github.com/AdguardTeam/dnsproxy v0.73.3
github.com/AdguardTeam/golibs v0.30.2 github.com/AdguardTeam/golibs v0.30.3
github.com/AdguardTeam/urlfilter v0.20.0 github.com/AdguardTeam/urlfilter v0.20.0
github.com/NYTimes/gziphandler v1.1.1 github.com/NYTimes/gziphandler v1.1.1
github.com/ameshkov/dnscrypt/v2 v2.3.0 github.com/ameshkov/dnscrypt/v2 v2.3.0
github.com/bluele/gcache v0.0.2 github.com/bluele/gcache v0.0.2
github.com/c2h5oh/datasize v0.0.0-20231215233829-aa82cc1e6500 github.com/c2h5oh/datasize v0.0.0-20231215233829-aa82cc1e6500
github.com/digineo/go-ipset/v2 v2.2.1 github.com/digineo/go-ipset/v2 v2.2.1
github.com/dimfeld/httptreemux/v5 v5.5.0
github.com/fsnotify/fsnotify v1.8.0 github.com/fsnotify/fsnotify v1.8.0
github.com/go-ping/ping v1.1.0 github.com/go-ping/ping v1.1.0
github.com/google/go-cmp v0.6.0 github.com/google/go-cmp v0.6.0

6
go.sum
View file

@ -1,7 +1,7 @@
github.com/AdguardTeam/dnsproxy v0.73.3 h1:aacr6Wu0ed94DDD+gSB6EwF8nvyq0+DAc7oFOgtgUpA= github.com/AdguardTeam/dnsproxy v0.73.3 h1:aacr6Wu0ed94DDD+gSB6EwF8nvyq0+DAc7oFOgtgUpA=
github.com/AdguardTeam/dnsproxy v0.73.3/go.mod h1:18ssqhDgOCiVIwYmmVuXVM05wSwrzkO2yjKhVRWJX/g= github.com/AdguardTeam/dnsproxy v0.73.3/go.mod h1:18ssqhDgOCiVIwYmmVuXVM05wSwrzkO2yjKhVRWJX/g=
github.com/AdguardTeam/golibs v0.30.2 h1:urU/NAyIvQOeArBqDmKCDpaRkfTCJ26uSiSuDMKQfuY= github.com/AdguardTeam/golibs v0.30.3 h1:pRxLjMCJ1cZccjZWMMuKxzQQGEpFbmtyj4Tg7nk5rY0=
github.com/AdguardTeam/golibs v0.30.2/go.mod h1:FkwcNQEJoGsgDGXcalrVa/4gWbE68KsmE2guXWtBQUE= github.com/AdguardTeam/golibs v0.30.3/go.mod h1:Ir9dlHfb8nRQsG3Qgo1zoGL+k1qMbcBtb8tcnsvzdAE=
github.com/AdguardTeam/urlfilter v0.20.0 h1:X32qiuVCVd8WDYCEsbdZKfXMzwdVqrdulamtUi4rmzs= github.com/AdguardTeam/urlfilter v0.20.0 h1:X32qiuVCVd8WDYCEsbdZKfXMzwdVqrdulamtUi4rmzs=
github.com/AdguardTeam/urlfilter v0.20.0/go.mod h1:gjrywLTxfJh6JOkwi9SU+frhP7kVVEZ5exFGkR99qpk= github.com/AdguardTeam/urlfilter v0.20.0/go.mod h1:gjrywLTxfJh6JOkwi9SU+frhP7kVVEZ5exFGkR99qpk=
github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I= github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I=
@ -25,8 +25,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/digineo/go-ipset/v2 v2.2.1 h1:k6skY+0fMqeUjjeWO/m5OuWPSZUAn7AucHMnQ1MX77g= github.com/digineo/go-ipset/v2 v2.2.1 h1:k6skY+0fMqeUjjeWO/m5OuWPSZUAn7AucHMnQ1MX77g=
github.com/digineo/go-ipset/v2 v2.2.1/go.mod h1:wBsNzJlZlABHUITkesrggFnZQtgW5wkqw1uo8Qxe0VU= github.com/digineo/go-ipset/v2 v2.2.1/go.mod h1:wBsNzJlZlABHUITkesrggFnZQtgW5wkqw1uo8Qxe0VU=
github.com/dimfeld/httptreemux/v5 v5.5.0 h1:p8jkiMrCuZ0CmhwYLcbNbl7DDo21fozhKHQ2PccwOFQ=
github.com/dimfeld/httptreemux/v5 v5.5.0/go.mod h1:QeEylH57C0v3VO0tkKraVz9oD3Uu93CKPnTLbsidvSw=
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M= github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=

View file

@ -146,16 +146,6 @@ func IsOpenWrt() (ok bool) {
return isOpenWrt() return isOpenWrt()
} }
// NotifyReconfigureSignal notifies c on receiving reconfigure signals.
func NotifyReconfigureSignal(c chan<- os.Signal) {
notifyReconfigureSignal(c)
}
// IsReconfigureSignal returns true if sig is a reconfigure signal.
func IsReconfigureSignal(sig os.Signal) (ok bool) {
return isReconfigureSignal(sig)
}
// SendShutdownSignal sends the shutdown signal to the channel. // SendShutdownSignal sends the shutdown signal to the channel.
func SendShutdownSignal(c chan<- os.Signal) { func SendShutdownSignal(c chan<- os.Signal) {
sendShutdownSignal(c) sendShutdownSignal(c)

View file

@ -1,22 +1,11 @@
//go:build darwin || freebsd || linux || openbsd //go:build unix
package aghos package aghos
import ( import (
"os" "os"
"os/signal"
"golang.org/x/sys/unix"
) )
func notifyReconfigureSignal(c chan<- os.Signal) {
signal.Notify(c, unix.SIGHUP)
}
func isReconfigureSignal(sig os.Signal) (ok bool) {
return sig == unix.SIGHUP
}
func sendShutdownSignal(_ chan<- os.Signal) { func sendShutdownSignal(_ chan<- os.Signal) {
// On Unix we are already notified by the system. // On Unix we are already notified by the system.
} }

View file

@ -4,12 +4,11 @@ package aghos
import ( import (
"os" "os"
"os/signal"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
func setRlimit(val uint64) (err error) { func setRlimit(_ uint64) (err error) {
return Unsupported("setrlimit") return Unsupported("setrlimit")
} }
@ -38,14 +37,6 @@ func isOpenWrt() (ok bool) {
return false return false
} }
func notifyReconfigureSignal(c chan<- os.Signal) {
signal.Notify(c, windows.SIGHUP)
}
func isReconfigureSignal(sig os.Signal) (ok bool) {
return sig == windows.SIGHUP
}
func sendShutdownSignal(c chan<- os.Signal) { func sendShutdownSignal(c chan<- os.Signal) {
c <- os.Interrupt c <- os.Interrupt
} }

View file

@ -58,7 +58,7 @@ func (w *FSWatcher) Add(name string) (err error) {
// ServiceWithConfig is a fake [agh.ServiceWithConfig] implementation for tests. // ServiceWithConfig is a fake [agh.ServiceWithConfig] implementation for tests.
type ServiceWithConfig[ConfigType any] struct { type ServiceWithConfig[ConfigType any] struct {
OnStart func() (err error) OnStart func(ctx context.Context) (err error)
OnShutdown func(ctx context.Context) (err error) OnShutdown func(ctx context.Context) (err error)
OnConfig func() (c ConfigType) OnConfig func() (c ConfigType)
} }
@ -68,8 +68,8 @@ var _ agh.ServiceWithConfig[struct{}] = (*ServiceWithConfig[struct{}])(nil)
// Start implements the [agh.ServiceWithConfig] interface for // Start implements the [agh.ServiceWithConfig] interface for
// *ServiceWithConfig. // *ServiceWithConfig.
func (s *ServiceWithConfig[_]) Start() (err error) { func (s *ServiceWithConfig[_]) Start(ctx context.Context) (err error) {
return s.OnStart() return s.OnStart(ctx)
} }
// Shutdown implements the [agh.ServiceWithConfig] interface for // Shutdown implements the [agh.ServiceWithConfig] interface for

View file

@ -82,7 +82,7 @@ type Empty struct{}
var _ agh.ServiceWithConfig[*Config] = Empty{} var _ agh.ServiceWithConfig[*Config] = Empty{}
// Start implements the [Service] interface for Empty. // Start implements the [Service] interface for Empty.
func (Empty) Start() (err error) { return nil } func (Empty) Start(_ context.Context) (err error) { return nil }
// Shutdown implements the [Service] interface for Empty. // Shutdown implements the [Service] interface for Empty.
func (Empty) Shutdown(_ context.Context) (err error) { return nil } func (Empty) Shutdown(_ context.Context) (err error) { return nil }

View file

@ -1,36 +1,9 @@
// Package agh contains common entities and interfaces of AdGuard Home. // Package agh contains common entities and interfaces of AdGuard Home.
package agh package agh
import "context" import (
"github.com/AdguardTeam/golibs/service"
// Service is the interface for API servers. )
//
// TODO(a.garipov): Consider adding a context to Start.
//
// TODO(a.garipov): Consider adding a Wait method or making an extension
// interface for that.
type Service interface {
// Start starts the service. It does not block.
Start() (err error)
// Shutdown gracefully stops the service. ctx is used to determine
// a timeout before trying to stop the service less gracefully.
Shutdown(ctx context.Context) (err error)
}
// type check
var _ Service = EmptyService{}
// EmptyService is a [Service] that does nothing.
//
// TODO(a.garipov): Remove if unnecessary.
type EmptyService struct{}
// Start implements the [Service] interface for EmptyService.
func (EmptyService) Start() (err error) { return nil }
// Shutdown implements the [Service] interface for EmptyService.
func (EmptyService) Shutdown(_ context.Context) (err error) { return nil }
// ServiceWithConfig is an extension of the [Service] interface for services // ServiceWithConfig is an extension of the [Service] interface for services
// that can return their configuration. // that can return their configuration.
@ -38,7 +11,7 @@ func (EmptyService) Shutdown(_ context.Context) (err error) { return nil }
// TODO(a.garipov): Consider removing this generic interface if we figure out // TODO(a.garipov): Consider removing this generic interface if we figure out
// how to make it testable in a better way. // how to make it testable in a better way.
type ServiceWithConfig[ConfigType any] interface { type ServiceWithConfig[ConfigType any] interface {
Service service.Interface
Config() (c ConfigType) Config() (c ConfigType)
} }
@ -51,7 +24,7 @@ var _ ServiceWithConfig[struct{}] = (*EmptyServiceWithConfig[struct{}])(nil)
// //
// TODO(a.garipov): Remove if unnecessary. // TODO(a.garipov): Remove if unnecessary.
type EmptyServiceWithConfig[ConfigType any] struct { type EmptyServiceWithConfig[ConfigType any] struct {
EmptyService service.Empty
Conf ConfigType Conf ConfigType
} }

View file

@ -12,11 +12,15 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr" "github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
"github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/service"
) )
// Main is the entry point of AdGuard Home. // Main is the entry point of AdGuard Home.
func Main(embeddedFrontend fs.FS) { func Main(embeddedFrontend fs.FS) {
ctx := context.Background()
start := time.Now() start := time.Now()
cmdName := os.Args[0] cmdName := os.Args[0]
@ -26,70 +30,69 @@ func Main(embeddedFrontend fs.FS) {
os.Exit(exitCode) os.Exit(exitCode)
} }
err = setLog(opts) baseLogger := newBaseLogger(opts)
check(err)
log.Info("starting adguard home, version %s, pid %d", version.Version(), os.Getpid()) baseLogger.InfoContext(
ctx,
"starting adguard home",
"version", version.Version(),
"pid", os.Getpid(),
)
if opts.workDir != "" { if opts.workDir != "" {
log.Info("changing working directory to %q", opts.workDir) baseLogger.InfoContext(ctx, "changing working directory", "dir", opts.workDir)
err = os.Chdir(opts.workDir) err = os.Chdir(opts.workDir)
check(err) errors.Check(err)
} }
frontend, err := frontendFromOpts(opts, embeddedFrontend) frontend, err := frontendFromOpts(ctx, baseLogger, opts, embeddedFrontend)
check(err) errors.Check(err)
startCtx, startCancel := context.WithTimeout(ctx, defaultTimeoutStart)
defer startCancel()
confMgrConf := &configmgr.Config{ confMgrConf := &configmgr.Config{
Frontend: frontend, BaseLogger: baseLogger,
WebAddr: opts.webAddr, Logger: baseLogger.With(slogutil.KeyPrefix, "configmgr"),
Start: start, Frontend: frontend,
FileName: opts.confFile, WebAddr: opts.webAddr,
Start: start,
FileName: opts.confFile,
} }
confMgr, err := newConfigMgr(confMgrConf) confMgr, err := configmgr.New(startCtx, confMgrConf)
check(err) errors.Check(err)
web := confMgr.Web() web := confMgr.Web()
err = web.Start() err = web.Start(startCtx)
check(err) errors.Check(err)
dns := confMgr.DNS() dns := confMgr.DNS()
err = dns.Start() err = dns.Start(startCtx)
check(err) errors.Check(err)
sigHdlr := newSignalHandler( sigHdlr := newSignalHandler(
baseLogger.With(slogutil.KeyPrefix, service.SignalHandlerPrefix),
confMgrConf, confMgrConf,
opts.pidFile, opts.pidFile,
web, web,
dns, dns,
) )
sigHdlr.handle() os.Exit(sigHdlr.handle(ctx))
} }
// defaultTimeout is the timeout used for some operations where another timeout // Default timeouts.
// hasn't been defined yet. //
const defaultTimeout = 5 * time.Second // TODO(a.garipov): Make configurable.
const (
// ctxWithDefaultTimeout is a helper function that returns a context with defaultTimeoutStart = 1 * time.Minute
// timeout set to defaultTimeout. defaultTimeoutShutdown = 5 * time.Second
func ctxWithDefaultTimeout() (ctx context.Context, cancel context.CancelFunc) { )
return context.WithTimeout(context.Background(), defaultTimeout)
}
// newConfigMgr returns a new configuration manager using defaultTimeout as the // newConfigMgr returns a new configuration manager using defaultTimeout as the
// context timeout. // context timeout.
func newConfigMgr(c *configmgr.Config) (m *configmgr.Manager, err error) { func newConfigMgr(ctx context.Context, c *configmgr.Config) (m *configmgr.Manager, err error) {
ctx, cancel := ctxWithDefaultTimeout()
defer cancel()
return configmgr.New(ctx, c) return configmgr.New(ctx, c)
} }
// check is a simple error-checking helper. It must only be used within Main.
func check(err error) {
if err != nil {
panic(err)
}
}

View file

@ -1,39 +1,39 @@
package cmd package cmd
import ( import (
"fmt" "io"
"log/slog"
"os" "os"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/log"
) )
// syslogServiceName is the name of the AdGuard Home service used for writing // newBaseLogger constructs a base logger based on the command-line options.
// logs to the system log. // opts must not be nil.
const syslogServiceName = "AdGuardHome" func newBaseLogger(opts *options) (baseLogger *slog.Logger) {
var output io.Writer
// setLog sets up the text logging.
//
// TODO(a.garipov): Add parameters from configuration file.
func setLog(opts *options) (err error) {
switch opts.confFile { switch opts.confFile {
case "stdout": case "stdout":
log.SetOutput(os.Stdout) output = os.Stdout
case "stderr": case "stderr":
log.SetOutput(os.Stderr) output = os.Stderr
case "syslog": case "syslog":
err = aghos.ConfigureSyslog(syslogServiceName) // TODO(a.garipov): Add a syslog handler to golibs.
if err != nil {
return fmt.Errorf("initializing syslog: %w", err)
}
default: default:
// TODO(a.garipov): Use the path. // TODO(a.garipov): Use the path.
} }
lvl := slog.LevelInfo
if opts.verbose { if opts.verbose {
log.SetLevel(log.DEBUG) lvl = slog.LevelDebug
log.Debug("verbose logging enabled")
} }
return nil return slogutil.New(&slogutil.Config{
Output: output,
// TODO(a.garipov): Get from config?
Format: slogutil.FormatText,
Level: lvl,
// TODO(a.garipov): Get from config.
AddTimestamp: true,
})
} }

View file

@ -1,11 +1,13 @@
package cmd package cmd
import ( import (
"context"
"encoding" "encoding"
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/fs" "io/fs"
"log/slog"
"net/netip" "net/netip"
"os" "os"
"slices" "slices"
@ -14,7 +16,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/configmigrate" "github.com/AdguardTeam/AdGuardHome/internal/configmigrate"
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr" "github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
"github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/osutil"
) )
// options contains all command-line options for the AdGuardHome(.exe) binary. // options contains all command-line options for the AdGuardHome(.exe) binary.
@ -372,13 +374,13 @@ func processOptions(
) (exitCode int, needExit bool) { ) (exitCode int, needExit bool) {
if parseErr != nil { if parseErr != nil {
// Assume that usage has already been printed. // Assume that usage has already been printed.
return statusArgumentError, true return osutil.ExitCodeArgumentError, true
} }
if opts.help { if opts.help {
usage(cmdName, os.Stdout) usage(cmdName, os.Stdout)
return statusSuccess, true return osutil.ExitCodeSuccess, true
} }
if opts.version { if opts.version {
@ -388,7 +390,7 @@ func processOptions(
fmt.Printf("AdGuard Home %s\n", version.Version()) fmt.Printf("AdGuard Home %s\n", version.Version())
} }
return statusSuccess, true return osutil.ExitCodeSuccess, true
} }
if opts.checkConfig { if opts.checkConfig {
@ -396,21 +398,26 @@ func processOptions(
if err != nil { if err != nil {
_, _ = io.WriteString(os.Stdout, err.Error()+"\n") _, _ = io.WriteString(os.Stdout, err.Error()+"\n")
return statusError, true return osutil.ExitCodeFailure, true
} }
return statusSuccess, true return osutil.ExitCodeSuccess, true
} }
return 0, false return 0, false
} }
// frontendFromOpts returns the frontend to use based on the options. // frontendFromOpts returns the frontend to use based on the options.
func frontendFromOpts(opts *options, embeddedFrontend fs.FS) (frontend fs.FS, err error) { func frontendFromOpts(
ctx context.Context,
logger *slog.Logger,
opts *options,
embeddedFrontend fs.FS,
) (frontend fs.FS, err error) {
const frontendSubdir = "build/static" const frontendSubdir = "build/static"
if opts.localFrontend { if opts.localFrontend {
log.Info("warning: using local frontend files") logger.WarnContext(ctx, "using local frontend files")
return os.DirFS(frontendSubdir), nil return os.DirFS(frontendSubdir), nil
} }

View file

@ -1,18 +1,26 @@
package cmd package cmd
import ( import (
"context"
"fmt"
"log/slog"
"os" "os"
"strconv" "strconv"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr" "github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil" "github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/service"
) )
// signalHandler processes incoming signals and shuts services down. // signalHandler processes incoming signals and shuts services down.
type signalHandler struct { type signalHandler struct {
// logger is used for logging the operation of the signal handler.
logger *slog.Logger
// confMgrConf contains the configuration parameters for the configuration // confMgrConf contains the configuration parameters for the configuration
// manager. // manager.
confMgrConf *configmgr.Config confMgrConf *configmgr.Config
@ -24,145 +32,172 @@ type signalHandler struct {
pidFile string pidFile string
// services are the services that are shut down before application exiting. // services are the services that are shut down before application exiting.
services []agh.Service services []service.Interface
// shutdownTimeout is the timeout for the shutdown operation.
shutdownTimeout time.Duration
} }
// handle processes OS signals. // handle processes OS signals. It blocks until a termination or a
func (h *signalHandler) handle() { // reconfiguration signal is received, after which it either shuts down all
defer log.OnPanic("signalHandler.handle") // services or reconfigures them. ctx is used for logging and serves as the
// base for the shutdown timeout. status is [osutil.ExitCodeSuccess] on success
// and [osutil.ExitCodeFailure] on error.
//
// TODO(a.garipov): Add reconfiguration logic to golibs.
func (h *signalHandler) handle(ctx context.Context) (status osutil.ExitCode) {
defer slogutil.RecoverAndLog(ctx, h.logger)
h.writePID() h.writePID(ctx)
for sig := range h.signal { for sig := range h.signal {
log.Info("sighdlr: received signal %q", sig) h.logger.InfoContext(ctx, "received", "signal", sig)
if aghos.IsReconfigureSignal(sig) { if osutil.IsReconfigureSignal(sig) {
h.reconfigure() err := h.reconfigure(ctx)
if err != nil {
h.logger.ErrorContext(ctx, "reconfiguration error", slogutil.KeyError, err)
return osutil.ExitCodeFailure
}
} else if osutil.IsShutdownSignal(sig) { } else if osutil.IsShutdownSignal(sig) {
status := h.shutdown() status = h.shutdown(ctx)
h.removePID()
log.Info("sighdlr: exiting with status %d", status) h.removePID(ctx)
os.Exit(status) return status
} }
} }
// Shouldn't happen, since h.signal is currently never closed.
panic("unexpected close of h.signal")
}
// writePID writes the PID to the file, if needed. Any errors are reported to
// log.
func (h *signalHandler) writePID(ctx context.Context) {
if h.pidFile == "" {
return
}
pid := os.Getpid()
data := strconv.AppendInt(nil, int64(pid), 10)
data = append(data, '\n')
err := aghos.WriteFile(h.pidFile, data, 0o644)
if err != nil {
h.logger.ErrorContext(ctx, "writing pidfile", slogutil.KeyError, err)
return
}
h.logger.DebugContext(ctx, "wrote pid", "file", h.pidFile, "pid", pid)
} }
// reconfigure rereads the configuration file and updates and restarts services. // reconfigure rereads the configuration file and updates and restarts services.
func (h *signalHandler) reconfigure() { func (h *signalHandler) reconfigure(ctx context.Context) (err error) {
log.Info("sighdlr: reconfiguring adguard home") h.logger.InfoContext(ctx, "reconfiguring started")
status := h.shutdown() status := h.shutdown(ctx)
if status != statusSuccess { if status != osutil.ExitCodeSuccess {
log.Info("sighdlr: reconfiguring: exiting with status %d", status) return errors.Error("shutdown failed")
os.Exit(status)
} }
// TODO(a.garipov): This is a very rough way to do it. Some services can be // TODO(a.garipov): This is a very rough way to do it. Some services can
// reconfigured without the full shutdown, and the error handling is // be reconfigured without the full shutdown, and the error handling is
// currently not the best. // currently not the best.
confMgr, err := newConfigMgr(h.confMgrConf) var errs []error
check(err)
ctx, cancel := context.WithTimeout(ctx, defaultTimeoutStart)
defer cancel()
confMgr, err := newConfigMgr(ctx, h.confMgrConf)
if err != nil {
errs = append(errs, fmt.Errorf("configuration manager: %w", err))
}
web := confMgr.Web() web := confMgr.Web()
err = web.Start() err = web.Start(ctx)
check(err) if err != nil {
errs = append(errs, fmt.Errorf("starting web: %w", err))
}
dns := confMgr.DNS() dns := confMgr.DNS()
err = dns.Start() err = dns.Start(ctx)
check(err) if err != nil {
errs = append(errs, fmt.Errorf("starting dns: %w", err))
}
h.services = []agh.Service{ if len(errs) > 0 {
return errors.Join(errs...)
}
h.services = []service.Interface{
dns, dns,
web, web,
} }
log.Info("sighdlr: successfully reconfigured adguard home") h.logger.InfoContext(ctx, "reconfiguring finished")
return nil
} }
// Exit status constants.
const (
statusSuccess = 0
statusError = 1
statusArgumentError = 2
)
// shutdown gracefully shuts down all services. // shutdown gracefully shuts down all services.
func (h *signalHandler) shutdown() (status int) { func (h *signalHandler) shutdown(ctx context.Context) (status int) {
ctx, cancel := ctxWithDefaultTimeout() ctx, cancel := context.WithTimeout(ctx, h.shutdownTimeout)
defer cancel() defer cancel()
status = statusSuccess status = osutil.ExitCodeSuccess
log.Info("sighdlr: shutting down services") h.logger.InfoContext(ctx, "shutting down")
for i, service := range h.services { for i, svc := range h.services {
err := service.Shutdown(ctx) err := svc.Shutdown(ctx)
if err != nil { if err != nil {
log.Error("sighdlr: shutting down service at index %d: %s", i, err) h.logger.ErrorContext(ctx, "shutting down service", "idx", i, slogutil.KeyError, err)
status = statusError status = osutil.ExitCodeFailure
} }
} }
return status return status
} }
// newSignalHandler returns a new signalHandler that shuts down svcs. // newSignalHandler returns a new signalHandler that shuts down svcs. logger
// and confMgrConf must not be nil.
func newSignalHandler( func newSignalHandler(
logger *slog.Logger,
confMgrConf *configmgr.Config, confMgrConf *configmgr.Config,
pidFile string, pidFile string,
svcs ...agh.Service, svcs ...service.Interface,
) (h *signalHandler) { ) (h *signalHandler) {
h = &signalHandler{ h = &signalHandler{
confMgrConf: confMgrConf, logger: logger,
signal: make(chan os.Signal, 1), confMgrConf: confMgrConf,
pidFile: pidFile, signal: make(chan os.Signal, 1),
services: svcs, pidFile: pidFile,
services: svcs,
shutdownTimeout: defaultTimeoutShutdown,
} }
notifier := osutil.DefaultSignalNotifier{} notifier := osutil.DefaultSignalNotifier{}
osutil.NotifyShutdownSignal(notifier, h.signal) osutil.NotifyShutdownSignal(notifier, h.signal)
aghos.NotifyReconfigureSignal(h.signal) osutil.NotifyReconfigureSignal(notifier, h.signal)
return h return h
} }
// writePID writes the PID to the file, if needed. Any errors are reported to
// log.
func (h *signalHandler) writePID() {
if h.pidFile == "" {
return
}
// Use 8, since most PIDs will fit.
data := make([]byte, 0, 8)
data = strconv.AppendInt(data, int64(os.Getpid()), 10)
data = append(data, '\n')
err := aghos.WriteFile(h.pidFile, data, 0o644)
if err != nil {
log.Error("sighdlr: writing pidfile: %s", err)
return
}
log.Debug("sighdlr: wrote pid to %q", h.pidFile)
}
// removePID removes the PID file, if any. // removePID removes the PID file, if any.
func (h *signalHandler) removePID() { func (h *signalHandler) removePID(ctx context.Context) {
if h.pidFile == "" { if h.pidFile == "" {
return return
} }
err := os.Remove(h.pidFile) err := os.Remove(h.pidFile)
if err != nil { if err != nil {
log.Error("sighdlr: removing pidfile: %s", err) h.logger.ErrorContext(ctx, "removing pidfile", slogutil.KeyError, err)
return return
} }
log.Debug("sighdlr: removed pid at %q", h.pidFile) h.logger.DebugContext(ctx, "removed pidfile", "file", h.pidFile)
} }

View file

@ -4,12 +4,11 @@ import (
"fmt" "fmt"
"net/netip" "net/netip"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
) )
// Configuration Structures
// config is the top-level on-disk configuration structure. // config is the top-level on-disk configuration structure.
type config struct { type config struct {
DNS *dnsConfig `yaml:"dns"` DNS *dnsConfig `yaml:"dns"`
@ -19,35 +18,33 @@ type config struct {
SchemaVersion int `yaml:"schema_version"` SchemaVersion int `yaml:"schema_version"`
} }
const errNoConf errors.Error = "configuration not found" // type check
var _ validator = (*config)(nil)
// validate returns an error if the configuration structure is invalid. // validate implements the [validator] interface for *config.
func (c *config) validate() (err error) { func (c *config) validate() (err error) {
if c == nil { if c == nil {
return errNoConf return errors.ErrNoValue
} }
// TODO(a.garipov): Add more validations. // TODO(a.garipov): Add more validations.
// Keep this in the same order as the fields in the config. // Keep this in the same order as the fields in the config.
validators := []struct { validators := container.KeyValues[string, validator]{{
validate func() (err error) Key: "dns",
name string Value: c.DNS,
}{{
validate: c.DNS.validate,
name: "dns",
}, { }, {
validate: c.HTTP.validate, Key: "http",
name: "http", Value: c.HTTP,
}, { }, {
validate: c.Log.validate, Key: "log",
name: "log", Value: c.Log,
}} }}
for _, v := range validators { for _, kv := range validators {
err = v.validate() err = kv.Value.validate()
if err != nil { if err != nil {
return fmt.Errorf("%s: %w", v.name, err) return fmt.Errorf("%s: %w", kv.Key, err)
} }
} }
@ -65,16 +62,19 @@ type dnsConfig struct {
UseDNS64 bool `yaml:"use_dns64"` UseDNS64 bool `yaml:"use_dns64"`
} }
// validate returns an error if the DNS configuration structure is invalid. // type check
var _ validator = (*dnsConfig)(nil)
// validate implements the [validator] interface for *dnsConfig.
// //
// TODO(a.garipov): Add more validations. // TODO(a.garipov): Add more validations.
func (c *dnsConfig) validate() (err error) { func (c *dnsConfig) validate() (err error) {
// TODO(a.garipov): Add more validations. // TODO(a.garipov): Add more validations.
switch { switch {
case c == nil: case c == nil:
return errNoConf return errors.ErrNoValue
case c.UpstreamTimeout.Duration <= 0: case c.UpstreamTimeout.Duration <= 0:
return newMustBePositiveError("upstream_timeout", c.UpstreamTimeout) return newErrNotPositive("upstream_timeout", c.UpstreamTimeout)
default: default:
return nil return nil
} }
@ -91,15 +91,18 @@ type httpConfig struct {
ForceHTTPS bool `yaml:"force_https"` ForceHTTPS bool `yaml:"force_https"`
} }
// validate returns an error if the HTTP configuration structure is invalid. // type check
var _ validator = (*httpConfig)(nil)
// validate implements the [validator] interface for *httpConfig.
// //
// TODO(a.garipov): Add more validations. // TODO(a.garipov): Add more validations.
func (c *httpConfig) validate() (err error) { func (c *httpConfig) validate() (err error) {
switch { switch {
case c == nil: case c == nil:
return errNoConf return errors.ErrNoValue
case c.Timeout.Duration <= 0: case c.Timeout.Duration <= 0:
return newMustBePositiveError("timeout", c.Timeout) return newErrNotPositive("timeout", c.Timeout)
default: default:
return c.Pprof.validate() return c.Pprof.validate()
} }
@ -111,10 +114,13 @@ type httpPprofConfig struct {
Enabled bool `yaml:"enabled"` Enabled bool `yaml:"enabled"`
} }
// validate returns an error if the pprof configuration structure is invalid. // type check
var _ validator = (*httpPprofConfig)(nil)
// validate implements the [validator] interface for *httpPprofConfig.
func (c *httpPprofConfig) validate() (err error) { func (c *httpPprofConfig) validate() (err error) {
if c == nil { if c == nil {
return errNoConf return errors.ErrNoValue
} }
return nil return nil
@ -126,12 +132,15 @@ type logConfig struct {
Verbose bool `yaml:"verbose"` Verbose bool `yaml:"verbose"`
} }
// validate returns an error if the HTTP configuration structure is invalid. // type check
var _ validator = (*logConfig)(nil)
// validate implements the [validator] interface for *logConfig.
// //
// TODO(a.garipov): Add more validations. // TODO(a.garipov): Add more validations.
func (c *logConfig) validate() (err error) { func (c *logConfig) validate() (err error) {
if c == nil { if c == nil {
return errNoConf return errors.ErrNoValue
} }
return nil return nil

View file

@ -8,6 +8,7 @@ import (
"context" "context"
"fmt" "fmt"
"io/fs" "io/fs"
"log/slog"
"net/netip" "net/netip"
"os" "os"
"slices" "slices"
@ -19,18 +20,22 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc" "github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
// Configuration Manager
// Manager handles full and partial changes in the configuration, persisting // Manager handles full and partial changes in the configuration, persisting
// them to disk if necessary. // them to disk if necessary.
// //
// TODO(a.garipov): Support missing configs and default values. // TODO(a.garipov): Support missing configs and default values.
type Manager struct { type Manager struct {
// baseLogger is used to create loggers for other entities.
baseLogger *slog.Logger
// logger is used for logging the operation of the configuration manager.
logger *slog.Logger
// updMu makes sure that at most one reconfiguration is performed at a time. // updMu makes sure that at most one reconfiguration is performed at a time.
// updMu protects all fields below. // updMu protects all fields below.
updMu *sync.RWMutex updMu *sync.RWMutex
@ -57,12 +62,24 @@ func Validate(fileName string) (err error) {
return err return err
} }
// Don't wrap the error, because it's informative enough as is. err = conf.validate()
return conf.validate() if err != nil {
return fmt.Errorf("validating config: %w", err)
}
return nil
} }
// Config contains the configuration parameters for the configuration manager. // Config contains the configuration parameters for the configuration manager.
type Config struct { type Config struct {
// BaseLogger is used to create loggers for other entities. It must not be
// nil.
BaseLogger *slog.Logger
// Logger is used for logging the operation of the configuration manager.
// It must not be nil.
Logger *slog.Logger
// Frontend is the filesystem with the frontend files. // Frontend is the filesystem with the frontend files.
Frontend fs.FS Frontend fs.FS
@ -93,9 +110,11 @@ func New(ctx context.Context, c *Config) (m *Manager, err error) {
} }
m = &Manager{ m = &Manager{
updMu: &sync.RWMutex{}, baseLogger: c.BaseLogger,
current: conf, logger: c.Logger,
fileName: c.FileName, updMu: &sync.RWMutex{},
current: conf,
fileName: c.FileName,
} }
err = m.assemble(ctx, conf, c.Frontend, c.WebAddr, c.Start) err = m.assemble(ctx, conf, c.Frontend, c.WebAddr, c.Start)
@ -137,6 +156,7 @@ func (m *Manager) assemble(
start time.Time, start time.Time,
) (err error) { ) (err error) {
dnsConf := &dnssvc.Config{ dnsConf := &dnssvc.Config{
Logger: m.baseLogger.With(slogutil.KeyPrefix, "dnssvc"),
Addresses: conf.DNS.Addresses, Addresses: conf.DNS.Addresses,
BootstrapServers: conf.DNS.BootstrapDNS, BootstrapServers: conf.DNS.BootstrapDNS,
UpstreamServers: conf.DNS.UpstreamDNS, UpstreamServers: conf.DNS.UpstreamDNS,
@ -151,6 +171,7 @@ func (m *Manager) assemble(
} }
webSvcConf := &websvc.Config{ webSvcConf := &websvc.Config{
Logger: m.baseLogger.With(slogutil.KeyPrefix, "websvc"),
Pprof: &websvc.PprofConfig{ Pprof: &websvc.PprofConfig{
Port: conf.HTTP.Pprof.Port, Port: conf.HTTP.Pprof.Port,
Enabled: conf.HTTP.Pprof.Enabled, Enabled: conf.HTTP.Pprof.Enabled,
@ -176,7 +197,7 @@ func (m *Manager) assemble(
} }
// write writes the current configuration to disk. // write writes the current configuration to disk.
func (m *Manager) write() (err error) { func (m *Manager) write(ctx context.Context) (err error) {
b, err := yaml.Marshal(m.current) b, err := yaml.Marshal(m.current)
if err != nil { if err != nil {
return fmt.Errorf("encoding: %w", err) return fmt.Errorf("encoding: %w", err)
@ -187,7 +208,7 @@ func (m *Manager) write() (err error) {
return fmt.Errorf("writing: %w", err) return fmt.Errorf("writing: %w", err)
} }
log.Info("configmgr: written to %q", m.fileName) m.logger.InfoContext(ctx, "config file written", "path", m.fileName)
return nil return nil
} }
@ -216,7 +237,7 @@ func (m *Manager) UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error) {
m.updateCurrentDNS(c) m.updateCurrentDNS(c)
return m.write() return m.write(ctx)
} }
// updateDNS recreates the DNS service. m.updMu is expected to be locked. // updateDNS recreates the DNS service. m.updMu is expected to be locked.
@ -270,7 +291,7 @@ func (m *Manager) UpdateWeb(ctx context.Context, c *websvc.Config) (err error) {
m.updateCurrentWeb(c) m.updateCurrentWeb(c)
return m.write() return m.write(ctx)
} }
// updateWeb recreates the web service. m.upd is expected to be locked. // updateWeb recreates the web service. m.upd is expected to be locked.

View file

@ -3,25 +3,29 @@ package configmgr
import ( import (
"fmt" "fmt"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
"golang.org/x/exp/constraints" "golang.org/x/exp/constraints"
) )
// validator is the interface for configuration entities that can validate
// themselves.
type validator interface {
// validate returns an error if the entity isn't valid.
validate() (err error)
}
// numberOrDuration is the constraint for integer types along with // numberOrDuration is the constraint for integer types along with
// timeutil.Duration. // timeutil.Duration.
type numberOrDuration interface { type numberOrDuration interface {
constraints.Integer | timeutil.Duration constraints.Integer | timeutil.Duration
} }
// newMustBePositiveError returns an error about the value that must be positive // newErrNotPositive returns an error about the value that must be positive but
// but isn't. prop is the name of the property to mention in the error message. // isn't. prop is the name of the property to mention in the error message.
// //
// TODO(a.garipov): Consider moving such helpers to golibs and use in AdGuardDNS // TODO(a.garipov): Consider moving such helpers to golibs and use in AdGuardDNS
// as well. // as well.
func newMustBePositiveError[T numberOrDuration](prop string, v T) (err error) { func newErrNotPositive[T numberOrDuration](prop string, v T) (err error) {
if s, ok := any(v).(fmt.Stringer); ok { return fmt.Errorf("%s: %w, got %v", prop, errors.ErrNotPositive, v)
return fmt.Errorf("%s must be positive, got %s", prop, s)
}
return fmt.Errorf("%s must be positive, got %d", prop, v)
} }

View file

@ -1,6 +1,7 @@
package dnssvc package dnssvc
import ( import (
"log/slog"
"net/netip" "net/netip"
"time" "time"
) )
@ -9,6 +10,10 @@ import (
// //
// TODO(a.garipov): Add timeout for incoming requests. // TODO(a.garipov): Add timeout for incoming requests.
type Config struct { type Config struct {
// Logger is used for logging the operation of the web API service. It must
// not be nil.
Logger *slog.Logger
// Addresses are the addresses on which to serve plain DNS queries. // Addresses are the addresses on which to serve plain DNS queries.
Addresses []netip.AddrPort Addresses []netip.AddrPort

View file

@ -7,6 +7,7 @@ package dnssvc
import ( import (
"context" "context"
"fmt" "fmt"
"log/slog"
"net" "net"
"net/netip" "net/netip"
"sync/atomic" "sync/atomic"
@ -28,6 +29,7 @@ import (
// TODO(a.garipov): Consider saving a [*proxy.Config] instance for those // TODO(a.garipov): Consider saving a [*proxy.Config] instance for those
// fields that are only used in [New] and [Service.Config]. // fields that are only used in [New] and [Service.Config].
type Service struct { type Service struct {
logger *slog.Logger
proxy *proxy.Proxy proxy *proxy.Proxy
bootstraps []string bootstraps []string
bootstrapResolvers []*upstream.UpstreamResolver bootstrapResolvers []*upstream.UpstreamResolver
@ -48,6 +50,7 @@ func New(c *Config) (svc *Service, err error) {
} }
svc = &Service{ svc = &Service{
logger: c.Logger,
bootstraps: c.BootstrapServers, bootstraps: c.BootstrapServers,
upstreams: c.UpstreamServers, upstreams: c.UpstreamServers,
dns64Prefixes: c.DNS64Prefixes, dns64Prefixes: c.DNS64Prefixes,
@ -68,6 +71,7 @@ func New(c *Config) (svc *Service, err error) {
svc.bootstrapResolvers = resolvers svc.bootstrapResolvers = resolvers
svc.proxy, err = proxy.New(&proxy.Config{ svc.proxy, err = proxy.New(&proxy.Config{
Logger: svc.logger,
UDPListenAddr: udpAddrs(c.Addresses), UDPListenAddr: udpAddrs(c.Addresses),
TCPListenAddr: tcpAddrs(c.Addresses), TCPListenAddr: tcpAddrs(c.Addresses),
UpstreamConfig: &proxy.UpstreamConfig{ UpstreamConfig: &proxy.UpstreamConfig{
@ -153,12 +157,12 @@ func udpAddrs(addrPorts []netip.AddrPort) (udpAddrs []*net.UDPAddr) {
} }
// type check // type check
var _ agh.Service = (*Service)(nil) var _ agh.ServiceWithConfig[*Config] = (*Service)(nil)
// Start implements the [agh.Service] interface for *Service. svc may be nil. // Start implements the [agh.Service] interface for *Service. svc may be nil.
// After Start exits, all DNS servers have tried to start, but there is no // After Start exits, all DNS servers have tried to start, but there is no
// guarantee that they did. Errors from the servers are written to the log. // guarantee that they did. Errors from the servers are written to the log.
func (svc *Service) Start() (err error) { func (svc *Service) Start(ctx context.Context) (err error) {
if svc == nil { if svc == nil {
return nil return nil
} }
@ -170,7 +174,7 @@ func (svc *Service) Start() (err error) {
svc.running.Store(err == nil) svc.running.Store(err == nil)
}() }()
return svc.proxy.Start(context.Background()) return svc.proxy.Start(ctx)
} }
// Shutdown implements the [agh.Service] interface for *Service. svc may be // Shutdown implements the [agh.Service] interface for *Service. svc may be
@ -215,6 +219,7 @@ func (svc *Service) Config() (c *Config) {
} }
c = &Config{ c = &Config{
Logger: svc.logger,
Addresses: addrs, Addresses: addrs,
BootstrapServers: svc.bootstraps, BootstrapServers: svc.bootstraps,
UpstreamServers: svc.upstreams, UpstreamServers: svc.upstreams,

View file

@ -6,16 +6,13 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// testTimeout is the common timeout for tests. // testTimeout is the common timeout for tests.
const testTimeout = 1 * time.Second const testTimeout = 1 * time.Second
@ -59,6 +56,7 @@ func TestService(t *testing.T) {
_, _ = testutil.RequireReceive(t, upstreamStartedCh, testTimeout) _, _ = testutil.RequireReceive(t, upstreamStartedCh, testTimeout)
c := &dnssvc.Config{ c := &dnssvc.Config{
Logger: slogutil.NewDiscardLogger(),
Addresses: []netip.AddrPort{netip.MustParseAddrPort(listenAddr)}, Addresses: []netip.AddrPort{netip.MustParseAddrPort(listenAddr)},
BootstrapServers: []string{upstreamSrv.PacketConn.LocalAddr().String()}, BootstrapServers: []string{upstreamSrv.PacketConn.LocalAddr().String()},
UpstreamServers: []string{upstreamAddr}, UpstreamServers: []string{upstreamAddr},
@ -71,7 +69,7 @@ func TestService(t *testing.T) {
svc, err := dnssvc.New(c) svc, err := dnssvc.New(c)
require.NoError(t, err) require.NoError(t, err)
err = svc.Start() err = svc.Start(testutil.ContextWithTimeout(t, testTimeout))
require.NoError(t, err) require.NoError(t, err)
gotConf := svc.Config() gotConf := svc.Config()

View file

@ -3,12 +3,17 @@ package websvc
import ( import (
"crypto/tls" "crypto/tls"
"io/fs" "io/fs"
"log/slog"
"net/netip" "net/netip"
"time" "time"
) )
// Config is the AdGuard Home web service configuration structure. // Config is the AdGuard Home web service configuration structure.
type Config struct { type Config struct {
// Logger is used for logging the operation of the web API service. It must
// not be nil.
Logger *slog.Logger
// Pprof is the configuration for the pprof debug API. It must not be nil. // Pprof is the configuration for the pprof debug API. It must not be nil.
Pprof *PprofConfig Pprof *PprofConfig
@ -60,17 +65,20 @@ type PprofConfig struct {
// finished. // finished.
func (svc *Service) Config() (c *Config) { func (svc *Service) Config() (c *Config) {
c = &Config{ c = &Config{
Logger: svc.logger,
Pprof: &PprofConfig{ Pprof: &PprofConfig{
Port: svc.pprofPort, Port: svc.pprofPort,
Enabled: svc.pprof != nil, Enabled: svc.pprof != nil,
}, },
ConfigManager: svc.confMgr, ConfigManager: svc.confMgr,
Frontend: svc.frontend,
TLS: svc.tls, TLS: svc.tls,
// Leave Addresses and SecureAddresses empty and get the actual // Leave Addresses and SecureAddresses empty and get the actual
// addresses that include the :0 ones later. // addresses that include the :0 ones later.
Start: svc.start, Start: svc.start,
Timeout: svc.timeout, OverrideAddress: svc.overrideAddr,
ForceHTTPS: svc.forceHTTPS, Timeout: svc.timeout,
ForceHTTPS: svc.forceHTTPS,
} }
c.Addresses, c.SecureAddresses = svc.addrs() c.Addresses, c.SecureAddresses = svc.addrs()

View file

@ -11,8 +11,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
) )
// DNS Settings Handlers
// ReqPatchSettingsDNS describes the request to the PATCH /api/v1/settings/dns // ReqPatchSettingsDNS describes the request to the PATCH /api/v1/settings/dns
// HTTP API. // HTTP API.
type ReqPatchSettingsDNS struct { type ReqPatchSettingsDNS struct {
@ -60,6 +58,7 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques
} }
newConf := &dnssvc.Config{ newConf := &dnssvc.Config{
Logger: svc.logger,
Addresses: req.Addresses, Addresses: req.Addresses,
BootstrapServers: req.BootstrapServers, BootstrapServers: req.BootstrapServers,
UpstreamServers: req.UpstreamServers, UpstreamServers: req.UpstreamServers,
@ -78,7 +77,7 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques
} }
newSvc := svc.confMgr.DNS() newSvc := svc.confMgr.DNS()
err = newSvc.Start() err = newSvc.Start(ctx)
if err != nil { if err != nil {
aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("starting new service: %w", err)) aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("starting new service: %w", err))

View file

@ -35,7 +35,7 @@ func TestService_HandlePatchSettingsDNS(t *testing.T) {
confMgr := newConfigManager() confMgr := newConfigManager()
confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) { confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) {
return &aghtest.ServiceWithConfig[*dnssvc.Config]{ return &aghtest.ServiceWithConfig[*dnssvc.Config]{
OnStart: func() (err error) { OnStart: func(_ context.Context) (err error) {
started.Store(true) started.Store(true)
return nil return nil
@ -52,7 +52,7 @@ func TestService_HandlePatchSettingsDNS(t *testing.T) {
u := &url.URL{ u := &url.URL{
Scheme: urlutil.SchemeHTTP, Scheme: urlutil.SchemeHTTP,
Host: addr.String(), Host: addr.String(),
Path: websvc.PathV1SettingsDNS, Path: websvc.PathPatternV1SettingsDNS,
} }
req := jobj{ req := jobj{

View file

@ -10,11 +10,9 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
) )
// HTTP Settings Handlers
// ReqPatchSettingsHTTP describes the request to the PATCH /api/v1/settings/http // ReqPatchSettingsHTTP describes the request to the PATCH /api/v1/settings/http
// HTTP API. // HTTP API.
type ReqPatchSettingsHTTP struct { type ReqPatchSettingsHTTP struct {
@ -53,6 +51,7 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque
} }
newConf := &Config{ newConf := &Config{
Logger: svc.logger,
Pprof: &PprofConfig{ Pprof: &PprofConfig{
Port: svc.pprofPort, Port: svc.pprofPort,
Enabled: svc.pprof != nil, Enabled: svc.pprof != nil,
@ -89,13 +88,13 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque
// relaunch updates the web service in the configuration manager and starts it. // relaunch updates the web service in the configuration manager and starts it.
// It is intended to be used as a goroutine. // It is intended to be used as a goroutine.
func (svc *Service) relaunch(ctx context.Context, cancel context.CancelFunc, newConf *Config) { func (svc *Service) relaunch(ctx context.Context, cancel context.CancelFunc, newConf *Config) {
defer log.OnPanic("websvc: relaunching") defer slogutil.RecoverAndLog(ctx, svc.logger)
defer cancel() defer cancel()
err := svc.confMgr.UpdateWeb(ctx, newConf) err := svc.confMgr.UpdateWeb(ctx, newConf)
if err != nil { if err != nil {
log.Error("websvc: updating web: %s", err) svc.logger.ErrorContext(ctx, "updating web", slogutil.KeyError, err)
return return
} }
@ -106,18 +105,18 @@ func (svc *Service) relaunch(ctx context.Context, cancel context.CancelFunc, new
var newSvc agh.ServiceWithConfig[*Config] var newSvc agh.ServiceWithConfig[*Config]
for newSvc = svc.confMgr.Web(); newSvc == svc; { for newSvc = svc.confMgr.Web(); newSvc == svc; {
if time.Since(updStart) >= maxUpdDur { if time.Since(updStart) >= maxUpdDur {
log.Error("websvc: failed to update svc after %s", maxUpdDur) svc.logger.ErrorContext(ctx, "failed to update service on time", "duration", maxUpdDur)
return return
} }
log.Debug("websvc: waiting for new websvc to be configured") svc.logger.DebugContext(ctx, "waiting for new service")
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
} }
err = newSvc.Start() err = newSvc.Start(ctx)
if err != nil { if err != nil {
log.Error("websvc: new svc failed to start with error: %s", err) svc.logger.ErrorContext(ctx, "new service failed", slogutil.KeyError, err)
} }
} }

View file

@ -13,6 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc" "github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil/urlutil" "github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -27,14 +28,15 @@ func TestService_HandlePatchSettingsHTTP(t *testing.T) {
} }
svc, err := websvc.New(&websvc.Config{ svc, err := websvc.New(&websvc.Config{
Logger: slogutil.NewDiscardLogger(),
Pprof: &websvc.PprofConfig{ Pprof: &websvc.PprofConfig{
Enabled: false, Enabled: false,
}, },
TLS: &tls.Config{ TLS: &tls.Config{
Certificates: []tls.Certificate{{}}, Certificates: []tls.Certificate{{}},
}, },
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")}, Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")}, SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
ForceHTTPS: true, ForceHTTPS: true,
}) })
@ -48,7 +50,7 @@ func TestService_HandlePatchSettingsHTTP(t *testing.T) {
u := &url.URL{ u := &url.URL{
Scheme: urlutil.SchemeHTTP, Scheme: urlutil.SchemeHTTP,
Host: addr.String(), Host: addr.String(),
Path: websvc.PathV1SettingsHTTP, Path: websvc.PathPatternV1SettingsHTTP,
} }
req := jobj{ req := jobj{

View file

@ -2,15 +2,11 @@ package websvc
import ( import (
"net/http" "net/http"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
) )
// Middlewares
// jsonMw sets the content type of the response to application/json. // jsonMw sets the content type of the response to application/json.
func jsonMw(h http.Handler) (wrapped http.HandlerFunc) { func jsonMw(h http.Handler) (wrapped http.HandlerFunc) {
f := func(w http.ResponseWriter, r *http.Request) { f := func(w http.ResponseWriter, r *http.Request) {
@ -21,18 +17,3 @@ func jsonMw(h http.Handler) (wrapped http.HandlerFunc) {
return http.HandlerFunc(f) return http.HandlerFunc(f)
} }
// logMw logs the queries with level debug.
func logMw(h http.Handler) (wrapped http.HandlerFunc) {
f := func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
m, u := r.Method, r.RequestURI
log.Debug("websvc: %s %s started", m, u)
defer func() { log.Debug("websvc: %s %s finished in %s", m, u, time.Since(start)) }()
h.ServeHTTP(w, r)
}
return http.HandlerFunc(f)
}

View file

@ -1,14 +0,0 @@
package websvc
// Path constants
const (
PathRoot = "/"
PathFrontend = "/*filepath"
PathHealthCheck = "/health-check"
PathV1SettingsAll = "/api/v1/settings/all"
PathV1SettingsDNS = "/api/v1/settings/dns"
PathV1SettingsHTTP = "/api/v1/settings/http"
PathV1SystemInfo = "/api/v1/system/info"
)

View file

@ -0,0 +1,73 @@
package websvc
import (
"log/slog"
"net/http"
"github.com/AdguardTeam/golibs/netutil/httputil"
)
// Path pattern constants.
const (
PathPatternFrontend = "/"
PathPatternHealthCheck = "/health-check"
PathPatternV1SettingsAll = "/api/v1/settings/all"
PathPatternV1SettingsDNS = "/api/v1/settings/dns"
PathPatternV1SettingsHTTP = "/api/v1/settings/http"
PathPatternV1SystemInfo = "/api/v1/system/info"
)
// Route pattern constants.
const (
routePatternFrontend = http.MethodGet + " " + PathPatternFrontend
routePatternGetV1SettingsAll = http.MethodGet + " " + PathPatternV1SettingsAll
routePatternGetV1SystemInfo = http.MethodGet + " " + PathPatternV1SystemInfo
routePatternHealthCheck = http.MethodGet + " " + PathPatternHealthCheck
routePatternPatchV1SettingsDNS = http.MethodPatch + " " + PathPatternV1SettingsDNS
routePatternPatchV1SettingsHTTP = http.MethodPatch + " " + PathPatternV1SettingsHTTP
)
// route registers all necessary handlers in mux.
func (svc *Service) route(mux *http.ServeMux) {
routes := []struct {
handler http.Handler
pattern string
isJSON bool
}{{
handler: httputil.HealthCheckHandler,
pattern: routePatternHealthCheck,
isJSON: false,
}, {
handler: http.FileServer(http.FS(svc.frontend)),
pattern: routePatternFrontend,
isJSON: false,
}, {
handler: http.HandlerFunc(svc.handleGetSettingsAll),
pattern: routePatternGetV1SettingsAll,
isJSON: true,
}, {
handler: http.HandlerFunc(svc.handlePatchSettingsDNS),
pattern: routePatternPatchV1SettingsDNS,
isJSON: true,
}, {
handler: http.HandlerFunc(svc.handlePatchSettingsHTTP),
pattern: routePatternPatchV1SettingsHTTP,
isJSON: true,
}, {
handler: http.HandlerFunc(svc.handleGetV1SystemInfo),
pattern: routePatternGetV1SystemInfo,
isJSON: true,
}}
logMw := httputil.NewLogMiddleware(svc.logger, slog.LevelDebug)
for _, r := range routes {
var hdlr http.Handler
if r.isJSON {
hdlr = jsonMw(r.handler)
} else {
hdlr = r.handler
}
mux.Handle(r.pattern, logMw.Wrap(hdlr))
}
}

View file

@ -0,0 +1,156 @@
package websvc
import (
"context"
"crypto/tls"
"fmt"
"log/slog"
"net"
"net/http"
"net/netip"
"net/url"
"sync"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
)
// server contains an *http.Server as well as entities and data associated with
// it.
//
// TODO(a.garipov): Join with similar structs in other projects and move to
// golibs/netutil/httputil.
//
// TODO(a.garipov): Once the above standardization is complete, consider
// merging debugsvc and websvc into a single httpsvc.
type server struct {
// mu protects http, logger, tcpListener, and url.
mu *sync.Mutex
http *http.Server
logger *slog.Logger
tcpListener *net.TCPListener
url *url.URL
tlsConf *tls.Config
initialAddr netip.AddrPort
}
// loggerKeyServer is the key used by [server] to identify itself.
const loggerKeyServer = "server"
// newServer returns a *server that is ready to serve HTTP queries. The TCP
// listener is not started. handler must not be nil.
func newServer(
baseLogger *slog.Logger,
initialAddr netip.AddrPort,
tlsConf *tls.Config,
handler http.Handler,
timeout time.Duration,
) (s *server) {
u := &url.URL{
Scheme: urlutil.SchemeHTTP,
Host: initialAddr.String(),
}
if tlsConf != nil {
u.Scheme = urlutil.SchemeHTTPS
}
logger := baseLogger.With(loggerKeyServer, u)
return &server{
mu: &sync.Mutex{},
http: &http.Server{
Handler: handler,
ReadTimeout: timeout,
ReadHeaderTimeout: timeout,
WriteTimeout: timeout,
IdleTimeout: timeout,
ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError),
},
logger: logger,
url: u,
tlsConf: tlsConf,
initialAddr: initialAddr,
}
}
// localAddr returns the local address of the server if the server has started
// listening; otherwise, it returns nil.
func (s *server) localAddr() (addr net.Addr) {
s.mu.Lock()
defer s.mu.Unlock()
if l := s.tcpListener; l != nil {
return l.Addr()
}
return nil
}
// serve starts s. baseLogger is used as a base logger for s. If s fails to
// serve with anything other than [http.ErrServerClosed], it causes an unhandled
// panic. It is intended to be used as a goroutine.
//
// TODO(a.garipov): Improve error handling.
func (s *server) serve(ctx context.Context, baseLogger *slog.Logger) {
l, err := net.ListenTCP("tcp", net.TCPAddrFromAddrPort(s.initialAddr))
if err != nil {
s.logger.ErrorContext(ctx, "listening tcp", slogutil.KeyError, err)
panic(fmt.Errorf("websvc: listening tcp: %w", err))
}
func() {
s.mu.Lock()
defer s.mu.Unlock()
s.tcpListener = l
// Reassign the address in case the port was zero.
s.url.Host = l.Addr().String()
s.logger = baseLogger.With(loggerKeyServer, s.url)
s.http.ErrorLog = slog.NewLogLogger(s.logger.Handler(), slog.LevelError)
}()
s.logger.InfoContext(ctx, "starting")
defer s.logger.InfoContext(ctx, "started")
err = s.http.Serve(l)
if err == nil || errors.Is(err, http.ErrServerClosed) {
return
}
s.logger.ErrorContext(ctx, "serving", slogutil.KeyError, err)
panic(fmt.Errorf("websvc: serving: %w", err))
}
// shutdown shuts s down.
func (s *server) shutdown(ctx context.Context) (err error) {
s.mu.Lock()
defer s.mu.Unlock()
var errs []error
err = s.http.Shutdown(ctx)
if err != nil {
errs = append(errs, fmt.Errorf("shutting down server %s: %w", s.url, err))
}
// Close the listener separately, as it might not have been closed if the
// context has been canceled.
//
// NOTE: The listener could remain uninitialized if [net.ListenTCP] failed
// in [s.serve].
if l := s.tcpListener; l != nil {
err = l.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
errs = append(errs, fmt.Errorf("closing listener for server %s: %w", s.url, err))
}
}
return errors.Join(errs...)
}

View file

@ -1,7 +1,6 @@
package websvc_test package websvc_test
import ( import (
"crypto/tls"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/netip" "net/netip"
@ -13,6 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc" "github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil/urlutil" "github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -29,16 +29,10 @@ func TestService_HandleGetSettingsAll(t *testing.T) {
BootstrapPreferIPv6: true, BootstrapPreferIPv6: true,
} }
wantWeb := &websvc.HTTPAPIHTTPSettings{
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")},
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")},
Timeout: aghhttp.JSONDuration(5 * time.Second),
ForceHTTPS: true,
}
confMgr := newConfigManager() confMgr := newConfigManager()
confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) { confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) {
c, err := dnssvc.New(&dnssvc.Config{ c, err := dnssvc.New(&dnssvc.Config{
Logger: slogutil.NewDiscardLogger(),
Addresses: wantDNS.Addresses, Addresses: wantDNS.Addresses,
UpstreamServers: wantDNS.UpstreamServers, UpstreamServers: wantDNS.UpstreamServers,
BootstrapServers: wantDNS.BootstrapServers, BootstrapServers: wantDNS.BootstrapServers,
@ -50,34 +44,27 @@ func TestService_HandleGetSettingsAll(t *testing.T) {
return c return c
} }
svc, err := websvc.New(&websvc.Config{ svc, addr := newTestServer(t, confMgr)
Pprof: &websvc.PprofConfig{ u := &url.URL{
Enabled: false, Scheme: urlutil.SchemeHTTP,
}, Host: addr.String(),
TLS: &tls.Config{ Path: websvc.PathPatternV1SettingsAll,
Certificates: []tls.Certificate{{}}, }
},
Addresses: wantWeb.Addresses,
SecureAddresses: wantWeb.SecureAddresses,
Timeout: time.Duration(wantWeb.Timeout),
ForceHTTPS: true,
})
require.NoError(t, err)
confMgr.onWeb = func() (s agh.ServiceWithConfig[*websvc.Config]) { confMgr.onWeb = func() (s agh.ServiceWithConfig[*websvc.Config]) {
return svc return svc
} }
_, addr := newTestServer(t, confMgr) wantWeb := &websvc.HTTPAPIHTTPSettings{
u := &url.URL{ Addresses: []netip.AddrPort{addr},
Scheme: urlutil.SchemeHTTP, SecureAddresses: nil,
Host: addr.String(), Timeout: aghhttp.JSONDuration(testTimeout),
Path: websvc.PathV1SettingsAll, ForceHTTPS: false,
} }
body := httpGet(t, u, http.StatusOK) body := httpGet(t, u, http.StatusOK)
resp := &websvc.RespGetV1SettingsAll{} resp := &websvc.RespGetV1SettingsAll{}
err = json.Unmarshal(body, resp) err := json.Unmarshal(body, resp)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, wantDNS, resp.DNS) assert.Equal(t, wantDNS, resp.DNS)

View file

@ -20,7 +20,7 @@ func TestService_handleGetV1SystemInfo(t *testing.T) {
u := &url.URL{ u := &url.URL{
Scheme: urlutil.SchemeHTTP, Scheme: urlutil.SchemeHTTP,
Host: addr.String(), Host: addr.String(),
Path: websvc.PathV1SystemInfo, Path: websvc.PathPatternV1SystemInfo,
} }
body := httpGet(t, u, http.StatusOK) body := httpGet(t, u, http.StatusOK)

View file

@ -1,31 +0,0 @@
package websvc
import (
"net"
"sync"
)
// Wait Listener
// waitListener is a wrapper around a listener that also calls wg.Done() on the
// first call to Accept. It is useful in situations where it is important to
// catch the precise moment of the first call to Accept, for example when
// starting an HTTP server.
//
// TODO(a.garipov): Move to aghnet?
type waitListener struct {
net.Listener
firstAcceptWG *sync.WaitGroup
firstAcceptOnce sync.Once
}
// type check
var _ net.Listener = (*waitListener)(nil)
// Accept implements the [net.Listener] interface for *waitListener.
func (l *waitListener) Accept() (conn net.Conn, err error) {
l.firstAcceptOnce.Do(l.firstAcceptWG.Done)
return l.Listener.Accept()
}

View file

@ -1,40 +0,0 @@
package websvc
import (
"net"
"sync"
"sync/atomic"
"testing"
"github.com/AdguardTeam/golibs/testutil/fakenet"
"github.com/stretchr/testify/assert"
)
func TestWaitListener_Accept(t *testing.T) {
var accepted atomic.Bool
var l net.Listener = &fakenet.Listener{
OnAccept: func() (conn net.Conn, err error) {
accepted.Store(true)
return nil, nil
},
OnAddr: func() (addr net.Addr) { panic("not implemented") },
OnClose: func() (err error) { panic("not implemented") },
}
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
var wrapper net.Listener = &waitListener{
Listener: l,
firstAcceptWG: wg,
}
_, _ = wrapper.Accept()
}()
wg.Wait()
assert.Eventually(t, accepted.Load, testTimeout, testTimeout/10)
}

View file

@ -10,22 +10,18 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io"
"io/fs" "io/fs"
"net" "log/slog"
"net/http" "net/http"
"net/netip" "net/netip"
"runtime" "runtime"
"sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/mathutil"
"github.com/AdguardTeam/golibs/netutil/httputil" "github.com/AdguardTeam/golibs/netutil/httputil"
httptreemux "github.com/dimfeld/httptreemux/v5"
) )
// ConfigManager is the configuration manager interface. // ConfigManager is the configuration manager interface.
@ -40,13 +36,14 @@ type ConfigManager interface {
// Service is the AdGuard Home web service. A nil *Service is a valid // Service is the AdGuard Home web service. A nil *Service is a valid
// [agh.Service] that does nothing. // [agh.Service] that does nothing.
type Service struct { type Service struct {
logger *slog.Logger
confMgr ConfigManager confMgr ConfigManager
frontend fs.FS frontend fs.FS
tls *tls.Config tls *tls.Config
pprof *http.Server pprof *server
start time.Time start time.Time
overrideAddr netip.AddrPort overrideAddr netip.AddrPort
servers []*http.Server servers []*server
timeout time.Duration timeout time.Duration
pprofPort uint16 pprofPort uint16
forceHTTPS bool forceHTTPS bool
@ -64,6 +61,7 @@ func New(c *Config) (svc *Service, err error) {
} }
svc = &Service{ svc = &Service{
logger: c.Logger,
confMgr: c.ConfigManager, confMgr: c.ConfigManager,
frontend: c.Frontend, frontend: c.Frontend,
tls: c.TLS, tls: c.TLS,
@ -73,17 +71,18 @@ func New(c *Config) (svc *Service, err error) {
forceHTTPS: c.ForceHTTPS, forceHTTPS: c.ForceHTTPS,
} }
mux := newMux(svc) mux := http.NewServeMux()
svc.route(mux)
if svc.overrideAddr != (netip.AddrPort{}) { if svc.overrideAddr != (netip.AddrPort{}) {
svc.servers = []*http.Server{newSrv(svc.overrideAddr, nil, mux, c.Timeout)} svc.servers = []*server{newServer(svc.logger, svc.overrideAddr, nil, mux, c.Timeout)}
} else { } else {
for _, a := range c.Addresses { for _, a := range c.Addresses {
svc.servers = append(svc.servers, newSrv(a, nil, mux, c.Timeout)) svc.servers = append(svc.servers, newServer(svc.logger, a, nil, mux, c.Timeout))
} }
for _, a := range c.SecureAddresses { for _, a := range c.SecureAddresses {
svc.servers = append(svc.servers, newSrv(a, c.TLS, mux, c.Timeout)) svc.servers = append(svc.servers, newServer(svc.logger, a, c.TLS, mux, c.Timeout))
} }
} }
@ -112,96 +111,7 @@ func (svc *Service) setupPprof(c *PprofConfig) {
svc.pprofPort = c.Port svc.pprofPort = c.Port
addr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), c.Port) addr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), c.Port)
// TODO(a.garipov): Consider making pprof timeout configurable. svc.pprof = newServer(svc.logger, addr, nil, pprofMux, 10*time.Minute)
svc.pprof = newSrv(addr, nil, pprofMux, 10*time.Minute)
}
// newSrv returns a new *http.Server with the given parameters.
func newSrv(
addr netip.AddrPort,
tlsConf *tls.Config,
h http.Handler,
timeout time.Duration,
) (srv *http.Server) {
addrStr := addr.String()
srv = &http.Server{
Addr: addrStr,
Handler: h,
TLSConfig: tlsConf,
ReadTimeout: timeout,
WriteTimeout: timeout,
IdleTimeout: timeout,
ReadHeaderTimeout: timeout,
}
if tlsConf == nil {
srv.ErrorLog = log.StdLog("websvc: plain http: "+addrStr, log.ERROR)
} else {
srv.ErrorLog = log.StdLog("websvc: https: "+addrStr, log.ERROR)
}
return srv
}
// newMux returns a new HTTP request multiplexer for the AdGuard Home web
// service.
func newMux(svc *Service) (mux *httptreemux.ContextMux) {
mux = httptreemux.NewContextMux()
routes := []struct {
handler http.HandlerFunc
method string
pattern string
isJSON bool
}{{
handler: svc.handleGetHealthCheck,
method: http.MethodGet,
pattern: PathHealthCheck,
isJSON: false,
}, {
handler: http.FileServer(http.FS(svc.frontend)).ServeHTTP,
method: http.MethodGet,
pattern: PathFrontend,
isJSON: false,
}, {
handler: http.FileServer(http.FS(svc.frontend)).ServeHTTP,
method: http.MethodGet,
pattern: PathRoot,
isJSON: false,
}, {
handler: svc.handleGetSettingsAll,
method: http.MethodGet,
pattern: PathV1SettingsAll,
isJSON: true,
}, {
handler: svc.handlePatchSettingsDNS,
method: http.MethodPatch,
pattern: PathV1SettingsDNS,
isJSON: true,
}, {
handler: svc.handlePatchSettingsHTTP,
method: http.MethodPatch,
pattern: PathV1SettingsHTTP,
isJSON: true,
}, {
handler: svc.handleGetV1SystemInfo,
method: http.MethodGet,
pattern: PathV1SystemInfo,
isJSON: true,
}}
for _, r := range routes {
var hdlr http.Handler
if r.isJSON {
hdlr = jsonMw(r.handler)
} else {
hdlr = r.handler
}
mux.Handle(r.method, r.pattern, logMw(hdlr))
}
return mux
} }
// addrs returns all addresses on which this server serves the HTTP API. addrs // addrs returns all addresses on which this server serves the HTTP API. addrs
@ -214,14 +124,12 @@ func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) {
} }
for _, srv := range svc.servers { for _, srv := range svc.servers {
// Use MustParseAddrPort, since no errors should technically happen addrPort := netutil.NetAddrToAddrPort(srv.localAddr())
// here, because all servers must have a valid address. if addrPort == (netip.AddrPort{}) {
addrPort := netip.MustParseAddrPort(srv.Addr) continue
}
// [srv.Serve] will set TLSConfig to an almost empty value, so, instead if srv.tlsConf == nil {
// of relying only on the nilness of TLSConfig, check the length of the
// certificates field as well.
if srv.TLSConfig == nil || len(srv.TLSConfig.Certificates) == 0 {
addrs = append(addrs, addrPort) addrs = append(addrs, addrPort)
} else { } else {
secureAddrs = append(secureAddrs, addrPort) secureAddrs = append(secureAddrs, addrPort)
@ -231,74 +139,60 @@ func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) {
return addrs, secureAddrs return addrs, secureAddrs
} }
// handleGetHealthCheck is the handler for the GET /health-check HTTP API.
func (svc *Service) handleGetHealthCheck(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "OK")
}
// type check // type check
var _ agh.Service = (*Service)(nil) var _ agh.ServiceWithConfig[*Config] = (*Service)(nil)
// Start implements the [agh.Service] interface for *Service. svc may be nil. // Start implements the [agh.Service] interface for *Service. svc may be nil.
// After Start exits, all HTTP servers have tried to start, possibly failing and // After Start exits, all HTTP servers have tried to start, possibly failing and
// writing error messages to the log. // writing error messages to the log.
func (svc *Service) Start() (err error) { //
// TODO(a.garipov): Use the context for cancelation as well.
func (svc *Service) Start(ctx context.Context) (err error) {
if svc == nil { if svc == nil {
return nil return nil
} }
pprofEnabled := svc.pprof != nil svc.logger.InfoContext(ctx, "starting")
srvNum := len(svc.servers) + mathutil.BoolToNumber[int](pprofEnabled) defer svc.logger.InfoContext(ctx, "started")
wg := &sync.WaitGroup{}
wg.Add(srvNum)
for _, srv := range svc.servers { for _, srv := range svc.servers {
go serve(srv, wg) go srv.serve(ctx, svc.logger)
} }
if pprofEnabled { if svc.pprof != nil {
go serve(svc.pprof, wg) go svc.pprof.serve(ctx, svc.logger)
} }
wg.Wait() return svc.wait(ctx)
}
// wait waits until either the context is canceled or all servers have started.
func (svc *Service) wait(ctx context.Context) (err error) {
for !svc.serversHaveStarted() {
select {
case <-ctx.Done():
return ctx.Err()
default:
// Wait and let the other goroutines do their job.
runtime.Gosched()
}
}
return nil return nil
} }
// serve starts and runs srv and writes all errors into its log. // serversHaveStarted returns true if all servers have started serving.
func serve(srv *http.Server, wg *sync.WaitGroup) { func (svc *Service) serversHaveStarted() (started bool) {
addr := srv.Addr started = len(svc.servers) != 0
defer log.OnPanic(addr) for _, srv := range svc.servers {
started = started && srv.localAddr() != nil
var proto string
var l net.Listener
var err error
if srv.TLSConfig == nil {
proto = "http"
l, err = net.Listen("tcp", addr)
} else {
proto = "https"
l, err = tls.Listen("tcp", addr, srv.TLSConfig)
}
if err != nil {
srv.ErrorLog.Printf("starting srv %s: binding: %s", addr, err)
} }
// Update the server's address in case the address had the port zero, which if svc.pprof != nil {
// would mean that a random available port was automatically chosen. started = started && svc.pprof.localAddr() != nil
srv.Addr = l.Addr().String()
log.Info("websvc: starting srv %s://%s", proto, srv.Addr)
l = &waitListener{
Listener: l,
firstAcceptWG: wg,
} }
err = srv.Serve(l) return started
if err != nil && !errors.Is(err, http.ErrServerClosed) {
srv.ErrorLog.Printf("starting srv %s: %s", addr, err)
}
} }
// Shutdown implements the [agh.Service] interface for *Service. svc may be // Shutdown implements the [agh.Service] interface for *Service. svc may be
@ -308,20 +202,24 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) {
return nil return nil
} }
svc.logger.InfoContext(ctx, "shutting down")
defer svc.logger.InfoContext(ctx, "shut down")
defer func() { err = errors.Annotate(err, "shutting down: %w") }() defer func() { err = errors.Annotate(err, "shutting down: %w") }()
var errs []error var errs []error
for _, srv := range svc.servers { for _, srv := range svc.servers {
shutdownErr := srv.Shutdown(ctx) shutdownErr := srv.shutdown(ctx)
if shutdownErr != nil { if shutdownErr != nil {
errs = append(errs, fmt.Errorf("srv %s: %w", srv.Addr, shutdownErr)) // Don't wrap the error, because it's informative enough as is.
errs = append(errs, err)
} }
} }
if svc.pprof != nil { if svc.pprof != nil {
shutdownErr := svc.pprof.Shutdown(ctx) shutdownErr := svc.pprof.shutdown(ctx)
if shutdownErr != nil { if shutdownErr != nil {
errs = append(errs, fmt.Errorf("pprof srv %s: %w", svc.pprof.Addr, shutdownErr)) errs = append(errs, fmt.Errorf("pprof: %w", shutdownErr))
} }
} }

View file

@ -1,6 +0,0 @@
package websvc
import "time"
// testTimeout is the common timeout for tests.
const testTimeout = 1 * time.Second

View file

@ -15,6 +15,8 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc" "github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil/httputil"
"github.com/AdguardTeam/golibs/netutil/urlutil" "github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/testutil/fakefs" "github.com/AdguardTeam/golibs/testutil/fakefs"
@ -22,10 +24,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// testTimeout is the common timeout for tests. // testTimeout is the common timeout for tests.
const testTimeout = 1 * time.Second const testTimeout = 1 * time.Second
@ -81,8 +79,6 @@ func newConfigManager() (m *configManager) {
// newTestServer creates and starts a new web service instance as well as its // newTestServer creates and starts a new web service instance as well as its
// sole address. It also registers a cleanup procedure, which shuts the // sole address. It also registers a cleanup procedure, which shuts the
// instance down. // instance down.
//
// TODO(a.garipov): Use svc or remove it.
func newTestServer( func newTestServer(
t testing.TB, t testing.TB,
confMgr websvc.ConfigManager, confMgr websvc.ConfigManager,
@ -90,6 +86,7 @@ func newTestServer(
t.Helper() t.Helper()
c := &websvc.Config{ c := &websvc.Config{
Logger: slogutil.NewDiscardLogger(),
Pprof: &websvc.PprofConfig{ Pprof: &websvc.PprofConfig{
Enabled: false, Enabled: false,
}, },
@ -108,7 +105,7 @@ func newTestServer(
svc, err := websvc.New(c) svc, err := websvc.New(c)
require.NoError(t, err) require.NoError(t, err)
err = svc.Start() err = svc.Start(testutil.ContextWithTimeout(t, testTimeout))
require.NoError(t, err) require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { testutil.CleanupAndRequireSuccess(t, func() (err error) {
return svc.Shutdown(testutil.ContextWithTimeout(t, testTimeout)) return svc.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
@ -184,10 +181,10 @@ func TestService_Start_getHealthCheck(t *testing.T) {
u := &url.URL{ u := &url.URL{
Scheme: urlutil.SchemeHTTP, Scheme: urlutil.SchemeHTTP,
Host: addr.String(), Host: addr.String(),
Path: websvc.PathHealthCheck, Path: websvc.PathPatternHealthCheck,
} }
body := httpGet(t, u, http.StatusOK) body := httpGet(t, u, http.StatusOK)
assert.Equal(t, []byte("OK"), body) assert.Equal(t, []byte(httputil.HealthCheckHandler), body)
} }