diff --git a/go.mod b/go.mod index e4d8bd5c..14ffe2e3 100644 --- a/go.mod +++ b/go.mod @@ -4,14 +4,13 @@ go 1.23.3 require ( github.com/AdguardTeam/dnsproxy v0.73.3 - github.com/AdguardTeam/golibs v0.30.2 + github.com/AdguardTeam/golibs v0.30.3 github.com/AdguardTeam/urlfilter v0.20.0 github.com/NYTimes/gziphandler v1.1.1 github.com/ameshkov/dnscrypt/v2 v2.3.0 github.com/bluele/gcache v0.0.2 github.com/c2h5oh/datasize v0.0.0-20231215233829-aa82cc1e6500 github.com/digineo/go-ipset/v2 v2.2.1 - github.com/dimfeld/httptreemux/v5 v5.5.0 github.com/fsnotify/fsnotify v1.8.0 github.com/go-ping/ping v1.1.0 github.com/google/go-cmp v0.6.0 diff --git a/go.sum b/go.sum index 00d0409c..cabad9ac 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/AdguardTeam/dnsproxy v0.73.3 h1:aacr6Wu0ed94DDD+gSB6EwF8nvyq0+DAc7oFOgtgUpA= github.com/AdguardTeam/dnsproxy v0.73.3/go.mod h1:18ssqhDgOCiVIwYmmVuXVM05wSwrzkO2yjKhVRWJX/g= -github.com/AdguardTeam/golibs v0.30.2 h1:urU/NAyIvQOeArBqDmKCDpaRkfTCJ26uSiSuDMKQfuY= -github.com/AdguardTeam/golibs v0.30.2/go.mod h1:FkwcNQEJoGsgDGXcalrVa/4gWbE68KsmE2guXWtBQUE= +github.com/AdguardTeam/golibs v0.30.3 h1:pRxLjMCJ1cZccjZWMMuKxzQQGEpFbmtyj4Tg7nk5rY0= +github.com/AdguardTeam/golibs v0.30.3/go.mod h1:Ir9dlHfb8nRQsG3Qgo1zoGL+k1qMbcBtb8tcnsvzdAE= github.com/AdguardTeam/urlfilter v0.20.0 h1:X32qiuVCVd8WDYCEsbdZKfXMzwdVqrdulamtUi4rmzs= github.com/AdguardTeam/urlfilter v0.20.0/go.mod h1:gjrywLTxfJh6JOkwi9SU+frhP7kVVEZ5exFGkR99qpk= github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I= @@ -25,8 +25,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/digineo/go-ipset/v2 v2.2.1 h1:k6skY+0fMqeUjjeWO/m5OuWPSZUAn7AucHMnQ1MX77g= github.com/digineo/go-ipset/v2 v2.2.1/go.mod h1:wBsNzJlZlABHUITkesrggFnZQtgW5wkqw1uo8Qxe0VU= -github.com/dimfeld/httptreemux/v5 v5.5.0 h1:p8jkiMrCuZ0CmhwYLcbNbl7DDo21fozhKHQ2PccwOFQ= -github.com/dimfeld/httptreemux/v5 v5.5.0/go.mod h1:QeEylH57C0v3VO0tkKraVz9oD3Uu93CKPnTLbsidvSw= github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M= github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= diff --git a/internal/aghos/os.go b/internal/aghos/os.go index f9ab2071..693cf7d2 100644 --- a/internal/aghos/os.go +++ b/internal/aghos/os.go @@ -146,16 +146,6 @@ func IsOpenWrt() (ok bool) { return isOpenWrt() } -// NotifyReconfigureSignal notifies c on receiving reconfigure signals. -func NotifyReconfigureSignal(c chan<- os.Signal) { - notifyReconfigureSignal(c) -} - -// IsReconfigureSignal returns true if sig is a reconfigure signal. -func IsReconfigureSignal(sig os.Signal) (ok bool) { - return isReconfigureSignal(sig) -} - // SendShutdownSignal sends the shutdown signal to the channel. func SendShutdownSignal(c chan<- os.Signal) { sendShutdownSignal(c) diff --git a/internal/aghos/os_unix.go b/internal/aghos/os_unix.go index f2cc4fef..42fbe1a7 100644 --- a/internal/aghos/os_unix.go +++ b/internal/aghos/os_unix.go @@ -1,22 +1,11 @@ -//go:build darwin || freebsd || linux || openbsd +//go:build unix package aghos import ( "os" - "os/signal" - - "golang.org/x/sys/unix" ) -func notifyReconfigureSignal(c chan<- os.Signal) { - signal.Notify(c, unix.SIGHUP) -} - -func isReconfigureSignal(sig os.Signal) (ok bool) { - return sig == unix.SIGHUP -} - func sendShutdownSignal(_ chan<- os.Signal) { // On Unix we are already notified by the system. } diff --git a/internal/aghos/os_windows.go b/internal/aghos/os_windows.go index b9bf8a4c..bff3b92c 100644 --- a/internal/aghos/os_windows.go +++ b/internal/aghos/os_windows.go @@ -4,12 +4,11 @@ package aghos import ( "os" - "os/signal" "golang.org/x/sys/windows" ) -func setRlimit(val uint64) (err error) { +func setRlimit(_ uint64) (err error) { return Unsupported("setrlimit") } @@ -38,14 +37,6 @@ func isOpenWrt() (ok bool) { return false } -func notifyReconfigureSignal(c chan<- os.Signal) { - signal.Notify(c, windows.SIGHUP) -} - -func isReconfigureSignal(sig os.Signal) (ok bool) { - return sig == windows.SIGHUP -} - func sendShutdownSignal(c chan<- os.Signal) { c <- os.Interrupt } diff --git a/internal/aghtest/interface.go b/internal/aghtest/interface.go index bc86721e..8db2882b 100644 --- a/internal/aghtest/interface.go +++ b/internal/aghtest/interface.go @@ -58,7 +58,7 @@ func (w *FSWatcher) Add(name string) (err error) { // ServiceWithConfig is a fake [agh.ServiceWithConfig] implementation for tests. type ServiceWithConfig[ConfigType any] struct { - OnStart func() (err error) + OnStart func(ctx context.Context) (err error) OnShutdown func(ctx context.Context) (err error) OnConfig func() (c ConfigType) } @@ -68,8 +68,8 @@ var _ agh.ServiceWithConfig[struct{}] = (*ServiceWithConfig[struct{}])(nil) // Start implements the [agh.ServiceWithConfig] interface for // *ServiceWithConfig. -func (s *ServiceWithConfig[_]) Start() (err error) { - return s.OnStart() +func (s *ServiceWithConfig[_]) Start(ctx context.Context) (err error) { + return s.OnStart(ctx) } // Shutdown implements the [agh.ServiceWithConfig] interface for diff --git a/internal/dhcpsvc/dhcpsvc.go b/internal/dhcpsvc/dhcpsvc.go index b6c77786..0445966d 100644 --- a/internal/dhcpsvc/dhcpsvc.go +++ b/internal/dhcpsvc/dhcpsvc.go @@ -82,7 +82,7 @@ type Empty struct{} var _ agh.ServiceWithConfig[*Config] = Empty{} // Start implements the [Service] interface for Empty. -func (Empty) Start() (err error) { return nil } +func (Empty) Start(_ context.Context) (err error) { return nil } // Shutdown implements the [Service] interface for Empty. func (Empty) Shutdown(_ context.Context) (err error) { return nil } diff --git a/internal/next/agh/agh.go b/internal/next/agh/agh.go index 52855524..2248bc81 100644 --- a/internal/next/agh/agh.go +++ b/internal/next/agh/agh.go @@ -1,36 +1,9 @@ // Package agh contains common entities and interfaces of AdGuard Home. package agh -import "context" - -// Service is the interface for API servers. -// -// TODO(a.garipov): Consider adding a context to Start. -// -// TODO(a.garipov): Consider adding a Wait method or making an extension -// interface for that. -type Service interface { - // Start starts the service. It does not block. - Start() (err error) - - // Shutdown gracefully stops the service. ctx is used to determine - // a timeout before trying to stop the service less gracefully. - Shutdown(ctx context.Context) (err error) -} - -// type check -var _ Service = EmptyService{} - -// EmptyService is a [Service] that does nothing. -// -// TODO(a.garipov): Remove if unnecessary. -type EmptyService struct{} - -// Start implements the [Service] interface for EmptyService. -func (EmptyService) Start() (err error) { return nil } - -// Shutdown implements the [Service] interface for EmptyService. -func (EmptyService) Shutdown(_ context.Context) (err error) { return nil } +import ( + "github.com/AdguardTeam/golibs/service" +) // ServiceWithConfig is an extension of the [Service] interface for services // that can return their configuration. @@ -38,7 +11,7 @@ func (EmptyService) Shutdown(_ context.Context) (err error) { return nil } // TODO(a.garipov): Consider removing this generic interface if we figure out // how to make it testable in a better way. type ServiceWithConfig[ConfigType any] interface { - Service + service.Interface Config() (c ConfigType) } @@ -51,7 +24,7 @@ var _ ServiceWithConfig[struct{}] = (*EmptyServiceWithConfig[struct{}])(nil) // // TODO(a.garipov): Remove if unnecessary. type EmptyServiceWithConfig[ConfigType any] struct { - EmptyService + service.Empty Conf ConfigType } diff --git a/internal/next/cmd/cmd.go b/internal/next/cmd/cmd.go index 1c118bdb..3bab1396 100644 --- a/internal/next/cmd/cmd.go +++ b/internal/next/cmd/cmd.go @@ -12,11 +12,15 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/next/configmgr" "github.com/AdguardTeam/AdGuardHome/internal/version" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/service" ) // Main is the entry point of AdGuard Home. func Main(embeddedFrontend fs.FS) { + ctx := context.Background() + start := time.Now() cmdName := os.Args[0] @@ -26,70 +30,69 @@ func Main(embeddedFrontend fs.FS) { os.Exit(exitCode) } - err = setLog(opts) - check(err) + baseLogger := newBaseLogger(opts) - log.Info("starting adguard home, version %s, pid %d", version.Version(), os.Getpid()) + baseLogger.InfoContext( + ctx, + "starting adguard home", + "version", version.Version(), + "pid", os.Getpid(), + ) if opts.workDir != "" { - log.Info("changing working directory to %q", opts.workDir) + baseLogger.InfoContext(ctx, "changing working directory", "dir", opts.workDir) + err = os.Chdir(opts.workDir) - check(err) + errors.Check(err) } - frontend, err := frontendFromOpts(opts, embeddedFrontend) - check(err) + frontend, err := frontendFromOpts(ctx, baseLogger, opts, embeddedFrontend) + errors.Check(err) + + startCtx, startCancel := context.WithTimeout(ctx, defaultTimeoutStart) + defer startCancel() confMgrConf := &configmgr.Config{ - Frontend: frontend, - WebAddr: opts.webAddr, - Start: start, - FileName: opts.confFile, + BaseLogger: baseLogger, + Logger: baseLogger.With(slogutil.KeyPrefix, "configmgr"), + Frontend: frontend, + WebAddr: opts.webAddr, + Start: start, + FileName: opts.confFile, } - confMgr, err := newConfigMgr(confMgrConf) - check(err) + confMgr, err := configmgr.New(startCtx, confMgrConf) + errors.Check(err) web := confMgr.Web() - err = web.Start() - check(err) + err = web.Start(startCtx) + errors.Check(err) dns := confMgr.DNS() - err = dns.Start() - check(err) + err = dns.Start(startCtx) + errors.Check(err) sigHdlr := newSignalHandler( + baseLogger.With(slogutil.KeyPrefix, service.SignalHandlerPrefix), confMgrConf, opts.pidFile, web, dns, ) - sigHdlr.handle() + os.Exit(sigHdlr.handle(ctx)) } -// defaultTimeout is the timeout used for some operations where another timeout -// hasn't been defined yet. -const defaultTimeout = 5 * time.Second - -// ctxWithDefaultTimeout is a helper function that returns a context with -// timeout set to defaultTimeout. -func ctxWithDefaultTimeout() (ctx context.Context, cancel context.CancelFunc) { - return context.WithTimeout(context.Background(), defaultTimeout) -} +// Default timeouts. +// +// TODO(a.garipov): Make configurable. +const ( + defaultTimeoutStart = 1 * time.Minute + defaultTimeoutShutdown = 5 * time.Second +) // newConfigMgr returns a new configuration manager using defaultTimeout as the // context timeout. -func newConfigMgr(c *configmgr.Config) (m *configmgr.Manager, err error) { - ctx, cancel := ctxWithDefaultTimeout() - defer cancel() - +func newConfigMgr(ctx context.Context, c *configmgr.Config) (m *configmgr.Manager, err error) { return configmgr.New(ctx, c) } - -// check is a simple error-checking helper. It must only be used within Main. -func check(err error) { - if err != nil { - panic(err) - } -} diff --git a/internal/next/cmd/log.go b/internal/next/cmd/log.go index 3aa2a0e5..0f25dad1 100644 --- a/internal/next/cmd/log.go +++ b/internal/next/cmd/log.go @@ -1,39 +1,39 @@ package cmd import ( - "fmt" + "io" + "log/slog" "os" - "github.com/AdguardTeam/AdGuardHome/internal/aghos" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" ) -// syslogServiceName is the name of the AdGuard Home service used for writing -// logs to the system log. -const syslogServiceName = "AdGuardHome" - -// setLog sets up the text logging. -// -// TODO(a.garipov): Add parameters from configuration file. -func setLog(opts *options) (err error) { +// newBaseLogger constructs a base logger based on the command-line options. +// opts must not be nil. +func newBaseLogger(opts *options) (baseLogger *slog.Logger) { + var output io.Writer switch opts.confFile { case "stdout": - log.SetOutput(os.Stdout) + output = os.Stdout case "stderr": - log.SetOutput(os.Stderr) + output = os.Stderr case "syslog": - err = aghos.ConfigureSyslog(syslogServiceName) - if err != nil { - return fmt.Errorf("initializing syslog: %w", err) - } + // TODO(a.garipov): Add a syslog handler to golibs. default: - // TODO(a.garipov): Use the path. + // TODO(a.garipov): Use the path. } + lvl := slog.LevelInfo if opts.verbose { - log.SetLevel(log.DEBUG) - log.Debug("verbose logging enabled") + lvl = slog.LevelDebug } - return nil + return slogutil.New(&slogutil.Config{ + Output: output, + // TODO(a.garipov): Get from config? + Format: slogutil.FormatText, + Level: lvl, + // TODO(a.garipov): Get from config. + AddTimestamp: true, + }) } diff --git a/internal/next/cmd/opt.go b/internal/next/cmd/opt.go index 27d95b92..8e06a492 100644 --- a/internal/next/cmd/opt.go +++ b/internal/next/cmd/opt.go @@ -1,11 +1,13 @@ package cmd import ( + "context" "encoding" "flag" "fmt" "io" "io/fs" + "log/slog" "net/netip" "os" "slices" @@ -14,7 +16,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/configmigrate" "github.com/AdguardTeam/AdGuardHome/internal/next/configmgr" "github.com/AdguardTeam/AdGuardHome/internal/version" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/osutil" ) // options contains all command-line options for the AdGuardHome(.exe) binary. @@ -372,13 +374,13 @@ func processOptions( ) (exitCode int, needExit bool) { if parseErr != nil { // Assume that usage has already been printed. - return statusArgumentError, true + return osutil.ExitCodeArgumentError, true } if opts.help { usage(cmdName, os.Stdout) - return statusSuccess, true + return osutil.ExitCodeSuccess, true } if opts.version { @@ -388,7 +390,7 @@ func processOptions( fmt.Printf("AdGuard Home %s\n", version.Version()) } - return statusSuccess, true + return osutil.ExitCodeSuccess, true } if opts.checkConfig { @@ -396,21 +398,26 @@ func processOptions( if err != nil { _, _ = io.WriteString(os.Stdout, err.Error()+"\n") - return statusError, true + return osutil.ExitCodeFailure, true } - return statusSuccess, true + return osutil.ExitCodeSuccess, true } return 0, false } // frontendFromOpts returns the frontend to use based on the options. -func frontendFromOpts(opts *options, embeddedFrontend fs.FS) (frontend fs.FS, err error) { +func frontendFromOpts( + ctx context.Context, + logger *slog.Logger, + opts *options, + embeddedFrontend fs.FS, +) (frontend fs.FS, err error) { const frontendSubdir = "build/static" if opts.localFrontend { - log.Info("warning: using local frontend files") + logger.WarnContext(ctx, "using local frontend files") return os.DirFS(frontendSubdir), nil } diff --git a/internal/next/cmd/signal.go b/internal/next/cmd/signal.go index 2454e062..d6aa7dc5 100644 --- a/internal/next/cmd/signal.go +++ b/internal/next/cmd/signal.go @@ -1,18 +1,26 @@ package cmd import ( + "context" + "fmt" + "log/slog" "os" "strconv" + "time" "github.com/AdguardTeam/AdGuardHome/internal/aghos" - "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/configmgr" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/osutil" + "github.com/AdguardTeam/golibs/service" ) // signalHandler processes incoming signals and shuts services down. type signalHandler struct { + // logger is used for logging the operation of the signal handler. + logger *slog.Logger + // confMgrConf contains the configuration parameters for the configuration // manager. confMgrConf *configmgr.Config @@ -24,145 +32,172 @@ type signalHandler struct { pidFile string // services are the services that are shut down before application exiting. - services []agh.Service + services []service.Interface + + // shutdownTimeout is the timeout for the shutdown operation. + shutdownTimeout time.Duration } -// handle processes OS signals. -func (h *signalHandler) handle() { - defer log.OnPanic("signalHandler.handle") +// handle processes OS signals. It blocks until a termination or a +// reconfiguration signal is received, after which it either shuts down all +// services or reconfigures them. ctx is used for logging and serves as the +// base for the shutdown timeout. status is [osutil.ExitCodeSuccess] on success +// and [osutil.ExitCodeFailure] on error. +// +// TODO(a.garipov): Add reconfiguration logic to golibs. +func (h *signalHandler) handle(ctx context.Context) (status osutil.ExitCode) { + defer slogutil.RecoverAndLog(ctx, h.logger) - h.writePID() + h.writePID(ctx) for sig := range h.signal { - log.Info("sighdlr: received signal %q", sig) + h.logger.InfoContext(ctx, "received", "signal", sig) - if aghos.IsReconfigureSignal(sig) { - h.reconfigure() + if osutil.IsReconfigureSignal(sig) { + err := h.reconfigure(ctx) + if err != nil { + h.logger.ErrorContext(ctx, "reconfiguration error", slogutil.KeyError, err) + + return osutil.ExitCodeFailure + } } else if osutil.IsShutdownSignal(sig) { - status := h.shutdown() - h.removePID() + status = h.shutdown(ctx) - log.Info("sighdlr: exiting with status %d", status) + h.removePID(ctx) - os.Exit(status) + return status } } + + // Shouldn't happen, since h.signal is currently never closed. + panic("unexpected close of h.signal") +} + +// writePID writes the PID to the file, if needed. Any errors are reported to +// log. +func (h *signalHandler) writePID(ctx context.Context) { + if h.pidFile == "" { + return + } + + pid := os.Getpid() + data := strconv.AppendInt(nil, int64(pid), 10) + data = append(data, '\n') + + err := aghos.WriteFile(h.pidFile, data, 0o644) + if err != nil { + h.logger.ErrorContext(ctx, "writing pidfile", slogutil.KeyError, err) + + return + } + + h.logger.DebugContext(ctx, "wrote pid", "file", h.pidFile, "pid", pid) } // reconfigure rereads the configuration file and updates and restarts services. -func (h *signalHandler) reconfigure() { - log.Info("sighdlr: reconfiguring adguard home") +func (h *signalHandler) reconfigure(ctx context.Context) (err error) { + h.logger.InfoContext(ctx, "reconfiguring started") - status := h.shutdown() - if status != statusSuccess { - log.Info("sighdlr: reconfiguring: exiting with status %d", status) - - os.Exit(status) + status := h.shutdown(ctx) + if status != osutil.ExitCodeSuccess { + return errors.Error("shutdown failed") } - // TODO(a.garipov): This is a very rough way to do it. Some services can be - // reconfigured without the full shutdown, and the error handling is + // TODO(a.garipov): This is a very rough way to do it. Some services can + // be reconfigured without the full shutdown, and the error handling is // currently not the best. - confMgr, err := newConfigMgr(h.confMgrConf) - check(err) + var errs []error + + ctx, cancel := context.WithTimeout(ctx, defaultTimeoutStart) + defer cancel() + + confMgr, err := newConfigMgr(ctx, h.confMgrConf) + if err != nil { + errs = append(errs, fmt.Errorf("configuration manager: %w", err)) + } web := confMgr.Web() - err = web.Start() - check(err) + err = web.Start(ctx) + if err != nil { + errs = append(errs, fmt.Errorf("starting web: %w", err)) + } dns := confMgr.DNS() - err = dns.Start() - check(err) + err = dns.Start(ctx) + if err != nil { + errs = append(errs, fmt.Errorf("starting dns: %w", err)) + } - h.services = []agh.Service{ + if len(errs) > 0 { + return errors.Join(errs...) + } + + h.services = []service.Interface{ dns, web, } - log.Info("sighdlr: successfully reconfigured adguard home") + h.logger.InfoContext(ctx, "reconfiguring finished") + + return nil } -// Exit status constants. -const ( - statusSuccess = 0 - statusError = 1 - statusArgumentError = 2 -) - // shutdown gracefully shuts down all services. -func (h *signalHandler) shutdown() (status int) { - ctx, cancel := ctxWithDefaultTimeout() +func (h *signalHandler) shutdown(ctx context.Context) (status int) { + ctx, cancel := context.WithTimeout(ctx, h.shutdownTimeout) defer cancel() - status = statusSuccess + status = osutil.ExitCodeSuccess - log.Info("sighdlr: shutting down services") - for i, service := range h.services { - err := service.Shutdown(ctx) + h.logger.InfoContext(ctx, "shutting down") + for i, svc := range h.services { + err := svc.Shutdown(ctx) if err != nil { - log.Error("sighdlr: shutting down service at index %d: %s", i, err) - status = statusError + h.logger.ErrorContext(ctx, "shutting down service", "idx", i, slogutil.KeyError, err) + status = osutil.ExitCodeFailure } } return status } -// newSignalHandler returns a new signalHandler that shuts down svcs. +// newSignalHandler returns a new signalHandler that shuts down svcs. logger +// and confMgrConf must not be nil. func newSignalHandler( + logger *slog.Logger, confMgrConf *configmgr.Config, pidFile string, - svcs ...agh.Service, + svcs ...service.Interface, ) (h *signalHandler) { h = &signalHandler{ - confMgrConf: confMgrConf, - signal: make(chan os.Signal, 1), - pidFile: pidFile, - services: svcs, + logger: logger, + confMgrConf: confMgrConf, + signal: make(chan os.Signal, 1), + pidFile: pidFile, + services: svcs, + shutdownTimeout: defaultTimeoutShutdown, } notifier := osutil.DefaultSignalNotifier{} osutil.NotifyShutdownSignal(notifier, h.signal) - aghos.NotifyReconfigureSignal(h.signal) + osutil.NotifyReconfigureSignal(notifier, h.signal) return h } -// writePID writes the PID to the file, if needed. Any errors are reported to -// log. -func (h *signalHandler) writePID() { - if h.pidFile == "" { - return - } - - // Use 8, since most PIDs will fit. - data := make([]byte, 0, 8) - data = strconv.AppendInt(data, int64(os.Getpid()), 10) - data = append(data, '\n') - - err := aghos.WriteFile(h.pidFile, data, 0o644) - if err != nil { - log.Error("sighdlr: writing pidfile: %s", err) - - return - } - - log.Debug("sighdlr: wrote pid to %q", h.pidFile) -} - // removePID removes the PID file, if any. -func (h *signalHandler) removePID() { +func (h *signalHandler) removePID(ctx context.Context) { if h.pidFile == "" { return } err := os.Remove(h.pidFile) if err != nil { - log.Error("sighdlr: removing pidfile: %s", err) + h.logger.ErrorContext(ctx, "removing pidfile", slogutil.KeyError, err) return } - log.Debug("sighdlr: removed pid at %q", h.pidFile) + h.logger.DebugContext(ctx, "removed pidfile", "file", h.pidFile) } diff --git a/internal/next/configmgr/config.go b/internal/next/configmgr/config.go index 5d67a372..7b47b147 100644 --- a/internal/next/configmgr/config.go +++ b/internal/next/configmgr/config.go @@ -4,12 +4,11 @@ import ( "fmt" "net/netip" + "github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/timeutil" ) -// Configuration Structures - // config is the top-level on-disk configuration structure. type config struct { DNS *dnsConfig `yaml:"dns"` @@ -19,35 +18,33 @@ type config struct { SchemaVersion int `yaml:"schema_version"` } -const errNoConf errors.Error = "configuration not found" +// type check +var _ validator = (*config)(nil) -// validate returns an error if the configuration structure is invalid. +// validate implements the [validator] interface for *config. func (c *config) validate() (err error) { if c == nil { - return errNoConf + return errors.ErrNoValue } // TODO(a.garipov): Add more validations. // Keep this in the same order as the fields in the config. - validators := []struct { - validate func() (err error) - name string - }{{ - validate: c.DNS.validate, - name: "dns", + validators := container.KeyValues[string, validator]{{ + Key: "dns", + Value: c.DNS, }, { - validate: c.HTTP.validate, - name: "http", + Key: "http", + Value: c.HTTP, }, { - validate: c.Log.validate, - name: "log", + Key: "log", + Value: c.Log, }} - for _, v := range validators { - err = v.validate() + for _, kv := range validators { + err = kv.Value.validate() if err != nil { - return fmt.Errorf("%s: %w", v.name, err) + return fmt.Errorf("%s: %w", kv.Key, err) } } @@ -65,16 +62,19 @@ type dnsConfig struct { UseDNS64 bool `yaml:"use_dns64"` } -// validate returns an error if the DNS configuration structure is invalid. +// type check +var _ validator = (*dnsConfig)(nil) + +// validate implements the [validator] interface for *dnsConfig. // // TODO(a.garipov): Add more validations. func (c *dnsConfig) validate() (err error) { // TODO(a.garipov): Add more validations. switch { case c == nil: - return errNoConf + return errors.ErrNoValue case c.UpstreamTimeout.Duration <= 0: - return newMustBePositiveError("upstream_timeout", c.UpstreamTimeout) + return newErrNotPositive("upstream_timeout", c.UpstreamTimeout) default: return nil } @@ -91,15 +91,18 @@ type httpConfig struct { ForceHTTPS bool `yaml:"force_https"` } -// validate returns an error if the HTTP configuration structure is invalid. +// type check +var _ validator = (*httpConfig)(nil) + +// validate implements the [validator] interface for *httpConfig. // // TODO(a.garipov): Add more validations. func (c *httpConfig) validate() (err error) { switch { case c == nil: - return errNoConf + return errors.ErrNoValue case c.Timeout.Duration <= 0: - return newMustBePositiveError("timeout", c.Timeout) + return newErrNotPositive("timeout", c.Timeout) default: return c.Pprof.validate() } @@ -111,10 +114,13 @@ type httpPprofConfig struct { Enabled bool `yaml:"enabled"` } -// validate returns an error if the pprof configuration structure is invalid. +// type check +var _ validator = (*httpPprofConfig)(nil) + +// validate implements the [validator] interface for *httpPprofConfig. func (c *httpPprofConfig) validate() (err error) { if c == nil { - return errNoConf + return errors.ErrNoValue } return nil @@ -126,12 +132,15 @@ type logConfig struct { Verbose bool `yaml:"verbose"` } -// validate returns an error if the HTTP configuration structure is invalid. +// type check +var _ validator = (*logConfig)(nil) + +// validate implements the [validator] interface for *logConfig. // // TODO(a.garipov): Add more validations. func (c *logConfig) validate() (err error) { if c == nil { - return errNoConf + return errors.ErrNoValue } return nil diff --git a/internal/next/configmgr/configmgr.go b/internal/next/configmgr/configmgr.go index a22b5bbb..1f9bc8de 100644 --- a/internal/next/configmgr/configmgr.go +++ b/internal/next/configmgr/configmgr.go @@ -8,6 +8,7 @@ import ( "context" "fmt" "io/fs" + "log/slog" "net/netip" "os" "slices" @@ -19,18 +20,22 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/timeutil" "gopkg.in/yaml.v3" ) -// Configuration Manager - // Manager handles full and partial changes in the configuration, persisting // them to disk if necessary. // // TODO(a.garipov): Support missing configs and default values. type Manager struct { + // baseLogger is used to create loggers for other entities. + baseLogger *slog.Logger + + // logger is used for logging the operation of the configuration manager. + logger *slog.Logger + // updMu makes sure that at most one reconfiguration is performed at a time. // updMu protects all fields below. updMu *sync.RWMutex @@ -57,12 +62,24 @@ func Validate(fileName string) (err error) { return err } - // Don't wrap the error, because it's informative enough as is. - return conf.validate() + err = conf.validate() + if err != nil { + return fmt.Errorf("validating config: %w", err) + } + + return nil } // Config contains the configuration parameters for the configuration manager. type Config struct { + // BaseLogger is used to create loggers for other entities. It must not be + // nil. + BaseLogger *slog.Logger + + // Logger is used for logging the operation of the configuration manager. + // It must not be nil. + Logger *slog.Logger + // Frontend is the filesystem with the frontend files. Frontend fs.FS @@ -93,9 +110,11 @@ func New(ctx context.Context, c *Config) (m *Manager, err error) { } m = &Manager{ - updMu: &sync.RWMutex{}, - current: conf, - fileName: c.FileName, + baseLogger: c.BaseLogger, + logger: c.Logger, + updMu: &sync.RWMutex{}, + current: conf, + fileName: c.FileName, } err = m.assemble(ctx, conf, c.Frontend, c.WebAddr, c.Start) @@ -137,6 +156,7 @@ func (m *Manager) assemble( start time.Time, ) (err error) { dnsConf := &dnssvc.Config{ + Logger: m.baseLogger.With(slogutil.KeyPrefix, "dnssvc"), Addresses: conf.DNS.Addresses, BootstrapServers: conf.DNS.BootstrapDNS, UpstreamServers: conf.DNS.UpstreamDNS, @@ -151,6 +171,7 @@ func (m *Manager) assemble( } webSvcConf := &websvc.Config{ + Logger: m.baseLogger.With(slogutil.KeyPrefix, "websvc"), Pprof: &websvc.PprofConfig{ Port: conf.HTTP.Pprof.Port, Enabled: conf.HTTP.Pprof.Enabled, @@ -176,7 +197,7 @@ func (m *Manager) assemble( } // write writes the current configuration to disk. -func (m *Manager) write() (err error) { +func (m *Manager) write(ctx context.Context) (err error) { b, err := yaml.Marshal(m.current) if err != nil { return fmt.Errorf("encoding: %w", err) @@ -187,7 +208,7 @@ func (m *Manager) write() (err error) { return fmt.Errorf("writing: %w", err) } - log.Info("configmgr: written to %q", m.fileName) + m.logger.InfoContext(ctx, "config file written", "path", m.fileName) return nil } @@ -216,7 +237,7 @@ func (m *Manager) UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error) { m.updateCurrentDNS(c) - return m.write() + return m.write(ctx) } // updateDNS recreates the DNS service. m.updMu is expected to be locked. @@ -270,7 +291,7 @@ func (m *Manager) UpdateWeb(ctx context.Context, c *websvc.Config) (err error) { m.updateCurrentWeb(c) - return m.write() + return m.write(ctx) } // updateWeb recreates the web service. m.upd is expected to be locked. diff --git a/internal/next/configmgr/error.go b/internal/next/configmgr/error.go index b4ffb92b..4b737197 100644 --- a/internal/next/configmgr/error.go +++ b/internal/next/configmgr/error.go @@ -3,25 +3,29 @@ package configmgr import ( "fmt" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/timeutil" "golang.org/x/exp/constraints" ) +// validator is the interface for configuration entities that can validate +// themselves. +type validator interface { + // validate returns an error if the entity isn't valid. + validate() (err error) +} + // numberOrDuration is the constraint for integer types along with // timeutil.Duration. type numberOrDuration interface { constraints.Integer | timeutil.Duration } -// newMustBePositiveError returns an error about the value that must be positive -// but isn't. prop is the name of the property to mention in the error message. +// newErrNotPositive returns an error about the value that must be positive but +// isn't. prop is the name of the property to mention in the error message. // // TODO(a.garipov): Consider moving such helpers to golibs and use in AdGuardDNS // as well. -func newMustBePositiveError[T numberOrDuration](prop string, v T) (err error) { - if s, ok := any(v).(fmt.Stringer); ok { - return fmt.Errorf("%s must be positive, got %s", prop, s) - } - - return fmt.Errorf("%s must be positive, got %d", prop, v) +func newErrNotPositive[T numberOrDuration](prop string, v T) (err error) { + return fmt.Errorf("%s: %w, got %v", prop, errors.ErrNotPositive, v) } diff --git a/internal/next/dnssvc/config.go b/internal/next/dnssvc/config.go index 57818c20..e4e882be 100644 --- a/internal/next/dnssvc/config.go +++ b/internal/next/dnssvc/config.go @@ -1,6 +1,7 @@ package dnssvc import ( + "log/slog" "net/netip" "time" ) @@ -9,6 +10,10 @@ import ( // // TODO(a.garipov): Add timeout for incoming requests. type Config struct { + // Logger is used for logging the operation of the web API service. It must + // not be nil. + Logger *slog.Logger + // Addresses are the addresses on which to serve plain DNS queries. Addresses []netip.AddrPort diff --git a/internal/next/dnssvc/dnssvc.go b/internal/next/dnssvc/dnssvc.go index 345af7bc..9e3b5b35 100644 --- a/internal/next/dnssvc/dnssvc.go +++ b/internal/next/dnssvc/dnssvc.go @@ -7,6 +7,7 @@ package dnssvc import ( "context" "fmt" + "log/slog" "net" "net/netip" "sync/atomic" @@ -28,6 +29,7 @@ import ( // TODO(a.garipov): Consider saving a [*proxy.Config] instance for those // fields that are only used in [New] and [Service.Config]. type Service struct { + logger *slog.Logger proxy *proxy.Proxy bootstraps []string bootstrapResolvers []*upstream.UpstreamResolver @@ -48,6 +50,7 @@ func New(c *Config) (svc *Service, err error) { } svc = &Service{ + logger: c.Logger, bootstraps: c.BootstrapServers, upstreams: c.UpstreamServers, dns64Prefixes: c.DNS64Prefixes, @@ -68,6 +71,7 @@ func New(c *Config) (svc *Service, err error) { svc.bootstrapResolvers = resolvers svc.proxy, err = proxy.New(&proxy.Config{ + Logger: svc.logger, UDPListenAddr: udpAddrs(c.Addresses), TCPListenAddr: tcpAddrs(c.Addresses), UpstreamConfig: &proxy.UpstreamConfig{ @@ -153,12 +157,12 @@ func udpAddrs(addrPorts []netip.AddrPort) (udpAddrs []*net.UDPAddr) { } // type check -var _ agh.Service = (*Service)(nil) +var _ agh.ServiceWithConfig[*Config] = (*Service)(nil) // Start implements the [agh.Service] interface for *Service. svc may be nil. // After Start exits, all DNS servers have tried to start, but there is no // guarantee that they did. Errors from the servers are written to the log. -func (svc *Service) Start() (err error) { +func (svc *Service) Start(ctx context.Context) (err error) { if svc == nil { return nil } @@ -170,7 +174,7 @@ func (svc *Service) Start() (err error) { svc.running.Store(err == nil) }() - return svc.proxy.Start(context.Background()) + return svc.proxy.Start(ctx) } // Shutdown implements the [agh.Service] interface for *Service. svc may be @@ -215,6 +219,7 @@ func (svc *Service) Config() (c *Config) { } c = &Config{ + Logger: svc.logger, Addresses: addrs, BootstrapServers: svc.bootstraps, UpstreamServers: svc.upstreams, diff --git a/internal/next/dnssvc/dnssvc_test.go b/internal/next/dnssvc/dnssvc_test.go index 2a46d956..c8a438eb 100644 --- a/internal/next/dnssvc/dnssvc_test.go +++ b/internal/next/dnssvc/dnssvc_test.go @@ -6,16 +6,13 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestMain(m *testing.M) { - testutil.DiscardLogOutput(m) -} - // testTimeout is the common timeout for tests. const testTimeout = 1 * time.Second @@ -59,6 +56,7 @@ func TestService(t *testing.T) { _, _ = testutil.RequireReceive(t, upstreamStartedCh, testTimeout) c := &dnssvc.Config{ + Logger: slogutil.NewDiscardLogger(), Addresses: []netip.AddrPort{netip.MustParseAddrPort(listenAddr)}, BootstrapServers: []string{upstreamSrv.PacketConn.LocalAddr().String()}, UpstreamServers: []string{upstreamAddr}, @@ -71,7 +69,7 @@ func TestService(t *testing.T) { svc, err := dnssvc.New(c) require.NoError(t, err) - err = svc.Start() + err = svc.Start(testutil.ContextWithTimeout(t, testTimeout)) require.NoError(t, err) gotConf := svc.Config() diff --git a/internal/next/websvc/config.go b/internal/next/websvc/config.go index 36a145c5..6e81afa2 100644 --- a/internal/next/websvc/config.go +++ b/internal/next/websvc/config.go @@ -3,12 +3,17 @@ package websvc import ( "crypto/tls" "io/fs" + "log/slog" "net/netip" "time" ) // Config is the AdGuard Home web service configuration structure. type Config struct { + // Logger is used for logging the operation of the web API service. It must + // not be nil. + Logger *slog.Logger + // Pprof is the configuration for the pprof debug API. It must not be nil. Pprof *PprofConfig @@ -60,17 +65,20 @@ type PprofConfig struct { // finished. func (svc *Service) Config() (c *Config) { c = &Config{ + Logger: svc.logger, Pprof: &PprofConfig{ Port: svc.pprofPort, Enabled: svc.pprof != nil, }, ConfigManager: svc.confMgr, + Frontend: svc.frontend, TLS: svc.tls, // Leave Addresses and SecureAddresses empty and get the actual // addresses that include the :0 ones later. - Start: svc.start, - Timeout: svc.timeout, - ForceHTTPS: svc.forceHTTPS, + Start: svc.start, + OverrideAddress: svc.overrideAddr, + Timeout: svc.timeout, + ForceHTTPS: svc.forceHTTPS, } c.Addresses, c.SecureAddresses = svc.addrs() diff --git a/internal/next/websvc/dns.go b/internal/next/websvc/dns.go index 39f05d22..9c2a222f 100644 --- a/internal/next/websvc/dns.go +++ b/internal/next/websvc/dns.go @@ -11,8 +11,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" ) -// DNS Settings Handlers - // ReqPatchSettingsDNS describes the request to the PATCH /api/v1/settings/dns // HTTP API. type ReqPatchSettingsDNS struct { @@ -60,6 +58,7 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques } newConf := &dnssvc.Config{ + Logger: svc.logger, Addresses: req.Addresses, BootstrapServers: req.BootstrapServers, UpstreamServers: req.UpstreamServers, @@ -78,7 +77,7 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques } newSvc := svc.confMgr.DNS() - err = newSvc.Start() + err = newSvc.Start(ctx) if err != nil { aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("starting new service: %w", err)) diff --git a/internal/next/websvc/dns_test.go b/internal/next/websvc/dns_test.go index d7e58eb2..bb546778 100644 --- a/internal/next/websvc/dns_test.go +++ b/internal/next/websvc/dns_test.go @@ -35,7 +35,7 @@ func TestService_HandlePatchSettingsDNS(t *testing.T) { confMgr := newConfigManager() confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) { return &aghtest.ServiceWithConfig[*dnssvc.Config]{ - OnStart: func() (err error) { + OnStart: func(_ context.Context) (err error) { started.Store(true) return nil @@ -52,7 +52,7 @@ func TestService_HandlePatchSettingsDNS(t *testing.T) { u := &url.URL{ Scheme: urlutil.SchemeHTTP, Host: addr.String(), - Path: websvc.PathV1SettingsDNS, + Path: websvc.PathPatternV1SettingsDNS, } req := jobj{ diff --git a/internal/next/websvc/http.go b/internal/next/websvc/http.go index db32372d..3fe8bce7 100644 --- a/internal/next/websvc/http.go +++ b/internal/next/websvc/http.go @@ -10,11 +10,9 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/next/agh" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" ) -// HTTP Settings Handlers - // ReqPatchSettingsHTTP describes the request to the PATCH /api/v1/settings/http // HTTP API. type ReqPatchSettingsHTTP struct { @@ -53,6 +51,7 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque } newConf := &Config{ + Logger: svc.logger, Pprof: &PprofConfig{ Port: svc.pprofPort, Enabled: svc.pprof != nil, @@ -89,13 +88,13 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque // relaunch updates the web service in the configuration manager and starts it. // It is intended to be used as a goroutine. func (svc *Service) relaunch(ctx context.Context, cancel context.CancelFunc, newConf *Config) { - defer log.OnPanic("websvc: relaunching") + defer slogutil.RecoverAndLog(ctx, svc.logger) defer cancel() err := svc.confMgr.UpdateWeb(ctx, newConf) if err != nil { - log.Error("websvc: updating web: %s", err) + svc.logger.ErrorContext(ctx, "updating web", slogutil.KeyError, err) return } @@ -106,18 +105,18 @@ func (svc *Service) relaunch(ctx context.Context, cancel context.CancelFunc, new var newSvc agh.ServiceWithConfig[*Config] for newSvc = svc.confMgr.Web(); newSvc == svc; { if time.Since(updStart) >= maxUpdDur { - log.Error("websvc: failed to update svc after %s", maxUpdDur) + svc.logger.ErrorContext(ctx, "failed to update service on time", "duration", maxUpdDur) return } - log.Debug("websvc: waiting for new websvc to be configured") + svc.logger.DebugContext(ctx, "waiting for new service") time.Sleep(100 * time.Millisecond) } - err = newSvc.Start() + err = newSvc.Start(ctx) if err != nil { - log.Error("websvc: new svc failed to start with error: %s", err) + svc.logger.ErrorContext(ctx, "new service failed", slogutil.KeyError, err) } } diff --git a/internal/next/websvc/http_test.go b/internal/next/websvc/http_test.go index 99569b62..297754f4 100644 --- a/internal/next/websvc/http_test.go +++ b/internal/next/websvc/http_test.go @@ -13,6 +13,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil/urlutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -27,14 +28,15 @@ func TestService_HandlePatchSettingsHTTP(t *testing.T) { } svc, err := websvc.New(&websvc.Config{ + Logger: slogutil.NewDiscardLogger(), Pprof: &websvc.PprofConfig{ Enabled: false, }, TLS: &tls.Config{ Certificates: []tls.Certificate{{}}, }, - Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")}, - SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")}, + Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")}, + SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")}, Timeout: 5 * time.Second, ForceHTTPS: true, }) @@ -48,7 +50,7 @@ func TestService_HandlePatchSettingsHTTP(t *testing.T) { u := &url.URL{ Scheme: urlutil.SchemeHTTP, Host: addr.String(), - Path: websvc.PathV1SettingsHTTP, + Path: websvc.PathPatternV1SettingsHTTP, } req := jobj{ diff --git a/internal/next/websvc/middleware.go b/internal/next/websvc/middleware.go index 8dc66b34..e90bb96b 100644 --- a/internal/next/websvc/middleware.go +++ b/internal/next/websvc/middleware.go @@ -2,15 +2,11 @@ package websvc import ( "net/http" - "time" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/golibs/httphdr" - "github.com/AdguardTeam/golibs/log" ) -// Middlewares - // jsonMw sets the content type of the response to application/json. func jsonMw(h http.Handler) (wrapped http.HandlerFunc) { f := func(w http.ResponseWriter, r *http.Request) { @@ -21,18 +17,3 @@ func jsonMw(h http.Handler) (wrapped http.HandlerFunc) { return http.HandlerFunc(f) } - -// logMw logs the queries with level debug. -func logMw(h http.Handler) (wrapped http.HandlerFunc) { - f := func(w http.ResponseWriter, r *http.Request) { - start := time.Now() - m, u := r.Method, r.RequestURI - - log.Debug("websvc: %s %s started", m, u) - defer func() { log.Debug("websvc: %s %s finished in %s", m, u, time.Since(start)) }() - - h.ServeHTTP(w, r) - } - - return http.HandlerFunc(f) -} diff --git a/internal/next/websvc/path.go b/internal/next/websvc/path.go deleted file mode 100644 index 95be8204..00000000 --- a/internal/next/websvc/path.go +++ /dev/null @@ -1,14 +0,0 @@ -package websvc - -// Path constants -const ( - PathRoot = "/" - PathFrontend = "/*filepath" - - PathHealthCheck = "/health-check" - - PathV1SettingsAll = "/api/v1/settings/all" - PathV1SettingsDNS = "/api/v1/settings/dns" - PathV1SettingsHTTP = "/api/v1/settings/http" - PathV1SystemInfo = "/api/v1/system/info" -) diff --git a/internal/next/websvc/route.go b/internal/next/websvc/route.go new file mode 100644 index 00000000..e2e5b06f --- /dev/null +++ b/internal/next/websvc/route.go @@ -0,0 +1,73 @@ +package websvc + +import ( + "log/slog" + "net/http" + + "github.com/AdguardTeam/golibs/netutil/httputil" +) + +// Path pattern constants. +const ( + PathPatternFrontend = "/" + PathPatternHealthCheck = "/health-check" + PathPatternV1SettingsAll = "/api/v1/settings/all" + PathPatternV1SettingsDNS = "/api/v1/settings/dns" + PathPatternV1SettingsHTTP = "/api/v1/settings/http" + PathPatternV1SystemInfo = "/api/v1/system/info" +) + +// Route pattern constants. +const ( + routePatternFrontend = http.MethodGet + " " + PathPatternFrontend + routePatternGetV1SettingsAll = http.MethodGet + " " + PathPatternV1SettingsAll + routePatternGetV1SystemInfo = http.MethodGet + " " + PathPatternV1SystemInfo + routePatternHealthCheck = http.MethodGet + " " + PathPatternHealthCheck + routePatternPatchV1SettingsDNS = http.MethodPatch + " " + PathPatternV1SettingsDNS + routePatternPatchV1SettingsHTTP = http.MethodPatch + " " + PathPatternV1SettingsHTTP +) + +// route registers all necessary handlers in mux. +func (svc *Service) route(mux *http.ServeMux) { + routes := []struct { + handler http.Handler + pattern string + isJSON bool + }{{ + handler: httputil.HealthCheckHandler, + pattern: routePatternHealthCheck, + isJSON: false, + }, { + handler: http.FileServer(http.FS(svc.frontend)), + pattern: routePatternFrontend, + isJSON: false, + }, { + handler: http.HandlerFunc(svc.handleGetSettingsAll), + pattern: routePatternGetV1SettingsAll, + isJSON: true, + }, { + handler: http.HandlerFunc(svc.handlePatchSettingsDNS), + pattern: routePatternPatchV1SettingsDNS, + isJSON: true, + }, { + handler: http.HandlerFunc(svc.handlePatchSettingsHTTP), + pattern: routePatternPatchV1SettingsHTTP, + isJSON: true, + }, { + handler: http.HandlerFunc(svc.handleGetV1SystemInfo), + pattern: routePatternGetV1SystemInfo, + isJSON: true, + }} + + logMw := httputil.NewLogMiddleware(svc.logger, slog.LevelDebug) + for _, r := range routes { + var hdlr http.Handler + if r.isJSON { + hdlr = jsonMw(r.handler) + } else { + hdlr = r.handler + } + + mux.Handle(r.pattern, logMw.Wrap(hdlr)) + } +} diff --git a/internal/next/websvc/server.go b/internal/next/websvc/server.go new file mode 100644 index 00000000..f1299d04 --- /dev/null +++ b/internal/next/websvc/server.go @@ -0,0 +1,156 @@ +package websvc + +import ( + "context" + "crypto/tls" + "fmt" + "log/slog" + "net" + "net/http" + "net/netip" + "net/url" + "sync" + "time" + + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/netutil/urlutil" +) + +// server contains an *http.Server as well as entities and data associated with +// it. +// +// TODO(a.garipov): Join with similar structs in other projects and move to +// golibs/netutil/httputil. +// +// TODO(a.garipov): Once the above standardization is complete, consider +// merging debugsvc and websvc into a single httpsvc. +type server struct { + // mu protects http, logger, tcpListener, and url. + mu *sync.Mutex + http *http.Server + logger *slog.Logger + tcpListener *net.TCPListener + url *url.URL + + tlsConf *tls.Config + initialAddr netip.AddrPort +} + +// loggerKeyServer is the key used by [server] to identify itself. +const loggerKeyServer = "server" + +// newServer returns a *server that is ready to serve HTTP queries. The TCP +// listener is not started. handler must not be nil. +func newServer( + baseLogger *slog.Logger, + initialAddr netip.AddrPort, + tlsConf *tls.Config, + handler http.Handler, + timeout time.Duration, +) (s *server) { + u := &url.URL{ + Scheme: urlutil.SchemeHTTP, + Host: initialAddr.String(), + } + + if tlsConf != nil { + u.Scheme = urlutil.SchemeHTTPS + } + + logger := baseLogger.With(loggerKeyServer, u) + + return &server{ + mu: &sync.Mutex{}, + http: &http.Server{ + Handler: handler, + ReadTimeout: timeout, + ReadHeaderTimeout: timeout, + WriteTimeout: timeout, + IdleTimeout: timeout, + ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError), + }, + logger: logger, + url: u, + + tlsConf: tlsConf, + initialAddr: initialAddr, + } +} + +// localAddr returns the local address of the server if the server has started +// listening; otherwise, it returns nil. +func (s *server) localAddr() (addr net.Addr) { + s.mu.Lock() + defer s.mu.Unlock() + + if l := s.tcpListener; l != nil { + return l.Addr() + } + + return nil +} + +// serve starts s. baseLogger is used as a base logger for s. If s fails to +// serve with anything other than [http.ErrServerClosed], it causes an unhandled +// panic. It is intended to be used as a goroutine. +// +// TODO(a.garipov): Improve error handling. +func (s *server) serve(ctx context.Context, baseLogger *slog.Logger) { + l, err := net.ListenTCP("tcp", net.TCPAddrFromAddrPort(s.initialAddr)) + if err != nil { + s.logger.ErrorContext(ctx, "listening tcp", slogutil.KeyError, err) + + panic(fmt.Errorf("websvc: listening tcp: %w", err)) + } + + func() { + s.mu.Lock() + defer s.mu.Unlock() + + s.tcpListener = l + + // Reassign the address in case the port was zero. + s.url.Host = l.Addr().String() + s.logger = baseLogger.With(loggerKeyServer, s.url) + s.http.ErrorLog = slog.NewLogLogger(s.logger.Handler(), slog.LevelError) + }() + + s.logger.InfoContext(ctx, "starting") + defer s.logger.InfoContext(ctx, "started") + + err = s.http.Serve(l) + if err == nil || errors.Is(err, http.ErrServerClosed) { + return + } + + s.logger.ErrorContext(ctx, "serving", slogutil.KeyError, err) + + panic(fmt.Errorf("websvc: serving: %w", err)) +} + +// shutdown shuts s down. +func (s *server) shutdown(ctx context.Context) (err error) { + s.mu.Lock() + defer s.mu.Unlock() + + var errs []error + err = s.http.Shutdown(ctx) + if err != nil { + errs = append(errs, fmt.Errorf("shutting down server %s: %w", s.url, err)) + } + + // Close the listener separately, as it might not have been closed if the + // context has been canceled. + // + // NOTE: The listener could remain uninitialized if [net.ListenTCP] failed + // in [s.serve]. + if l := s.tcpListener; l != nil { + err = l.Close() + if err != nil && !errors.Is(err, net.ErrClosed) { + errs = append(errs, fmt.Errorf("closing listener for server %s: %w", s.url, err)) + } + } + + return errors.Join(errs...) +} diff --git a/internal/next/websvc/settings_test.go b/internal/next/websvc/settings_test.go index e30002de..b2003556 100644 --- a/internal/next/websvc/settings_test.go +++ b/internal/next/websvc/settings_test.go @@ -1,7 +1,6 @@ package websvc_test import ( - "crypto/tls" "encoding/json" "net/http" "net/netip" @@ -13,6 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil/urlutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -29,16 +29,10 @@ func TestService_HandleGetSettingsAll(t *testing.T) { BootstrapPreferIPv6: true, } - wantWeb := &websvc.HTTPAPIHTTPSettings{ - Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")}, - SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")}, - Timeout: aghhttp.JSONDuration(5 * time.Second), - ForceHTTPS: true, - } - confMgr := newConfigManager() confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) { c, err := dnssvc.New(&dnssvc.Config{ + Logger: slogutil.NewDiscardLogger(), Addresses: wantDNS.Addresses, UpstreamServers: wantDNS.UpstreamServers, BootstrapServers: wantDNS.BootstrapServers, @@ -50,34 +44,27 @@ func TestService_HandleGetSettingsAll(t *testing.T) { return c } - svc, err := websvc.New(&websvc.Config{ - Pprof: &websvc.PprofConfig{ - Enabled: false, - }, - TLS: &tls.Config{ - Certificates: []tls.Certificate{{}}, - }, - Addresses: wantWeb.Addresses, - SecureAddresses: wantWeb.SecureAddresses, - Timeout: time.Duration(wantWeb.Timeout), - ForceHTTPS: true, - }) - require.NoError(t, err) + svc, addr := newTestServer(t, confMgr) + u := &url.URL{ + Scheme: urlutil.SchemeHTTP, + Host: addr.String(), + Path: websvc.PathPatternV1SettingsAll, + } confMgr.onWeb = func() (s agh.ServiceWithConfig[*websvc.Config]) { return svc } - _, addr := newTestServer(t, confMgr) - u := &url.URL{ - Scheme: urlutil.SchemeHTTP, - Host: addr.String(), - Path: websvc.PathV1SettingsAll, + wantWeb := &websvc.HTTPAPIHTTPSettings{ + Addresses: []netip.AddrPort{addr}, + SecureAddresses: nil, + Timeout: aghhttp.JSONDuration(testTimeout), + ForceHTTPS: false, } body := httpGet(t, u, http.StatusOK) resp := &websvc.RespGetV1SettingsAll{} - err = json.Unmarshal(body, resp) + err := json.Unmarshal(body, resp) require.NoError(t, err) assert.Equal(t, wantDNS, resp.DNS) diff --git a/internal/next/websvc/system_test.go b/internal/next/websvc/system_test.go index 44f9e529..a021886c 100644 --- a/internal/next/websvc/system_test.go +++ b/internal/next/websvc/system_test.go @@ -20,7 +20,7 @@ func TestService_handleGetV1SystemInfo(t *testing.T) { u := &url.URL{ Scheme: urlutil.SchemeHTTP, Host: addr.String(), - Path: websvc.PathV1SystemInfo, + Path: websvc.PathPatternV1SystemInfo, } body := httpGet(t, u, http.StatusOK) diff --git a/internal/next/websvc/waitlistener.go b/internal/next/websvc/waitlistener.go deleted file mode 100644 index 8ab56269..00000000 --- a/internal/next/websvc/waitlistener.go +++ /dev/null @@ -1,31 +0,0 @@ -package websvc - -import ( - "net" - "sync" -) - -// Wait Listener - -// waitListener is a wrapper around a listener that also calls wg.Done() on the -// first call to Accept. It is useful in situations where it is important to -// catch the precise moment of the first call to Accept, for example when -// starting an HTTP server. -// -// TODO(a.garipov): Move to aghnet? -type waitListener struct { - net.Listener - - firstAcceptWG *sync.WaitGroup - firstAcceptOnce sync.Once -} - -// type check -var _ net.Listener = (*waitListener)(nil) - -// Accept implements the [net.Listener] interface for *waitListener. -func (l *waitListener) Accept() (conn net.Conn, err error) { - l.firstAcceptOnce.Do(l.firstAcceptWG.Done) - - return l.Listener.Accept() -} diff --git a/internal/next/websvc/waitlistener_internal_test.go b/internal/next/websvc/waitlistener_internal_test.go deleted file mode 100644 index 089c0531..00000000 --- a/internal/next/websvc/waitlistener_internal_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package websvc - -import ( - "net" - "sync" - "sync/atomic" - "testing" - - "github.com/AdguardTeam/golibs/testutil/fakenet" - "github.com/stretchr/testify/assert" -) - -func TestWaitListener_Accept(t *testing.T) { - var accepted atomic.Bool - var l net.Listener = &fakenet.Listener{ - OnAccept: func() (conn net.Conn, err error) { - accepted.Store(true) - - return nil, nil - }, - OnAddr: func() (addr net.Addr) { panic("not implemented") }, - OnClose: func() (err error) { panic("not implemented") }, - } - - wg := &sync.WaitGroup{} - wg.Add(1) - - go func() { - var wrapper net.Listener = &waitListener{ - Listener: l, - firstAcceptWG: wg, - } - - _, _ = wrapper.Accept() - }() - - wg.Wait() - - assert.Eventually(t, accepted.Load, testTimeout, testTimeout/10) -} diff --git a/internal/next/websvc/websvc.go b/internal/next/websvc/websvc.go index 31dbbb65..189d231e 100644 --- a/internal/next/websvc/websvc.go +++ b/internal/next/websvc/websvc.go @@ -10,22 +10,18 @@ import ( "context" "crypto/tls" "fmt" - "io" "io/fs" - "net" + "log/slog" "net/http" "net/netip" "runtime" - "sync" "time" "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/golibs/mathutil" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil/httputil" - httptreemux "github.com/dimfeld/httptreemux/v5" ) // ConfigManager is the configuration manager interface. @@ -40,13 +36,14 @@ type ConfigManager interface { // Service is the AdGuard Home web service. A nil *Service is a valid // [agh.Service] that does nothing. type Service struct { + logger *slog.Logger confMgr ConfigManager frontend fs.FS tls *tls.Config - pprof *http.Server + pprof *server start time.Time overrideAddr netip.AddrPort - servers []*http.Server + servers []*server timeout time.Duration pprofPort uint16 forceHTTPS bool @@ -64,6 +61,7 @@ func New(c *Config) (svc *Service, err error) { } svc = &Service{ + logger: c.Logger, confMgr: c.ConfigManager, frontend: c.Frontend, tls: c.TLS, @@ -73,17 +71,18 @@ func New(c *Config) (svc *Service, err error) { forceHTTPS: c.ForceHTTPS, } - mux := newMux(svc) + mux := http.NewServeMux() + svc.route(mux) if svc.overrideAddr != (netip.AddrPort{}) { - svc.servers = []*http.Server{newSrv(svc.overrideAddr, nil, mux, c.Timeout)} + svc.servers = []*server{newServer(svc.logger, svc.overrideAddr, nil, mux, c.Timeout)} } else { for _, a := range c.Addresses { - svc.servers = append(svc.servers, newSrv(a, nil, mux, c.Timeout)) + svc.servers = append(svc.servers, newServer(svc.logger, a, nil, mux, c.Timeout)) } for _, a := range c.SecureAddresses { - svc.servers = append(svc.servers, newSrv(a, c.TLS, mux, c.Timeout)) + svc.servers = append(svc.servers, newServer(svc.logger, a, c.TLS, mux, c.Timeout)) } } @@ -112,96 +111,7 @@ func (svc *Service) setupPprof(c *PprofConfig) { svc.pprofPort = c.Port addr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), c.Port) - // TODO(a.garipov): Consider making pprof timeout configurable. - svc.pprof = newSrv(addr, nil, pprofMux, 10*time.Minute) -} - -// newSrv returns a new *http.Server with the given parameters. -func newSrv( - addr netip.AddrPort, - tlsConf *tls.Config, - h http.Handler, - timeout time.Duration, -) (srv *http.Server) { - addrStr := addr.String() - srv = &http.Server{ - Addr: addrStr, - Handler: h, - TLSConfig: tlsConf, - ReadTimeout: timeout, - WriteTimeout: timeout, - IdleTimeout: timeout, - ReadHeaderTimeout: timeout, - } - - if tlsConf == nil { - srv.ErrorLog = log.StdLog("websvc: plain http: "+addrStr, log.ERROR) - } else { - srv.ErrorLog = log.StdLog("websvc: https: "+addrStr, log.ERROR) - } - - return srv -} - -// newMux returns a new HTTP request multiplexer for the AdGuard Home web -// service. -func newMux(svc *Service) (mux *httptreemux.ContextMux) { - mux = httptreemux.NewContextMux() - - routes := []struct { - handler http.HandlerFunc - method string - pattern string - isJSON bool - }{{ - handler: svc.handleGetHealthCheck, - method: http.MethodGet, - pattern: PathHealthCheck, - isJSON: false, - }, { - handler: http.FileServer(http.FS(svc.frontend)).ServeHTTP, - method: http.MethodGet, - pattern: PathFrontend, - isJSON: false, - }, { - handler: http.FileServer(http.FS(svc.frontend)).ServeHTTP, - method: http.MethodGet, - pattern: PathRoot, - isJSON: false, - }, { - handler: svc.handleGetSettingsAll, - method: http.MethodGet, - pattern: PathV1SettingsAll, - isJSON: true, - }, { - handler: svc.handlePatchSettingsDNS, - method: http.MethodPatch, - pattern: PathV1SettingsDNS, - isJSON: true, - }, { - handler: svc.handlePatchSettingsHTTP, - method: http.MethodPatch, - pattern: PathV1SettingsHTTP, - isJSON: true, - }, { - handler: svc.handleGetV1SystemInfo, - method: http.MethodGet, - pattern: PathV1SystemInfo, - isJSON: true, - }} - - for _, r := range routes { - var hdlr http.Handler - if r.isJSON { - hdlr = jsonMw(r.handler) - } else { - hdlr = r.handler - } - - mux.Handle(r.method, r.pattern, logMw(hdlr)) - } - - return mux + svc.pprof = newServer(svc.logger, addr, nil, pprofMux, 10*time.Minute) } // addrs returns all addresses on which this server serves the HTTP API. addrs @@ -214,14 +124,12 @@ func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) { } for _, srv := range svc.servers { - // Use MustParseAddrPort, since no errors should technically happen - // here, because all servers must have a valid address. - addrPort := netip.MustParseAddrPort(srv.Addr) + addrPort := netutil.NetAddrToAddrPort(srv.localAddr()) + if addrPort == (netip.AddrPort{}) { + continue + } - // [srv.Serve] will set TLSConfig to an almost empty value, so, instead - // of relying only on the nilness of TLSConfig, check the length of the - // certificates field as well. - if srv.TLSConfig == nil || len(srv.TLSConfig.Certificates) == 0 { + if srv.tlsConf == nil { addrs = append(addrs, addrPort) } else { secureAddrs = append(secureAddrs, addrPort) @@ -231,74 +139,60 @@ func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) { return addrs, secureAddrs } -// handleGetHealthCheck is the handler for the GET /health-check HTTP API. -func (svc *Service) handleGetHealthCheck(w http.ResponseWriter, _ *http.Request) { - _, _ = io.WriteString(w, "OK") -} - // type check -var _ agh.Service = (*Service)(nil) +var _ agh.ServiceWithConfig[*Config] = (*Service)(nil) // Start implements the [agh.Service] interface for *Service. svc may be nil. // After Start exits, all HTTP servers have tried to start, possibly failing and // writing error messages to the log. -func (svc *Service) Start() (err error) { +// +// TODO(a.garipov): Use the context for cancelation as well. +func (svc *Service) Start(ctx context.Context) (err error) { if svc == nil { return nil } - pprofEnabled := svc.pprof != nil - srvNum := len(svc.servers) + mathutil.BoolToNumber[int](pprofEnabled) + svc.logger.InfoContext(ctx, "starting") + defer svc.logger.InfoContext(ctx, "started") - wg := &sync.WaitGroup{} - wg.Add(srvNum) for _, srv := range svc.servers { - go serve(srv, wg) + go srv.serve(ctx, svc.logger) } - if pprofEnabled { - go serve(svc.pprof, wg) + if svc.pprof != nil { + go svc.pprof.serve(ctx, svc.logger) } - wg.Wait() + return svc.wait(ctx) +} + +// wait waits until either the context is canceled or all servers have started. +func (svc *Service) wait(ctx context.Context) (err error) { + for !svc.serversHaveStarted() { + select { + case <-ctx.Done(): + return ctx.Err() + default: + // Wait and let the other goroutines do their job. + runtime.Gosched() + } + } return nil } -// serve starts and runs srv and writes all errors into its log. -func serve(srv *http.Server, wg *sync.WaitGroup) { - addr := srv.Addr - defer log.OnPanic(addr) - - var proto string - var l net.Listener - var err error - if srv.TLSConfig == nil { - proto = "http" - l, err = net.Listen("tcp", addr) - } else { - proto = "https" - l, err = tls.Listen("tcp", addr, srv.TLSConfig) - } - if err != nil { - srv.ErrorLog.Printf("starting srv %s: binding: %s", addr, err) +// serversHaveStarted returns true if all servers have started serving. +func (svc *Service) serversHaveStarted() (started bool) { + started = len(svc.servers) != 0 + for _, srv := range svc.servers { + started = started && srv.localAddr() != nil } - // Update the server's address in case the address had the port zero, which - // would mean that a random available port was automatically chosen. - srv.Addr = l.Addr().String() - - log.Info("websvc: starting srv %s://%s", proto, srv.Addr) - - l = &waitListener{ - Listener: l, - firstAcceptWG: wg, + if svc.pprof != nil { + started = started && svc.pprof.localAddr() != nil } - err = srv.Serve(l) - if err != nil && !errors.Is(err, http.ErrServerClosed) { - srv.ErrorLog.Printf("starting srv %s: %s", addr, err) - } + return started } // Shutdown implements the [agh.Service] interface for *Service. svc may be @@ -308,20 +202,24 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) { return nil } + svc.logger.InfoContext(ctx, "shutting down") + defer svc.logger.InfoContext(ctx, "shut down") + defer func() { err = errors.Annotate(err, "shutting down: %w") }() var errs []error for _, srv := range svc.servers { - shutdownErr := srv.Shutdown(ctx) + shutdownErr := srv.shutdown(ctx) if shutdownErr != nil { - errs = append(errs, fmt.Errorf("srv %s: %w", srv.Addr, shutdownErr)) + // Don't wrap the error, because it's informative enough as is. + errs = append(errs, err) } } if svc.pprof != nil { - shutdownErr := svc.pprof.Shutdown(ctx) + shutdownErr := svc.pprof.shutdown(ctx) if shutdownErr != nil { - errs = append(errs, fmt.Errorf("pprof srv %s: %w", svc.pprof.Addr, shutdownErr)) + errs = append(errs, fmt.Errorf("pprof: %w", shutdownErr)) } } diff --git a/internal/next/websvc/websvc_internal_test.go b/internal/next/websvc/websvc_internal_test.go deleted file mode 100644 index 3509b193..00000000 --- a/internal/next/websvc/websvc_internal_test.go +++ /dev/null @@ -1,6 +0,0 @@ -package websvc - -import "time" - -// testTimeout is the common timeout for tests. -const testTimeout = 1 * time.Second diff --git a/internal/next/websvc/websvc_test.go b/internal/next/websvc/websvc_test.go index f9196aac..79e46ac6 100644 --- a/internal/next/websvc/websvc_test.go +++ b/internal/next/websvc/websvc_test.go @@ -15,6 +15,8 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" + "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/netutil/httputil" "github.com/AdguardTeam/golibs/netutil/urlutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil/fakefs" @@ -22,10 +24,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestMain(m *testing.M) { - testutil.DiscardLogOutput(m) -} - // testTimeout is the common timeout for tests. const testTimeout = 1 * time.Second @@ -81,8 +79,6 @@ func newConfigManager() (m *configManager) { // newTestServer creates and starts a new web service instance as well as its // sole address. It also registers a cleanup procedure, which shuts the // instance down. -// -// TODO(a.garipov): Use svc or remove it. func newTestServer( t testing.TB, confMgr websvc.ConfigManager, @@ -90,6 +86,7 @@ func newTestServer( t.Helper() c := &websvc.Config{ + Logger: slogutil.NewDiscardLogger(), Pprof: &websvc.PprofConfig{ Enabled: false, }, @@ -108,7 +105,7 @@ func newTestServer( svc, err := websvc.New(c) require.NoError(t, err) - err = svc.Start() + err = svc.Start(testutil.ContextWithTimeout(t, testTimeout)) require.NoError(t, err) testutil.CleanupAndRequireSuccess(t, func() (err error) { return svc.Shutdown(testutil.ContextWithTimeout(t, testTimeout)) @@ -184,10 +181,10 @@ func TestService_Start_getHealthCheck(t *testing.T) { u := &url.URL{ Scheme: urlutil.SchemeHTTP, Host: addr.String(), - Path: websvc.PathHealthCheck, + Path: websvc.PathPatternHealthCheck, } body := httpGet(t, u, http.StatusOK) - assert.Equal(t, []byte("OK"), body) + assert.Equal(t, []byte(httputil.HealthCheckHandler), body) }