next/websvc: upd more

This commit is contained in:
Ainar Garipov 2024-11-12 16:53:15 +03:00
parent 0c64e6cfc6
commit d729aa150f
9 changed files with 210 additions and 225 deletions

View file

@ -35,8 +35,8 @@ func TestService_HandlePatchSettingsHTTP(t *testing.T) {
TLS: &tls.Config{
Certificates: []tls.Certificate{{}},
},
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")},
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")},
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
Timeout: 5 * time.Second,
ForceHTTPS: true,
})

View file

@ -30,31 +30,31 @@ const (
// route registers all necessary handlers in mux.
func (svc *Service) route(mux *http.ServeMux) {
routes := []struct {
handler http.HandlerFunc
handler http.Handler
pattern string
isJSON bool
}{{
handler: svc.handleGetHealthCheck,
handler: httputil.HealthCheckHandler,
pattern: routePatternHealthCheck,
isJSON: false,
}, {
handler: http.FileServer(http.FS(svc.frontend)).ServeHTTP,
handler: http.FileServer(http.FS(svc.frontend)),
pattern: routePatternFrontend,
isJSON: false,
}, {
handler: svc.handleGetSettingsAll,
handler: http.HandlerFunc(svc.handleGetSettingsAll),
pattern: routePatternGetV1SettingsAll,
isJSON: true,
}, {
handler: svc.handlePatchSettingsDNS,
handler: http.HandlerFunc(svc.handlePatchSettingsDNS),
pattern: routePatternPatchV1SettingsDNS,
isJSON: true,
}, {
handler: svc.handlePatchSettingsHTTP,
handler: http.HandlerFunc(svc.handlePatchSettingsHTTP),
pattern: routePatternPatchV1SettingsHTTP,
isJSON: true,
}, {
handler: svc.handleGetV1SystemInfo,
handler: http.HandlerFunc(svc.handleGetV1SystemInfo),
pattern: routePatternGetV1SystemInfo,
isJSON: true,
}}

View file

@ -0,0 +1,155 @@
package websvc
import (
"context"
"crypto/tls"
"fmt"
"log/slog"
"net"
"net/http"
"net/netip"
"net/url"
"sync"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
)
// server contains an *http.Server as well as entities and data associated with
// it.
//
// TODO(a.garipov): Join with similar structs in other projects and move to
// golibs/netutil/httputil.
//
// TODO(a.garipov): Once the above standardization is complete, consider
// merging debugsvc and websvc into a single httpsvc.
type server struct {
// mu protects http, logger, tcpListener, and url.
mu *sync.Mutex
http *http.Server
logger *slog.Logger
tcpListener *net.TCPListener
url *url.URL
tlsConf *tls.Config
initialAddr netip.AddrPort
}
// newServer returns a *server that is ready to serve HTTP queries. The TCP
// listener is not started. handler must not be nil.
func newServer(
baseLogger *slog.Logger,
initialAddr netip.AddrPort,
tlsConf *tls.Config,
handler http.Handler,
timeout time.Duration,
) (s *server) {
u := &url.URL{
Scheme: urlutil.SchemeHTTP,
Host: initialAddr.String(),
}
if tlsConf != nil {
u.Scheme = urlutil.SchemeHTTPS
}
logger := baseLogger.With("server", u)
return &server{
mu: &sync.Mutex{},
http: &http.Server{
Handler: handler,
ReadTimeout: timeout,
ReadHeaderTimeout: timeout,
WriteTimeout: timeout,
IdleTimeout: timeout,
ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError),
},
logger: logger,
url: u,
tlsConf: tlsConf,
initialAddr: initialAddr,
}
}
// localAddr returns the local address of the server if the server has started
// listening; otherwise, it returns nil.
func (s *server) localAddr() (addr net.Addr) {
s.mu.Lock()
defer s.mu.Unlock()
if l := s.tcpListener; l != nil {
return l.Addr()
}
return nil
}
// serve starts s. baseLogger is used as a base logger for s. If s fails to
// serve with anything other than [http.ErrServerClosed], it causes an unhandled
// panic. It is intended to be used as a goroutine.
//
// TODO(a.garipov): Improve error handling.
func (s *server) serve(ctx context.Context, baseLogger *slog.Logger) {
l, err := net.ListenTCP("tcp", net.TCPAddrFromAddrPort(s.initialAddr))
if err != nil {
err = fmt.Errorf("listening tcp: %w", err)
s.logger.ErrorContext(ctx, "listening tcp", slogutil.KeyError, err)
panic(fmt.Errorf("websvc: %s", err))
}
func() {
s.mu.Lock()
defer s.mu.Unlock()
s.tcpListener = l
// Reassign the address in case the port was zero.
s.url.Host = l.Addr().String()
s.logger = baseLogger.With("server", s.url)
s.http.ErrorLog = slog.NewLogLogger(s.logger.Handler(), slog.LevelError)
}()
s.logger.InfoContext(ctx, "starting")
defer s.logger.InfoContext(ctx, "started")
err = s.http.Serve(l)
if err == nil || errors.Is(err, http.ErrServerClosed) {
return
}
err = fmt.Errorf("serving: %w", err)
s.logger.ErrorContext(ctx, "serve failed", slogutil.KeyError, err)
panic(fmt.Errorf("websvc: %s", err))
}
// shutdown shuts s down.
func (s *server) shutdown(ctx context.Context) (err error) {
s.mu.Lock()
defer s.mu.Unlock()
var errs []error
err = s.http.Shutdown(ctx)
if err != nil {
errs = append(errs, fmt.Errorf("shutting down server %s: %w", s.url, err))
}
// Close the listener separately, as it might not have been closed if the
// context has been canceled.
//
// NOTE: The listener could remain uninitialized if [net.ListenTCP] failed
// in [s.serve].
if l := s.tcpListener; l != nil {
err = l.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
errs = append(errs, fmt.Errorf("closing listener for server %s: %w", s.url, err))
}
}
return errors.Join(errs...)
}

View file

@ -1,7 +1,6 @@
package websvc_test
import (
"crypto/tls"
"encoding/json"
"net/http"
"net/netip"
@ -30,13 +29,6 @@ func TestService_HandleGetSettingsAll(t *testing.T) {
BootstrapPreferIPv6: true,
}
wantWeb := &websvc.HTTPAPIHTTPSettings{
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")},
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")},
Timeout: aghhttp.JSONDuration(5 * time.Second),
ForceHTTPS: true,
}
confMgr := newConfigManager()
confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) {
c, err := dnssvc.New(&dnssvc.Config{
@ -52,35 +44,27 @@ func TestService_HandleGetSettingsAll(t *testing.T) {
return c
}
svc, err := websvc.New(&websvc.Config{
Logger: slogutil.NewDiscardLogger(),
Pprof: &websvc.PprofConfig{
Enabled: false,
},
TLS: &tls.Config{
Certificates: []tls.Certificate{{}},
},
Addresses: wantWeb.Addresses,
SecureAddresses: wantWeb.SecureAddresses,
Timeout: time.Duration(wantWeb.Timeout),
ForceHTTPS: true,
})
require.NoError(t, err)
confMgr.onWeb = func() (s agh.ServiceWithConfig[*websvc.Config]) {
return svc
}
_, addr := newTestServer(t, confMgr)
svc, addr := newTestServer(t, confMgr)
u := &url.URL{
Scheme: urlutil.SchemeHTTP,
Host: addr.String(),
Path: websvc.PathPatternV1SettingsAll,
}
confMgr.onWeb = func() (s agh.ServiceWithConfig[*websvc.Config]) {
return svc
}
wantWeb := &websvc.HTTPAPIHTTPSettings{
Addresses: []netip.AddrPort{addr},
SecureAddresses: nil,
Timeout: aghhttp.JSONDuration(testTimeout),
ForceHTTPS: false,
}
body := httpGet(t, u, http.StatusOK)
resp := &websvc.RespGetV1SettingsAll{}
err = json.Unmarshal(body, resp)
err := json.Unmarshal(body, resp)
require.NoError(t, err)
assert.Equal(t, wantDNS, resp.DNS)

View file

@ -1,31 +0,0 @@
package websvc
import (
"net"
"sync"
)
// Wait Listener
// waitListener is a wrapper around a listener that also calls wg.Done() on the
// first call to Accept. It is useful in situations where it is important to
// catch the precise moment of the first call to Accept, for example when
// starting an HTTP server.
//
// TODO(a.garipov): Move to aghnet?
type waitListener struct {
net.Listener
firstAcceptWG *sync.WaitGroup
firstAcceptOnce sync.Once
}
// type check
var _ net.Listener = (*waitListener)(nil)
// Accept implements the [net.Listener] interface for *waitListener.
func (l *waitListener) Accept() (conn net.Conn, err error) {
l.firstAcceptOnce.Do(l.firstAcceptWG.Done)
return l.Listener.Accept()
}

View file

@ -1,40 +0,0 @@
package websvc
import (
"net"
"sync"
"sync/atomic"
"testing"
"github.com/AdguardTeam/golibs/testutil/fakenet"
"github.com/stretchr/testify/assert"
)
func TestWaitListener_Accept(t *testing.T) {
var accepted atomic.Bool
var l net.Listener = &fakenet.Listener{
OnAccept: func() (conn net.Conn, err error) {
accepted.Store(true)
return nil, nil
},
OnAddr: func() (addr net.Addr) { panic("not implemented") },
OnClose: func() (err error) { panic("not implemented") },
}
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
var wrapper net.Listener = &waitListener{
Listener: l,
firstAcceptWG: wg,
}
_, _ = wrapper.Accept()
}()
wg.Wait()
assert.Eventually(t, accepted.Load, testTimeout, testTimeout/10)
}

View file

@ -10,24 +10,18 @@ import (
"context"
"crypto/tls"
"fmt"
"io"
"io/fs"
"log/slog"
"net"
"net/http"
"net/netip"
"net/url"
"runtime"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/mathutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/netutil/httputil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
)
// ConfigManager is the configuration manager interface.
@ -55,12 +49,6 @@ type Service struct {
forceHTTPS bool
}
// server is a wrapper around http.Server with additional information.
type server struct {
logURL *url.URL
http.Server
}
// New returns a new properly initialized *Service. If c is nil, svc is a nil
// *Service that does nothing. The fields of c must not be modified after
// calling New.
@ -123,44 +111,9 @@ func (svc *Service) setupPprof(c *PprofConfig) {
svc.pprofPort = c.Port
addr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), c.Port)
// TODO(a.garipov): Consider making pprof timeout configurable.
svc.pprof = newServer(svc.logger, addr, nil, pprofMux, 10*time.Minute)
}
// newServer returns a new *server with the given parameters.
func newServer(
logger *slog.Logger,
addr netip.AddrPort,
tlsConf *tls.Config,
h http.Handler,
timeout time.Duration,
) (srv *server) {
addrStr := addr.String()
logURL := &url.URL{
Scheme: urlutil.SchemeHTTP,
Host: addrStr,
}
if tlsConf != nil {
logURL.Scheme = urlutil.SchemeHTTPS
}
l := logger.With("addr", logURL)
return &server{
logURL: logURL,
Server: http.Server{
Addr: addrStr,
Handler: h,
TLSConfig: tlsConf,
ReadTimeout: timeout,
WriteTimeout: timeout,
IdleTimeout: timeout,
ReadHeaderTimeout: timeout,
ErrorLog: slog.NewLogLogger(l.Handler(), slog.LevelError),
},
}
}
// addrs returns all addresses on which this server serves the HTTP API. addrs
// must not be called simultaneously with Start. If svc was initialized with
// ":0" addresses, addrs will not return the actual bound ports until Start is
@ -171,14 +124,12 @@ func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) {
}
for _, srv := range svc.servers {
// Use MustParseAddrPort, since no errors should technically happen
// here, because all servers must have a valid address.
addrPort := netip.MustParseAddrPort(srv.Addr)
addrPort := netutil.NetAddrToAddrPort(srv.localAddr())
if addrPort == (netip.AddrPort{}) {
continue
}
// [srv.Serve] will set TLSConfig to an almost empty value, so, instead
// of relying only on the nilness of TLSConfig, check the length of the
// certificates field as well.
if srv.TLSConfig == nil || len(srv.TLSConfig.Certificates) == 0 {
if srv.tlsConf == nil {
addrs = append(addrs, addrPort)
} else {
secureAddrs = append(secureAddrs, addrPort)
@ -188,11 +139,6 @@ func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) {
return addrs, secureAddrs
}
// handleGetHealthCheck is the handler for the GET /health-check HTTP API.
func (svc *Service) handleGetHealthCheck(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "OK")
}
// type check
var _ agh.ServiceWithConfig[*Config] = (*Service)(nil)
@ -206,59 +152,33 @@ func (svc *Service) Start(ctx context.Context) (err error) {
return nil
}
pprofEnabled := svc.pprof != nil
srvNum := len(svc.servers) + mathutil.BoolToNumber[int](pprofEnabled)
svc.logger.InfoContext(ctx, "starting")
defer svc.logger.InfoContext(ctx, "started")
wg := &sync.WaitGroup{}
wg.Add(srvNum)
for _, srv := range svc.servers {
go serve(ctx, svc.logger, srv, wg)
go srv.serve(ctx, svc.logger)
}
if pprofEnabled {
go serve(ctx, svc.logger, svc.pprof, wg)
if svc.pprof != nil {
go svc.pprof.serve(ctx, svc.logger)
}
wg.Wait()
started := false
for !started {
select {
case <-ctx.Done():
return ctx.Err()
default:
started = true
for _, srv := range svc.servers {
started = started && srv.localAddr() != nil
}
}
}
return nil
}
// serve starts and runs srv and writes all errors into its log.
func serve(ctx context.Context, logger *slog.Logger, srv *server, wg *sync.WaitGroup) {
defer slogutil.RecoverAndLog(ctx, logger)
var l net.Listener
var err error
addr := srv.Addr
if srv.TLSConfig == nil {
l, err = net.Listen("tcp", addr)
} else {
l, err = tls.Listen("tcp", addr, srv.TLSConfig)
}
if err != nil {
logger.WarnContext(ctx, "binding", "tcp_addr", addr, slogutil.KeyError, err)
}
// Update the server's address in case the address had the port zero, which
// would mean that a random available port was automatically chosen.
srv.Addr = l.Addr().String()
srv.logURL.Host = srv.Addr
logger = logger.With("addr", srv.logURL)
logger.InfoContext(ctx, "starting")
l = &waitListener{
Listener: l,
firstAcceptWG: wg,
}
err = srv.Serve(l)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.WarnContext(ctx, "starting server", "tcp_addr", addr, slogutil.KeyError, err)
}
}
// Shutdown implements the [agh.Service] interface for *Service. svc may be
// nil.
func (svc *Service) Shutdown(ctx context.Context) (err error) {
@ -266,20 +186,24 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) {
return nil
}
svc.logger.InfoContext(ctx, "shutting down")
defer svc.logger.InfoContext(ctx, "shut down")
defer func() { err = errors.Annotate(err, "shutting down: %w") }()
var errs []error
for _, srv := range svc.servers {
shutdownErr := srv.Shutdown(ctx)
shutdownErr := srv.shutdown(ctx)
if shutdownErr != nil {
errs = append(errs, fmt.Errorf("srv %s: %w", srv.Addr, shutdownErr))
// Don't wrap the error, because it's informative enough as is.
errs = append(errs, err)
}
}
if svc.pprof != nil {
shutdownErr := svc.pprof.Shutdown(ctx)
shutdownErr := svc.pprof.shutdown(ctx)
if shutdownErr != nil {
errs = append(errs, fmt.Errorf("pprof srv %s: %w", svc.pprof.Addr, shutdownErr))
errs = append(errs, fmt.Errorf("pprof: %w", shutdownErr))
}
}

View file

@ -1,6 +0,0 @@
package websvc
import "time"
// testTimeout is the common timeout for tests.
const testTimeout = 1 * time.Second

View file

@ -16,6 +16,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil/httputil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/testutil/fakefs"
@ -78,8 +79,6 @@ func newConfigManager() (m *configManager) {
// newTestServer creates and starts a new web service instance as well as its
// sole address. It also registers a cleanup procedure, which shuts the
// instance down.
//
// TODO(a.garipov): Use svc or remove it.
func newTestServer(
t testing.TB,
confMgr websvc.ConfigManager,
@ -187,5 +186,5 @@ func TestService_Start_getHealthCheck(t *testing.T) {
body := httpGet(t, u, http.StatusOK)
assert.Equal(t, []byte("OK"), body)
assert.Equal(t, []byte(httputil.HealthCheckHandler), body)
}