diff --git a/internal/next/websvc/http_test.go b/internal/next/websvc/http_test.go index 7e32a568..297754f4 100644 --- a/internal/next/websvc/http_test.go +++ b/internal/next/websvc/http_test.go @@ -35,8 +35,8 @@ func TestService_HandlePatchSettingsHTTP(t *testing.T) { TLS: &tls.Config{ Certificates: []tls.Certificate{{}}, }, - Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")}, - SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")}, + Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")}, + SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")}, Timeout: 5 * time.Second, ForceHTTPS: true, }) diff --git a/internal/next/websvc/route.go b/internal/next/websvc/route.go index a04e6974..e2e5b06f 100644 --- a/internal/next/websvc/route.go +++ b/internal/next/websvc/route.go @@ -30,31 +30,31 @@ const ( // route registers all necessary handlers in mux. func (svc *Service) route(mux *http.ServeMux) { routes := []struct { - handler http.HandlerFunc + handler http.Handler pattern string isJSON bool }{{ - handler: svc.handleGetHealthCheck, + handler: httputil.HealthCheckHandler, pattern: routePatternHealthCheck, isJSON: false, }, { - handler: http.FileServer(http.FS(svc.frontend)).ServeHTTP, + handler: http.FileServer(http.FS(svc.frontend)), pattern: routePatternFrontend, isJSON: false, }, { - handler: svc.handleGetSettingsAll, + handler: http.HandlerFunc(svc.handleGetSettingsAll), pattern: routePatternGetV1SettingsAll, isJSON: true, }, { - handler: svc.handlePatchSettingsDNS, + handler: http.HandlerFunc(svc.handlePatchSettingsDNS), pattern: routePatternPatchV1SettingsDNS, isJSON: true, }, { - handler: svc.handlePatchSettingsHTTP, + handler: http.HandlerFunc(svc.handlePatchSettingsHTTP), pattern: routePatternPatchV1SettingsHTTP, isJSON: true, }, { - handler: svc.handleGetV1SystemInfo, + handler: http.HandlerFunc(svc.handleGetV1SystemInfo), pattern: routePatternGetV1SystemInfo, isJSON: true, }} diff --git a/internal/next/websvc/server.go b/internal/next/websvc/server.go new file mode 100644 index 00000000..df0a5d52 --- /dev/null +++ b/internal/next/websvc/server.go @@ -0,0 +1,155 @@ +package websvc + +import ( + "context" + "crypto/tls" + "fmt" + "log/slog" + "net" + "net/http" + "net/netip" + "net/url" + "sync" + "time" + + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/netutil/urlutil" +) + +// server contains an *http.Server as well as entities and data associated with +// it. +// +// TODO(a.garipov): Join with similar structs in other projects and move to +// golibs/netutil/httputil. +// +// TODO(a.garipov): Once the above standardization is complete, consider +// merging debugsvc and websvc into a single httpsvc. +type server struct { + // mu protects http, logger, tcpListener, and url. + mu *sync.Mutex + http *http.Server + logger *slog.Logger + tcpListener *net.TCPListener + url *url.URL + + tlsConf *tls.Config + initialAddr netip.AddrPort +} + +// newServer returns a *server that is ready to serve HTTP queries. The TCP +// listener is not started. handler must not be nil. +func newServer( + baseLogger *slog.Logger, + initialAddr netip.AddrPort, + tlsConf *tls.Config, + handler http.Handler, + timeout time.Duration, +) (s *server) { + u := &url.URL{ + Scheme: urlutil.SchemeHTTP, + Host: initialAddr.String(), + } + + if tlsConf != nil { + u.Scheme = urlutil.SchemeHTTPS + } + + logger := baseLogger.With("server", u) + + return &server{ + mu: &sync.Mutex{}, + http: &http.Server{ + Handler: handler, + ReadTimeout: timeout, + ReadHeaderTimeout: timeout, + WriteTimeout: timeout, + IdleTimeout: timeout, + ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError), + }, + logger: logger, + url: u, + + tlsConf: tlsConf, + initialAddr: initialAddr, + } +} + +// localAddr returns the local address of the server if the server has started +// listening; otherwise, it returns nil. +func (s *server) localAddr() (addr net.Addr) { + s.mu.Lock() + defer s.mu.Unlock() + + if l := s.tcpListener; l != nil { + return l.Addr() + } + + return nil +} + +// serve starts s. baseLogger is used as a base logger for s. If s fails to +// serve with anything other than [http.ErrServerClosed], it causes an unhandled +// panic. It is intended to be used as a goroutine. +// +// TODO(a.garipov): Improve error handling. +func (s *server) serve(ctx context.Context, baseLogger *slog.Logger) { + l, err := net.ListenTCP("tcp", net.TCPAddrFromAddrPort(s.initialAddr)) + if err != nil { + err = fmt.Errorf("listening tcp: %w", err) + + s.logger.ErrorContext(ctx, "listening tcp", slogutil.KeyError, err) + + panic(fmt.Errorf("websvc: %s", err)) + } + + func() { + s.mu.Lock() + defer s.mu.Unlock() + + s.tcpListener = l + + // Reassign the address in case the port was zero. + s.url.Host = l.Addr().String() + s.logger = baseLogger.With("server", s.url) + s.http.ErrorLog = slog.NewLogLogger(s.logger.Handler(), slog.LevelError) + }() + + s.logger.InfoContext(ctx, "starting") + defer s.logger.InfoContext(ctx, "started") + + err = s.http.Serve(l) + if err == nil || errors.Is(err, http.ErrServerClosed) { + return + } + + err = fmt.Errorf("serving: %w", err) + s.logger.ErrorContext(ctx, "serve failed", slogutil.KeyError, err) + panic(fmt.Errorf("websvc: %s", err)) +} + +// shutdown shuts s down. +func (s *server) shutdown(ctx context.Context) (err error) { + s.mu.Lock() + defer s.mu.Unlock() + + var errs []error + err = s.http.Shutdown(ctx) + if err != nil { + errs = append(errs, fmt.Errorf("shutting down server %s: %w", s.url, err)) + } + + // Close the listener separately, as it might not have been closed if the + // context has been canceled. + // + // NOTE: The listener could remain uninitialized if [net.ListenTCP] failed + // in [s.serve]. + if l := s.tcpListener; l != nil { + err = l.Close() + if err != nil && !errors.Is(err, net.ErrClosed) { + errs = append(errs, fmt.Errorf("closing listener for server %s: %w", s.url, err)) + } + } + + return errors.Join(errs...) +} diff --git a/internal/next/websvc/settings_test.go b/internal/next/websvc/settings_test.go index 6e2c7571..b2003556 100644 --- a/internal/next/websvc/settings_test.go +++ b/internal/next/websvc/settings_test.go @@ -1,7 +1,6 @@ package websvc_test import ( - "crypto/tls" "encoding/json" "net/http" "net/netip" @@ -30,13 +29,6 @@ func TestService_HandleGetSettingsAll(t *testing.T) { BootstrapPreferIPv6: true, } - wantWeb := &websvc.HTTPAPIHTTPSettings{ - Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")}, - SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")}, - Timeout: aghhttp.JSONDuration(5 * time.Second), - ForceHTTPS: true, - } - confMgr := newConfigManager() confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) { c, err := dnssvc.New(&dnssvc.Config{ @@ -52,35 +44,27 @@ func TestService_HandleGetSettingsAll(t *testing.T) { return c } - svc, err := websvc.New(&websvc.Config{ - Logger: slogutil.NewDiscardLogger(), - Pprof: &websvc.PprofConfig{ - Enabled: false, - }, - TLS: &tls.Config{ - Certificates: []tls.Certificate{{}}, - }, - Addresses: wantWeb.Addresses, - SecureAddresses: wantWeb.SecureAddresses, - Timeout: time.Duration(wantWeb.Timeout), - ForceHTTPS: true, - }) - require.NoError(t, err) - - confMgr.onWeb = func() (s agh.ServiceWithConfig[*websvc.Config]) { - return svc - } - - _, addr := newTestServer(t, confMgr) + svc, addr := newTestServer(t, confMgr) u := &url.URL{ Scheme: urlutil.SchemeHTTP, Host: addr.String(), Path: websvc.PathPatternV1SettingsAll, } + confMgr.onWeb = func() (s agh.ServiceWithConfig[*websvc.Config]) { + return svc + } + + wantWeb := &websvc.HTTPAPIHTTPSettings{ + Addresses: []netip.AddrPort{addr}, + SecureAddresses: nil, + Timeout: aghhttp.JSONDuration(testTimeout), + ForceHTTPS: false, + } + body := httpGet(t, u, http.StatusOK) resp := &websvc.RespGetV1SettingsAll{} - err = json.Unmarshal(body, resp) + err := json.Unmarshal(body, resp) require.NoError(t, err) assert.Equal(t, wantDNS, resp.DNS) diff --git a/internal/next/websvc/waitlistener.go b/internal/next/websvc/waitlistener.go deleted file mode 100644 index 8ab56269..00000000 --- a/internal/next/websvc/waitlistener.go +++ /dev/null @@ -1,31 +0,0 @@ -package websvc - -import ( - "net" - "sync" -) - -// Wait Listener - -// waitListener is a wrapper around a listener that also calls wg.Done() on the -// first call to Accept. It is useful in situations where it is important to -// catch the precise moment of the first call to Accept, for example when -// starting an HTTP server. -// -// TODO(a.garipov): Move to aghnet? -type waitListener struct { - net.Listener - - firstAcceptWG *sync.WaitGroup - firstAcceptOnce sync.Once -} - -// type check -var _ net.Listener = (*waitListener)(nil) - -// Accept implements the [net.Listener] interface for *waitListener. -func (l *waitListener) Accept() (conn net.Conn, err error) { - l.firstAcceptOnce.Do(l.firstAcceptWG.Done) - - return l.Listener.Accept() -} diff --git a/internal/next/websvc/waitlistener_internal_test.go b/internal/next/websvc/waitlistener_internal_test.go deleted file mode 100644 index 089c0531..00000000 --- a/internal/next/websvc/waitlistener_internal_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package websvc - -import ( - "net" - "sync" - "sync/atomic" - "testing" - - "github.com/AdguardTeam/golibs/testutil/fakenet" - "github.com/stretchr/testify/assert" -) - -func TestWaitListener_Accept(t *testing.T) { - var accepted atomic.Bool - var l net.Listener = &fakenet.Listener{ - OnAccept: func() (conn net.Conn, err error) { - accepted.Store(true) - - return nil, nil - }, - OnAddr: func() (addr net.Addr) { panic("not implemented") }, - OnClose: func() (err error) { panic("not implemented") }, - } - - wg := &sync.WaitGroup{} - wg.Add(1) - - go func() { - var wrapper net.Listener = &waitListener{ - Listener: l, - firstAcceptWG: wg, - } - - _, _ = wrapper.Accept() - }() - - wg.Wait() - - assert.Eventually(t, accepted.Load, testTimeout, testTimeout/10) -} diff --git a/internal/next/websvc/websvc.go b/internal/next/websvc/websvc.go index 58e08390..88325e98 100644 --- a/internal/next/websvc/websvc.go +++ b/internal/next/websvc/websvc.go @@ -10,24 +10,18 @@ import ( "context" "crypto/tls" "fmt" - "io" "io/fs" "log/slog" - "net" "net/http" "net/netip" - "net/url" "runtime" - "sync" "time" "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/logutil/slogutil" - "github.com/AdguardTeam/golibs/mathutil" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil/httputil" - "github.com/AdguardTeam/golibs/netutil/urlutil" ) // ConfigManager is the configuration manager interface. @@ -55,12 +49,6 @@ type Service struct { forceHTTPS bool } -// server is a wrapper around http.Server with additional information. -type server struct { - logURL *url.URL - http.Server -} - // New returns a new properly initialized *Service. If c is nil, svc is a nil // *Service that does nothing. The fields of c must not be modified after // calling New. @@ -123,44 +111,9 @@ func (svc *Service) setupPprof(c *PprofConfig) { svc.pprofPort = c.Port addr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), c.Port) - // TODO(a.garipov): Consider making pprof timeout configurable. svc.pprof = newServer(svc.logger, addr, nil, pprofMux, 10*time.Minute) } -// newServer returns a new *server with the given parameters. -func newServer( - logger *slog.Logger, - addr netip.AddrPort, - tlsConf *tls.Config, - h http.Handler, - timeout time.Duration, -) (srv *server) { - addrStr := addr.String() - logURL := &url.URL{ - Scheme: urlutil.SchemeHTTP, - Host: addrStr, - } - - if tlsConf != nil { - logURL.Scheme = urlutil.SchemeHTTPS - } - - l := logger.With("addr", logURL) - return &server{ - logURL: logURL, - Server: http.Server{ - Addr: addrStr, - Handler: h, - TLSConfig: tlsConf, - ReadTimeout: timeout, - WriteTimeout: timeout, - IdleTimeout: timeout, - ReadHeaderTimeout: timeout, - ErrorLog: slog.NewLogLogger(l.Handler(), slog.LevelError), - }, - } -} - // addrs returns all addresses on which this server serves the HTTP API. addrs // must not be called simultaneously with Start. If svc was initialized with // ":0" addresses, addrs will not return the actual bound ports until Start is @@ -171,14 +124,12 @@ func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) { } for _, srv := range svc.servers { - // Use MustParseAddrPort, since no errors should technically happen - // here, because all servers must have a valid address. - addrPort := netip.MustParseAddrPort(srv.Addr) + addrPort := netutil.NetAddrToAddrPort(srv.localAddr()) + if addrPort == (netip.AddrPort{}) { + continue + } - // [srv.Serve] will set TLSConfig to an almost empty value, so, instead - // of relying only on the nilness of TLSConfig, check the length of the - // certificates field as well. - if srv.TLSConfig == nil || len(srv.TLSConfig.Certificates) == 0 { + if srv.tlsConf == nil { addrs = append(addrs, addrPort) } else { secureAddrs = append(secureAddrs, addrPort) @@ -188,11 +139,6 @@ func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) { return addrs, secureAddrs } -// handleGetHealthCheck is the handler for the GET /health-check HTTP API. -func (svc *Service) handleGetHealthCheck(w http.ResponseWriter, _ *http.Request) { - _, _ = io.WriteString(w, "OK") -} - // type check var _ agh.ServiceWithConfig[*Config] = (*Service)(nil) @@ -206,59 +152,33 @@ func (svc *Service) Start(ctx context.Context) (err error) { return nil } - pprofEnabled := svc.pprof != nil - srvNum := len(svc.servers) + mathutil.BoolToNumber[int](pprofEnabled) + svc.logger.InfoContext(ctx, "starting") + defer svc.logger.InfoContext(ctx, "started") - wg := &sync.WaitGroup{} - wg.Add(srvNum) for _, srv := range svc.servers { - go serve(ctx, svc.logger, srv, wg) + go srv.serve(ctx, svc.logger) } - if pprofEnabled { - go serve(ctx, svc.logger, svc.pprof, wg) + if svc.pprof != nil { + go svc.pprof.serve(ctx, svc.logger) } - wg.Wait() + started := false + for !started { + select { + case <-ctx.Done(): + return ctx.Err() + default: + started = true + for _, srv := range svc.servers { + started = started && srv.localAddr() != nil + } + } + } return nil } -// serve starts and runs srv and writes all errors into its log. -func serve(ctx context.Context, logger *slog.Logger, srv *server, wg *sync.WaitGroup) { - defer slogutil.RecoverAndLog(ctx, logger) - - var l net.Listener - var err error - addr := srv.Addr - if srv.TLSConfig == nil { - l, err = net.Listen("tcp", addr) - } else { - l, err = tls.Listen("tcp", addr, srv.TLSConfig) - } - if err != nil { - logger.WarnContext(ctx, "binding", "tcp_addr", addr, slogutil.KeyError, err) - } - - // Update the server's address in case the address had the port zero, which - // would mean that a random available port was automatically chosen. - srv.Addr = l.Addr().String() - srv.logURL.Host = srv.Addr - - logger = logger.With("addr", srv.logURL) - logger.InfoContext(ctx, "starting") - - l = &waitListener{ - Listener: l, - firstAcceptWG: wg, - } - - err = srv.Serve(l) - if err != nil && !errors.Is(err, http.ErrServerClosed) { - logger.WarnContext(ctx, "starting server", "tcp_addr", addr, slogutil.KeyError, err) - } -} - // Shutdown implements the [agh.Service] interface for *Service. svc may be // nil. func (svc *Service) Shutdown(ctx context.Context) (err error) { @@ -266,20 +186,24 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) { return nil } + svc.logger.InfoContext(ctx, "shutting down") + defer svc.logger.InfoContext(ctx, "shut down") + defer func() { err = errors.Annotate(err, "shutting down: %w") }() var errs []error for _, srv := range svc.servers { - shutdownErr := srv.Shutdown(ctx) + shutdownErr := srv.shutdown(ctx) if shutdownErr != nil { - errs = append(errs, fmt.Errorf("srv %s: %w", srv.Addr, shutdownErr)) + // Don't wrap the error, because it's informative enough as is. + errs = append(errs, err) } } if svc.pprof != nil { - shutdownErr := svc.pprof.Shutdown(ctx) + shutdownErr := svc.pprof.shutdown(ctx) if shutdownErr != nil { - errs = append(errs, fmt.Errorf("pprof srv %s: %w", svc.pprof.Addr, shutdownErr)) + errs = append(errs, fmt.Errorf("pprof: %w", shutdownErr)) } } diff --git a/internal/next/websvc/websvc_internal_test.go b/internal/next/websvc/websvc_internal_test.go deleted file mode 100644 index 3509b193..00000000 --- a/internal/next/websvc/websvc_internal_test.go +++ /dev/null @@ -1,6 +0,0 @@ -package websvc - -import "time" - -// testTimeout is the common timeout for tests. -const testTimeout = 1 * time.Second diff --git a/internal/next/websvc/websvc_test.go b/internal/next/websvc/websvc_test.go index ccbf57b4..79e46ac6 100644 --- a/internal/next/websvc/websvc_test.go +++ b/internal/next/websvc/websvc_test.go @@ -16,6 +16,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/netutil/httputil" "github.com/AdguardTeam/golibs/netutil/urlutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil/fakefs" @@ -78,8 +79,6 @@ func newConfigManager() (m *configManager) { // newTestServer creates and starts a new web service instance as well as its // sole address. It also registers a cleanup procedure, which shuts the // instance down. -// -// TODO(a.garipov): Use svc or remove it. func newTestServer( t testing.TB, confMgr websvc.ConfigManager, @@ -187,5 +186,5 @@ func TestService_Start_getHealthCheck(t *testing.T) { body := httpGet(t, u, http.StatusOK) - assert.Equal(t, []byte("OK"), body) + assert.Equal(t, []byte(httputil.HealthCheckHandler), body) }