diff --git a/Makefile b/Makefile index b4823bb7..cca89017 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,7 @@ YARN_INSTALL_FLAGS = $(YARN_FLAGS) --network-timeout 120000 --silent\ --ignore-engines --ignore-optional --ignore-platform\ --ignore-scripts -V1API = 0 +NEXTAPI = 0 # Macros for the build-release target. If FRONTEND_PREBUILT is 0, the # default, the macro $(BUILD_RELEASE_DEPS_$(FRONTEND_PREBUILT)) expands @@ -63,7 +63,7 @@ ENV = env\ PATH="$${PWD}/bin:$$( "$(GO.MACRO)" env GOPATH )/bin:$${PATH}"\ RACE='$(RACE)'\ SIGN='$(SIGN)'\ - V1API='$(V1API)'\ + NEXTAPI='$(NEXTAPI)'\ VERBOSE='$(VERBOSE)'\ VERSION='$(VERSION)'\ diff --git a/internal/aghchan/aghchan.go b/internal/aghchan/aghchan.go new file mode 100644 index 00000000..1da1790a --- /dev/null +++ b/internal/aghchan/aghchan.go @@ -0,0 +1,33 @@ +// Package aghchan contains channel utilities. +package aghchan + +import ( + "fmt" + "time" +) + +// Receive returns an error if it cannot receive a value form c before timeout +// runs out. +func Receive[T any](c <-chan T, timeout time.Duration) (v T, ok bool, err error) { + var zero T + timeoutCh := time.After(timeout) + select { + case <-timeoutCh: + // TODO(a.garipov): Consider implementing [errors.Aser] for + // os.ErrTimeout. + return zero, false, fmt.Errorf("did not receive after %s", timeout) + case v, ok = <-c: + return v, ok, nil + } +} + +// MustReceive panics if it cannot receive a value form c before timeout runs +// out. +func MustReceive[T any](c <-chan T, timeout time.Duration) (v T, ok bool) { + v, ok, err := Receive(c, timeout) + if err != nil { + panic(err) + } + + return v, ok +} diff --git a/internal/aghnet/hostscontainer_test.go b/internal/aghnet/hostscontainer_test.go index 1f75a3c9..d2637d85 100644 --- a/internal/aghnet/hostscontainer_test.go +++ b/internal/aghnet/hostscontainer_test.go @@ -10,9 +10,9 @@ import ( "testing/fstest" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghchan" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/urlfilter" @@ -163,15 +163,9 @@ func TestHostsContainer_refresh(t *testing.T) { checkRefresh := func(t *testing.T, want *HostsRecord) { t.Helper() - var ok bool - var upd *netutil.IPMap - select { - case upd, ok = <-hc.Upd(): - require.True(t, ok) - require.NotNil(t, upd) - case <-time.After(1 * time.Second): - t.Fatal("did not receive after 1s") - } + upd, ok := aghchan.MustReceive(hc.Upd(), 1*time.Second) + require.True(t, ok) + require.NotNil(t, upd) assert.Equal(t, 1, upd.Len()) diff --git a/internal/aghtest/interface.go b/internal/aghtest/interface.go index 2de9d372..7aae35ee 100644 --- a/internal/aghtest/interface.go +++ b/internal/aghtest/interface.go @@ -1,6 +1,7 @@ package aghtest import ( + "context" "io/fs" "net" @@ -15,6 +16,8 @@ import ( // Standard Library +// Package fs + // type check var _ fs.FS = &FS{} @@ -58,6 +61,8 @@ func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) { return fsys.OnStat(name) } +// Package net + // type check var _ net.Listener = (*Listener)(nil) @@ -83,32 +88,10 @@ func (l *Listener) Close() (err error) { return l.OnClose() } -// Module dnsproxy - -// type check -var _ upstream.Upstream = (*UpstreamMock)(nil) - -// UpstreamMock is a mock [upstream.Upstream] implementation for tests. -// -// TODO(a.garipov): Replace with all uses of Upstream with UpstreamMock and -// rename it to just Upstream. -type UpstreamMock struct { - OnAddress func() (addr string) - OnExchange func(req *dns.Msg) (resp *dns.Msg, err error) -} - -// Address implements the [upstream.Upstream] interface for *UpstreamMock. -func (u *UpstreamMock) Address() (addr string) { - return u.OnAddress() -} - -// Exchange implements the [upstream.Upstream] interface for *UpstreamMock. -func (u *UpstreamMock) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { - return u.OnExchange(req) -} - // Module AdGuardHome +// Package aghos + // type check var _ aghos.FSWatcher = (*FSWatcher)(nil) @@ -133,3 +116,57 @@ func (w *FSWatcher) Add(name string) (err error) { func (w *FSWatcher) Close() (err error) { return w.OnClose() } + +// Package websvc + +// ServiceWithConfig is a mock [websvc.ServiceWithConfig] implementation for +// tests. +type ServiceWithConfig[ConfigType any] struct { + OnStart func() (err error) + OnShutdown func(ctx context.Context) (err error) + OnConfig func() (c ConfigType) +} + +// Start implements the [websvc.ServiceWithConfig] interface for +// *ServiceWithConfig. +func (s *ServiceWithConfig[_]) Start() (err error) { + return s.OnStart() +} + +// Shutdown implements the [websvc.ServiceWithConfig] interface for +// *ServiceWithConfig. +func (s *ServiceWithConfig[_]) Shutdown(ctx context.Context) (err error) { + return s.OnShutdown(ctx) +} + +// Config implements the [websvc.ServiceWithConfig] interface for +// *ServiceWithConfig. +func (s *ServiceWithConfig[ConfigType]) Config() (c ConfigType) { + return s.OnConfig() +} + +// Module dnsproxy + +// Package upstream + +// type check +var _ upstream.Upstream = (*UpstreamMock)(nil) + +// UpstreamMock is a mock [upstream.Upstream] implementation for tests. +// +// TODO(a.garipov): Replace with all uses of Upstream with UpstreamMock and +// rename it to just Upstream. +type UpstreamMock struct { + OnAddress func() (addr string) + OnExchange func(req *dns.Msg) (resp *dns.Msg, err error) +} + +// Address implements the [upstream.Upstream] interface for *UpstreamMock. +func (u *UpstreamMock) Address() (addr string) { + return u.OnAddress() +} + +// Exchange implements the [upstream.Upstream] interface for *UpstreamMock. +func (u *UpstreamMock) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { + return u.OnExchange(req) +} diff --git a/internal/aghtest/interface_test.go b/internal/aghtest/interface_test.go index 5a465c2c..bd2c0823 100644 --- a/internal/aghtest/interface_test.go +++ b/internal/aghtest/interface_test.go @@ -1,9 +1,9 @@ package aghtest_test import ( - "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" ) // type check -var _ aghos.FSWatcher = (*aghtest.FSWatcher)(nil) +var _ websvc.ServiceWithConfig[struct{}] = (*aghtest.ServiceWithConfig[struct{}])(nil) diff --git a/internal/v1/agh/agh.go b/internal/next/agh/agh.go similarity index 100% rename from internal/v1/agh/agh.go rename to internal/next/agh/agh.go diff --git a/internal/v1/cmd/cmd.go b/internal/next/cmd/cmd.go similarity index 80% rename from internal/v1/cmd/cmd.go rename to internal/next/cmd/cmd.go index 2f61509b..5b329abf 100644 --- a/internal/v1/cmd/cmd.go +++ b/internal/next/cmd/cmd.go @@ -11,29 +11,32 @@ import ( "net/netip" "time" - "github.com/AdguardTeam/AdGuardHome/internal/v1/websvc" + "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" "github.com/AdguardTeam/golibs/log" ) // Main is the entry point of application. func Main(clientBuildFS fs.FS) { - // # Initial Configuration + // Initial Configuration start := time.Now() rand.Seed(start.UnixNano()) // TODO(a.garipov): Set up logging. - // # Web Service + // Web Service // TODO(a.garipov): Use in the Web service. _ = clientBuildFS // TODO(a.garipov): Make configurable. web := websvc.New(&websvc.Config{ - Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:3001")}, - Start: start, - Timeout: 60 * time.Second, + // TODO(a.garipov): Use an actual implementation. + ConfigManager: nil, + Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:3001")}, + Start: start, + Timeout: 60 * time.Second, + ForceHTTPS: false, }) err := web.Start() diff --git a/internal/v1/cmd/signal.go b/internal/next/cmd/signal.go similarity index 96% rename from internal/v1/cmd/signal.go rename to internal/next/cmd/signal.go index b66075f6..122f3f2c 100644 --- a/internal/v1/cmd/signal.go +++ b/internal/next/cmd/signal.go @@ -4,7 +4,7 @@ import ( "os" "github.com/AdguardTeam/AdGuardHome/internal/aghos" - "github.com/AdguardTeam/AdGuardHome/internal/v1/agh" + "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/golibs/log" ) diff --git a/internal/v1/dnssvc/dnssvc.go b/internal/next/dnssvc/dnssvc.go similarity index 77% rename from internal/v1/dnssvc/dnssvc.go rename to internal/next/dnssvc/dnssvc.go index ffe5b080..f25fa294 100644 --- a/internal/v1/dnssvc/dnssvc.go +++ b/internal/next/dnssvc/dnssvc.go @@ -9,9 +9,10 @@ import ( "fmt" "net" "net/netip" + "sync/atomic" "time" - "github.com/AdguardTeam/AdGuardHome/internal/v1/agh" + "github.com/AdguardTeam/AdGuardHome/internal/next/agh" // TODO(a.garipov): Add a “dnsproxy proxy” package to shield us from changes // and replacement of module dnsproxy. "github.com/AdguardTeam/dnsproxy/proxy" @@ -47,6 +48,14 @@ type Config struct { // Service is the AdGuard Home DNS service. A nil *Service is a valid // [agh.Service] that does nothing. type Service struct { + // running is an atomic boolean value. Keep it the first value in the + // struct to ensure atomic alignment. 0 means that the service is not + // running, 1 means that it is running. + // + // TODO(a.garipov): Use [atomic.Bool] in Go 1.19 or get rid of it + // completely. + running uint64 + proxy *proxy.Proxy bootstraps []string upstreams []string @@ -160,6 +169,17 @@ func (svc *Service) Start() (err error) { return nil } + defer func() { + // TODO(a.garipov): [proxy.Proxy.Start] doesn't actually have any way to + // tell when all servers are actually up, so at best this is merely an + // assumption. + if err != nil { + atomic.StoreUint64(&svc.running, 0) + } else { + atomic.StoreUint64(&svc.running, 1) + } + }() + return svc.proxy.Start() } @@ -173,13 +193,27 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) { return svc.proxy.Stop() } -// Config returns the current configuration of the web service. +// Config returns the current configuration of the web service. Config must not +// be called simultaneously with Start. If svc was initialized with ":0" +// addresses, addrs will not return the actual bound ports until Start is +// finished. func (svc *Service) Config() (c *Config) { // TODO(a.garipov): Do we need to get the TCP addresses separately? - udpAddrs := svc.proxy.Addrs(proxy.ProtoUDP) - addrs := make([]netip.AddrPort, len(udpAddrs)) - for i, a := range udpAddrs { - addrs[i] = a.(*net.UDPAddr).AddrPort() + + var addrs []netip.AddrPort + if atomic.LoadUint64(&svc.running) == 1 { + udpAddrs := svc.proxy.Addrs(proxy.ProtoUDP) + addrs = make([]netip.AddrPort, len(udpAddrs)) + for i, a := range udpAddrs { + addrs[i] = a.(*net.UDPAddr).AddrPort() + } + } else { + conf := svc.proxy.Config + udpAddrs := conf.UDPListenAddr + addrs = make([]netip.AddrPort, len(udpAddrs)) + for i, a := range udpAddrs { + addrs[i] = a.AddrPort() + } } c = &Config{ diff --git a/internal/v1/dnssvc/dnssvc_test.go b/internal/next/dnssvc/dnssvc_test.go similarity index 97% rename from internal/v1/dnssvc/dnssvc_test.go rename to internal/next/dnssvc/dnssvc_test.go index 5bc3b562..8205897c 100644 --- a/internal/v1/dnssvc/dnssvc_test.go +++ b/internal/next/dnssvc/dnssvc_test.go @@ -7,7 +7,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" - "github.com/AdguardTeam/AdGuardHome/internal/v1/dnssvc" + "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/miekg/dns" "github.com/stretchr/testify/assert" diff --git a/internal/next/websvc/dns.go b/internal/next/websvc/dns.go new file mode 100644 index 00000000..8846813d --- /dev/null +++ b/internal/next/websvc/dns.go @@ -0,0 +1,84 @@ +package websvc + +import ( + "encoding/json" + "fmt" + "net/http" + "net/netip" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" +) + +// DNS Settings Handlers + +// ReqPatchSettingsDNS describes the request to the PATCH /api/v1/settings/dns +// HTTP API. +type ReqPatchSettingsDNS struct { + // TODO(a.garipov): Add more as we go. + + Addresses []netip.AddrPort `json:"addresses"` + BootstrapServers []string `json:"bootstrap_servers"` + UpstreamServers []string `json:"upstream_servers"` + UpstreamTimeout JSONDuration `json:"upstream_timeout"` +} + +// HTTPAPIDNSSettings are the DNS settings as used by the HTTP API. See the +// DnsSettings object in the OpenAPI specification. +type HTTPAPIDNSSettings struct { + // TODO(a.garipov): Add more as we go. + + Addresses []netip.AddrPort `json:"addresses"` + BootstrapServers []string `json:"bootstrap_servers"` + UpstreamServers []string `json:"upstream_servers"` + UpstreamTimeout JSONDuration `json:"upstream_timeout"` +} + +// handlePatchSettingsDNS is the handler for the PATCH /api/v1/settings/dns HTTP +// API. +func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Request) { + req := &ReqPatchSettingsDNS{ + Addresses: []netip.AddrPort{}, + BootstrapServers: []string{}, + UpstreamServers: []string{}, + } + + // TODO(a.garipov): Validate nulls and proper JSON patch. + + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + writeJSONErrorResponse(w, r, fmt.Errorf("decoding: %w", err)) + + return + } + + newConf := &dnssvc.Config{ + Addresses: req.Addresses, + BootstrapServers: req.BootstrapServers, + UpstreamServers: req.UpstreamServers, + UpstreamTimeout: time.Duration(req.UpstreamTimeout), + } + + ctx := r.Context() + err = svc.confMgr.UpdateDNS(ctx, newConf) + if err != nil { + writeJSONErrorResponse(w, r, fmt.Errorf("updating: %w", err)) + + return + } + + newSvc := svc.confMgr.DNS() + err = newSvc.Start() + if err != nil { + writeJSONErrorResponse(w, r, fmt.Errorf("starting new service: %w", err)) + + return + } + + writeJSONOKResponse(w, r, &HTTPAPIDNSSettings{ + Addresses: newConf.Addresses, + BootstrapServers: newConf.BootstrapServers, + UpstreamServers: newConf.UpstreamServers, + UpstreamTimeout: JSONDuration(newConf.UpstreamTimeout), + }) +} diff --git a/internal/next/websvc/dns_test.go b/internal/next/websvc/dns_test.go new file mode 100644 index 00000000..f774c3d8 --- /dev/null +++ b/internal/next/websvc/dns_test.go @@ -0,0 +1,68 @@ +package websvc_test + +import ( + "context" + "encoding/json" + "net/http" + "net/netip" + "net/url" + "sync/atomic" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" + "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestService_HandlePatchSettingsDNS(t *testing.T) { + wantDNS := &websvc.HTTPAPIDNSSettings{ + Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.1.1:53")}, + BootstrapServers: []string{"1.0.0.1"}, + UpstreamServers: []string{"1.1.1.1"}, + UpstreamTimeout: websvc.JSONDuration(2 * time.Second), + } + + // TODO(a.garipov): Use [atomic.Bool] in Go 1.19. + var numStarted uint64 + confMgr := newConfigManager() + confMgr.onDNS = func() (s websvc.ServiceWithConfig[*dnssvc.Config]) { + return &aghtest.ServiceWithConfig[*dnssvc.Config]{ + OnStart: func() (err error) { + atomic.AddUint64(&numStarted, 1) + + return nil + }, + OnShutdown: func(_ context.Context) (err error) { panic("not implemented") }, + OnConfig: func() (c *dnssvc.Config) { panic("not implemented") }, + } + } + confMgr.onUpdateDNS = func(ctx context.Context, c *dnssvc.Config) (err error) { + return nil + } + + _, addr := newTestServer(t, confMgr) + u := &url.URL{ + Scheme: "http", + Host: addr.String(), + Path: websvc.PathV1SettingsDNS, + } + + req := jobj{ + "addresses": wantDNS.Addresses, + "bootstrap_servers": wantDNS.BootstrapServers, + "upstream_servers": wantDNS.UpstreamServers, + "upstream_timeout": wantDNS.UpstreamTimeout, + } + + respBody := httpPatch(t, u, req, http.StatusOK) + resp := &websvc.HTTPAPIDNSSettings{} + err := json.Unmarshal(respBody, resp) + require.NoError(t, err) + + assert.Equal(t, uint64(1), numStarted) + assert.Equal(t, wantDNS, resp) + assert.Equal(t, wantDNS, resp) +} diff --git a/internal/next/websvc/http.go b/internal/next/websvc/http.go new file mode 100644 index 00000000..b58eecb9 --- /dev/null +++ b/internal/next/websvc/http.go @@ -0,0 +1,109 @@ +package websvc + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/netip" + "time" + + "github.com/AdguardTeam/golibs/log" +) + +// HTTP Settings Handlers + +// ReqPatchSettingsHTTP describes the request to the PATCH /api/v1/settings/http +// HTTP API. +type ReqPatchSettingsHTTP struct { + // TODO(a.garipov): Add more as we go. + // + // TODO(a.garipov): Add wait time. + + Addresses []netip.AddrPort `json:"addresses"` + SecureAddresses []netip.AddrPort `json:"secure_addresses"` + Timeout JSONDuration `json:"timeout"` +} + +// HTTPAPIHTTPSettings are the HTTP settings as used by the HTTP API. See the +// HttpSettings object in the OpenAPI specification. +type HTTPAPIHTTPSettings struct { + // TODO(a.garipov): Add more as we go. + + Addresses []netip.AddrPort `json:"addresses"` + SecureAddresses []netip.AddrPort `json:"secure_addresses"` + Timeout JSONDuration `json:"timeout"` + ForceHTTPS bool `json:"force_https"` +} + +// handlePatchSettingsHTTP is the handler for the PATCH /api/v1/settings/http +// HTTP API. +func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Request) { + req := &ReqPatchSettingsHTTP{} + + // TODO(a.garipov): Validate nulls and proper JSON patch. + + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + writeJSONErrorResponse(w, r, fmt.Errorf("decoding: %w", err)) + + return + } + + newConf := &Config{ + ConfigManager: svc.confMgr, + TLS: svc.tls, + Addresses: req.Addresses, + SecureAddresses: req.SecureAddresses, + Timeout: time.Duration(req.Timeout), + ForceHTTPS: svc.forceHTTPS, + } + + writeJSONOKResponse(w, r, &HTTPAPIHTTPSettings{ + Addresses: newConf.Addresses, + SecureAddresses: newConf.SecureAddresses, + Timeout: JSONDuration(newConf.Timeout), + ForceHTTPS: newConf.ForceHTTPS, + }) + + cancelUpd := func() {} + updCtx := context.Background() + + ctx := r.Context() + if deadline, ok := ctx.Deadline(); ok { + updCtx, cancelUpd = context.WithDeadline(updCtx, deadline) + } + + // Launch the new HTTP service in a separate goroutine to let this handler + // finish and thus, this server to shutdown. + go func() { + defer cancelUpd() + + updErr := svc.confMgr.UpdateWeb(updCtx, newConf) + if updErr != nil { + writeJSONErrorResponse(w, r, fmt.Errorf("updating: %w", updErr)) + + return + } + + // TODO(a.garipov): Consider better ways to do this. + const maxUpdDur = 10 * time.Second + updStart := time.Now() + var newSvc ServiceWithConfig[*Config] + for newSvc = svc.confMgr.Web(); newSvc == svc; { + if time.Since(updStart) >= maxUpdDur { + log.Error("websvc: failed to update svc after %s", maxUpdDur) + + return + } + + log.Debug("websvc: waiting for new websvc to be configured") + time.Sleep(1 * time.Second) + } + + updErr = newSvc.Start() + if updErr != nil { + log.Error("websvc: new svc failed to start with error: %s", updErr) + } + }() +} diff --git a/internal/next/websvc/http_test.go b/internal/next/websvc/http_test.go new file mode 100644 index 00000000..baf384da --- /dev/null +++ b/internal/next/websvc/http_test.go @@ -0,0 +1,62 @@ +package websvc_test + +import ( + "context" + "crypto/tls" + "encoding/json" + "net/http" + "net/netip" + "net/url" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestService_HandlePatchSettingsHTTP(t *testing.T) { + wantWeb := &websvc.HTTPAPIHTTPSettings{ + Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.1.1:80")}, + SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.1.1:443")}, + Timeout: websvc.JSONDuration(10 * time.Second), + ForceHTTPS: false, + } + + confMgr := newConfigManager() + confMgr.onWeb = func() (s websvc.ServiceWithConfig[*websvc.Config]) { + return websvc.New(&websvc.Config{ + TLS: &tls.Config{ + Certificates: []tls.Certificate{{}}, + }, + Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")}, + SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")}, + Timeout: 5 * time.Second, + ForceHTTPS: true, + }) + } + confMgr.onUpdateWeb = func(ctx context.Context, c *websvc.Config) (err error) { + return nil + } + + _, addr := newTestServer(t, confMgr) + u := &url.URL{ + Scheme: "http", + Host: addr.String(), + Path: websvc.PathV1SettingsHTTP, + } + + req := jobj{ + "addresses": wantWeb.Addresses, + "secure_addresses": wantWeb.SecureAddresses, + "timeout": wantWeb.Timeout, + "force_https": wantWeb.ForceHTTPS, + } + + respBody := httpPatch(t, u, req, http.StatusOK) + resp := &websvc.HTTPAPIHTTPSettings{} + err := json.Unmarshal(respBody, resp) + require.NoError(t, err) + + assert.Equal(t, wantWeb, resp) +} diff --git a/internal/next/websvc/json.go b/internal/next/websvc/json.go new file mode 100644 index 00000000..fa2010a8 --- /dev/null +++ b/internal/next/websvc/json.go @@ -0,0 +1,143 @@ +package websvc + +import ( + "encoding/json" + "fmt" + "net/http" + "strconv" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" + "github.com/AdguardTeam/golibs/log" +) + +// JSON Utilities + +// nsecPerMsec is the number of nanoseconds in a millisecond. +const nsecPerMsec = float64(time.Millisecond / time.Nanosecond) + +// JSONDuration is a time.Duration that can be decoded from JSON and encoded +// into JSON according to our API conventions. +type JSONDuration time.Duration + +// type check +var _ json.Marshaler = JSONDuration(0) + +// MarshalJSON implements the json.Marshaler interface for JSONDuration. err is +// always nil. +func (d JSONDuration) MarshalJSON() (b []byte, err error) { + msec := float64(time.Duration(d)) / nsecPerMsec + b = strconv.AppendFloat(nil, msec, 'f', -1, 64) + + return b, nil +} + +// type check +var _ json.Unmarshaler = (*JSONDuration)(nil) + +// UnmarshalJSON implements the json.Marshaler interface for *JSONDuration. +func (d *JSONDuration) UnmarshalJSON(b []byte) (err error) { + if d == nil { + return fmt.Errorf("json duration is nil") + } + + msec, err := strconv.ParseFloat(string(b), 64) + if err != nil { + return fmt.Errorf("parsing json time: %w", err) + } + + *d = JSONDuration(int64(msec * nsecPerMsec)) + + return nil +} + +// JSONTime is a time.Time that can be decoded from JSON and encoded into JSON +// according to our API conventions. +type JSONTime time.Time + +// type check +var _ json.Marshaler = JSONTime{} + +// MarshalJSON implements the json.Marshaler interface for JSONTime. err is +// always nil. +func (t JSONTime) MarshalJSON() (b []byte, err error) { + msec := float64(time.Time(t).UnixNano()) / nsecPerMsec + b = strconv.AppendFloat(nil, msec, 'f', -1, 64) + + return b, nil +} + +// type check +var _ json.Unmarshaler = (*JSONTime)(nil) + +// UnmarshalJSON implements the json.Marshaler interface for *JSONTime. +func (t *JSONTime) UnmarshalJSON(b []byte) (err error) { + if t == nil { + return fmt.Errorf("json time is nil") + } + + msec, err := strconv.ParseFloat(string(b), 64) + if err != nil { + return fmt.Errorf("parsing json time: %w", err) + } + + *t = JSONTime(time.Unix(0, int64(msec*nsecPerMsec)).UTC()) + + return nil +} + +// writeJSONOKResponse writes headers with the code 200 OK, encodes v into w, +// and logs any errors it encounters. r is used to get additional information +// from the request. +func writeJSONOKResponse(w http.ResponseWriter, r *http.Request, v any) { + writeJSONResponse(w, r, v, http.StatusOK) +} + +// writeJSONResponse writes headers with code, encodes v into w, and logs any +// errors it encounters. r is used to get additional information from the +// request. +func writeJSONResponse(w http.ResponseWriter, r *http.Request, v any, code int) { + // TODO(a.garipov): Put some of these to a middleware. + h := w.Header() + h.Set(aghhttp.HdrNameContentType, aghhttp.HdrValApplicationJSON) + h.Set(aghhttp.HdrNameServer, aghhttp.UserAgent()) + + w.WriteHeader(code) + + err := json.NewEncoder(w).Encode(v) + if err != nil { + log.Error("websvc: writing resp to %s %s: %s", r.Method, r.URL.Path, err) + } +} + +// ErrorCode is the error code as used by the HTTP API. See the ErrorCode +// definition in the OpenAPI specification. +type ErrorCode string + +// ErrorCode constants. +// +// TODO(a.garipov): Expand and document codes. +const ( + // ErrorCodeTMP000 is the temporary error code used for all errors. + ErrorCodeTMP000 = "" +) + +// HTTPAPIErrorResp is the error response as used by the HTTP API. See the +// BadRequestResp, InternalServerErrorResp, and similar objects in the OpenAPI +// specification. +type HTTPAPIErrorResp struct { + Code ErrorCode `json:"code"` + Msg string `json:"msg"` +} + +// writeJSONErrorResponse encodes err as a JSON error into w, and logs any +// errors it encounters. r is used to get additional information from the +// request. +func writeJSONErrorResponse(w http.ResponseWriter, r *http.Request, err error) { + log.Error("websvc: %s %s: %s", r.Method, r.URL.Path, err) + + writeJSONResponse(w, r, &HTTPAPIErrorResp{ + Code: ErrorCodeTMP000, + Msg: err.Error(), + }, http.StatusUnprocessableEntity) +} diff --git a/internal/next/websvc/json_test.go b/internal/next/websvc/json_test.go new file mode 100644 index 00000000..90874958 --- /dev/null +++ b/internal/next/websvc/json_test.go @@ -0,0 +1,114 @@ +package websvc_test + +import ( + "encoding/json" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" + "github.com/AdguardTeam/golibs/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testJSONTime is the JSON time for tests. +var testJSONTime = websvc.JSONTime(time.Unix(1_234_567_890, 123_456_000).UTC()) + +// testJSONTimeStr is the string with the JSON encoding of testJSONTime. +const testJSONTimeStr = "1234567890123.456" + +func TestJSONTime_MarshalJSON(t *testing.T) { + testCases := []struct { + name string + wantErrMsg string + in websvc.JSONTime + want []byte + }{{ + name: "unix_zero", + wantErrMsg: "", + in: websvc.JSONTime(time.Unix(0, 0)), + want: []byte("0"), + }, { + name: "empty", + wantErrMsg: "", + in: websvc.JSONTime{}, + want: []byte("-6795364578871.345"), + }, { + name: "time", + wantErrMsg: "", + in: testJSONTime, + want: []byte(testJSONTimeStr), + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := tc.in.MarshalJSON() + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + + assert.Equal(t, tc.want, got) + }) + } + + t.Run("json", func(t *testing.T) { + in := &struct { + A websvc.JSONTime + }{ + A: testJSONTime, + } + + got, err := json.Marshal(in) + require.NoError(t, err) + + assert.Equal(t, []byte(`{"A":`+testJSONTimeStr+`}`), got) + }) +} + +func TestJSONTime_UnmarshalJSON(t *testing.T) { + testCases := []struct { + name string + wantErrMsg string + want websvc.JSONTime + data []byte + }{{ + name: "time", + wantErrMsg: "", + want: testJSONTime, + data: []byte(testJSONTimeStr), + }, { + name: "bad", + wantErrMsg: `parsing json time: strconv.ParseFloat: parsing "{}": ` + + `invalid syntax`, + want: websvc.JSONTime{}, + data: []byte(`{}`), + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var got websvc.JSONTime + err := got.UnmarshalJSON(tc.data) + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + + assert.Equal(t, tc.want, got) + }) + } + + t.Run("nil", func(t *testing.T) { + err := (*websvc.JSONTime)(nil).UnmarshalJSON([]byte("0")) + require.Error(t, err) + + msg := err.Error() + assert.Equal(t, "json time is nil", msg) + }) + + t.Run("json", func(t *testing.T) { + want := testJSONTime + var got struct { + A websvc.JSONTime + } + + err := json.Unmarshal([]byte(`{"A":`+testJSONTimeStr+`}`), &got) + require.NoError(t, err) + + assert.Equal(t, want, got.A) + }) +} diff --git a/internal/v1/websvc/middleware.go b/internal/next/websvc/middleware.go similarity index 100% rename from internal/v1/websvc/middleware.go rename to internal/next/websvc/middleware.go diff --git a/internal/next/websvc/path.go b/internal/next/websvc/path.go new file mode 100644 index 00000000..e38a1d60 --- /dev/null +++ b/internal/next/websvc/path.go @@ -0,0 +1,11 @@ +package websvc + +// Path constants +const ( + PathHealthCheck = "/health-check" + + PathV1SettingsAll = "/api/v1/settings/all" + PathV1SettingsDNS = "/api/v1/settings/dns" + PathV1SettingsHTTP = "/api/v1/settings/http" + PathV1SystemInfo = "/api/v1/system/info" +) diff --git a/internal/next/websvc/settings.go b/internal/next/websvc/settings.go new file mode 100644 index 00000000..b6c5a80a --- /dev/null +++ b/internal/next/websvc/settings.go @@ -0,0 +1,42 @@ +package websvc + +import ( + "net/http" +) + +// All Settings Handlers + +// RespGetV1SettingsAll describes the response of the GET /api/v1/settings/all +// HTTP API. +type RespGetV1SettingsAll struct { + // TODO(a.garipov): Add more as we go. + + DNS *HTTPAPIDNSSettings `json:"dns"` + HTTP *HTTPAPIHTTPSettings `json:"http"` +} + +// handleGetSettingsAll is the handler for the GET /api/v1/settings/all HTTP +// API. +func (svc *Service) handleGetSettingsAll(w http.ResponseWriter, r *http.Request) { + dnsSvc := svc.confMgr.DNS() + dnsConf := dnsSvc.Config() + + webSvc := svc.confMgr.Web() + httpConf := webSvc.Config() + + // TODO(a.garipov): Add all currently supported parameters. + writeJSONOKResponse(w, r, &RespGetV1SettingsAll{ + DNS: &HTTPAPIDNSSettings{ + Addresses: dnsConf.Addresses, + BootstrapServers: dnsConf.BootstrapServers, + UpstreamServers: dnsConf.UpstreamServers, + UpstreamTimeout: JSONDuration(dnsConf.UpstreamTimeout), + }, + HTTP: &HTTPAPIHTTPSettings{ + Addresses: httpConf.Addresses, + SecureAddresses: httpConf.SecureAddresses, + Timeout: JSONDuration(httpConf.Timeout), + ForceHTTPS: httpConf.ForceHTTPS, + }, + }) +} diff --git a/internal/next/websvc/settings_test.go b/internal/next/websvc/settings_test.go new file mode 100644 index 00000000..dadb4b55 --- /dev/null +++ b/internal/next/websvc/settings_test.go @@ -0,0 +1,74 @@ +package websvc_test + +import ( + "crypto/tls" + "encoding/json" + "net/http" + "net/netip" + "net/url" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" + "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestService_HandleGetSettingsAll(t *testing.T) { + // TODO(a.garipov): Add all currently supported parameters. + + wantDNS := &websvc.HTTPAPIDNSSettings{ + Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:53")}, + BootstrapServers: []string{"94.140.14.140", "94.140.14.141"}, + UpstreamServers: []string{"94.140.14.14", "1.1.1.1"}, + UpstreamTimeout: websvc.JSONDuration(1 * time.Second), + } + + wantWeb := &websvc.HTTPAPIHTTPSettings{ + Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")}, + SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")}, + Timeout: websvc.JSONDuration(5 * time.Second), + ForceHTTPS: true, + } + + confMgr := newConfigManager() + confMgr.onDNS = func() (s websvc.ServiceWithConfig[*dnssvc.Config]) { + c, err := dnssvc.New(&dnssvc.Config{ + Addresses: wantDNS.Addresses, + UpstreamServers: wantDNS.UpstreamServers, + BootstrapServers: wantDNS.BootstrapServers, + UpstreamTimeout: time.Duration(wantDNS.UpstreamTimeout), + }) + require.NoError(t, err) + + return c + } + + confMgr.onWeb = func() (s websvc.ServiceWithConfig[*websvc.Config]) { + return websvc.New(&websvc.Config{ + TLS: &tls.Config{ + Certificates: []tls.Certificate{{}}, + }, + Addresses: wantWeb.Addresses, + SecureAddresses: wantWeb.SecureAddresses, + Timeout: time.Duration(wantWeb.Timeout), + ForceHTTPS: true, + }) + } + + _, addr := newTestServer(t, confMgr) + u := &url.URL{ + Scheme: "http", + Host: addr.String(), + Path: websvc.PathV1SettingsAll, + } + + body := httpGet(t, u, http.StatusOK) + resp := &websvc.RespGetV1SettingsAll{} + err := json.Unmarshal(body, resp) + require.NoError(t, err) + + assert.Equal(t, wantDNS, resp.DNS) + assert.Equal(t, wantWeb, resp.HTTP) +} diff --git a/internal/v1/websvc/system.go b/internal/next/websvc/system.go similarity index 87% rename from internal/v1/websvc/system.go rename to internal/next/websvc/system.go index 47d0c63c..fbf60fe4 100644 --- a/internal/v1/websvc/system.go +++ b/internal/next/websvc/system.go @@ -16,20 +16,20 @@ type RespGetV1SystemInfo struct { Channel string `json:"channel"` OS string `json:"os"` NewVersion string `json:"new_version,omitempty"` - Start jsonTime `json:"start"` + Start JSONTime `json:"start"` Version string `json:"version"` } // handleGetV1SystemInfo is the handler for the GET /api/v1/system/info HTTP // API. func (svc *Service) handleGetV1SystemInfo(w http.ResponseWriter, r *http.Request) { - writeJSONResponse(w, r, &RespGetV1SystemInfo{ + writeJSONOKResponse(w, r, &RespGetV1SystemInfo{ Arch: runtime.GOARCH, Channel: version.Channel(), OS: runtime.GOOS, // TODO(a.garipov): Fill this when we have an updater. NewVersion: "", - Start: jsonTime(svc.start), + Start: JSONTime(svc.start), Version: version.Version(), }) } diff --git a/internal/v1/websvc/system_test.go b/internal/next/websvc/system_test.go similarity index 82% rename from internal/v1/websvc/system_test.go rename to internal/next/websvc/system_test.go index 49579ca5..acbdcba2 100644 --- a/internal/v1/websvc/system_test.go +++ b/internal/next/websvc/system_test.go @@ -8,16 +8,17 @@ import ( "testing" "time" - "github.com/AdguardTeam/AdGuardHome/internal/v1/websvc" + "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestService_handleGetV1SystemInfo(t *testing.T) { - _, addr := newTestServer(t) + confMgr := newConfigManager() + _, addr := newTestServer(t, confMgr) u := &url.URL{ Scheme: "http", - Host: addr, + Host: addr.String(), Path: websvc.PathV1SystemInfo, } diff --git a/internal/next/websvc/waitlistener.go b/internal/next/websvc/waitlistener.go new file mode 100644 index 00000000..8ab56269 --- /dev/null +++ b/internal/next/websvc/waitlistener.go @@ -0,0 +1,31 @@ +package websvc + +import ( + "net" + "sync" +) + +// Wait Listener + +// waitListener is a wrapper around a listener that also calls wg.Done() on the +// first call to Accept. It is useful in situations where it is important to +// catch the precise moment of the first call to Accept, for example when +// starting an HTTP server. +// +// TODO(a.garipov): Move to aghnet? +type waitListener struct { + net.Listener + + firstAcceptWG *sync.WaitGroup + firstAcceptOnce sync.Once +} + +// type check +var _ net.Listener = (*waitListener)(nil) + +// Accept implements the [net.Listener] interface for *waitListener. +func (l *waitListener) Accept() (conn net.Conn, err error) { + l.firstAcceptOnce.Do(l.firstAcceptWG.Done) + + return l.Listener.Accept() +} diff --git a/internal/next/websvc/waitlistener_internal_test.go b/internal/next/websvc/waitlistener_internal_test.go new file mode 100644 index 00000000..e151341b --- /dev/null +++ b/internal/next/websvc/waitlistener_internal_test.go @@ -0,0 +1,46 @@ +package websvc + +import ( + "net" + "sync" + "sync/atomic" + "testing" + + "github.com/AdguardTeam/AdGuardHome/internal/aghchan" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/stretchr/testify/assert" +) + +func TestWaitListener_Accept(t *testing.T) { + // TODO(a.garipov): use atomic.Bool in Go 1.19. + var numAcceptCalls uint32 + var l net.Listener = &aghtest.Listener{ + OnAccept: func() (conn net.Conn, err error) { + atomic.AddUint32(&numAcceptCalls, 1) + + return nil, nil + }, + OnAddr: func() (addr net.Addr) { panic("not implemented") }, + OnClose: func() (err error) { panic("not implemented") }, + } + + wg := &sync.WaitGroup{} + wg.Add(1) + + done := make(chan struct{}) + go aghchan.MustReceive(done, testTimeout) + + go func() { + var wrapper net.Listener = &waitListener{ + Listener: l, + firstAcceptWG: wg, + } + + _, _ = wrapper.Accept() + }() + + wg.Wait() + close(done) + + assert.Equal(t, uint32(1), atomic.LoadUint32(&numAcceptCalls)) +} diff --git a/internal/v1/websvc/websvc.go b/internal/next/websvc/websvc.go similarity index 52% rename from internal/v1/websvc/websvc.go rename to internal/next/websvc/websvc.go index bbaac005..75f7d001 100644 --- a/internal/v1/websvc/websvc.go +++ b/internal/next/websvc/websvc.go @@ -1,4 +1,7 @@ -// Package websvc contains the AdGuard Home web service. +// Package websvc contains the AdGuard Home HTTP API service. +// +// NOTE: Packages other than cmd must not import this package, as it imports +// most other packages. // // TODO(a.garipov): Add tests. package websvc @@ -14,18 +17,46 @@ import ( "sync" "time" - "github.com/AdguardTeam/AdGuardHome/internal/v1/agh" + "github.com/AdguardTeam/AdGuardHome/internal/next/agh" + "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" httptreemux "github.com/dimfeld/httptreemux/v5" ) +// ServiceWithConfig is an extension of the [agh.Service] interface for services +// that can return their configuration. +// +// TODO(a.garipov): Consider removing this generic interface if we figure out +// how to make it testable in a better way. +type ServiceWithConfig[ConfigType any] interface { + agh.Service + + Config() (c ConfigType) +} + +// ConfigManager is the configuration manager interface. +type ConfigManager interface { + DNS() (svc ServiceWithConfig[*dnssvc.Config]) + Web() (svc ServiceWithConfig[*Config]) + + UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error) + UpdateWeb(ctx context.Context, c *Config) (err error) +} + // Config is the AdGuard Home web service configuration structure. type Config struct { + // ConfigManager is used to show information about services as well as + // dynamically reconfigure them. + ConfigManager ConfigManager + // TLS is the optional TLS configuration. If TLS is not nil, // SecureAddresses must not be empty. TLS *tls.Config + // Start is the time of start of AdGuard Home. + Start time.Time + // Addresses are the addresses on which to serve the plain HTTP API. Addresses []netip.AddrPort @@ -33,40 +64,48 @@ type Config struct { // SecureAddresses is not empty, TLS must not be nil. SecureAddresses []netip.AddrPort - // Start is the time of start of AdGuard Home. - Start time.Time - // Timeout is the timeout for all server operations. Timeout time.Duration + + // ForceHTTPS tells if all requests to Addresses should be redirected to a + // secure address instead. + // + // TODO(a.garipov): Use; define rules, which address to redirect to. + ForceHTTPS bool } // Service is the AdGuard Home web service. A nil *Service is a valid // [agh.Service] that does nothing. type Service struct { - tls *tls.Config - servers []*http.Server - start time.Time - timeout time.Duration + confMgr ConfigManager + tls *tls.Config + start time.Time + servers []*http.Server + timeout time.Duration + forceHTTPS bool } // New returns a new properly initialized *Service. If c is nil, svc is a nil -// *Service that does nothing. +// *Service that does nothing. The fields of c must not be modified after +// calling New. func New(c *Config) (svc *Service) { if c == nil { return nil } svc = &Service{ - tls: c.TLS, - start: c.Start, - timeout: c.Timeout, + confMgr: c.ConfigManager, + tls: c.TLS, + start: c.Start, + timeout: c.Timeout, + forceHTTPS: c.ForceHTTPS, } mux := newMux(svc) for _, a := range c.Addresses { addr := a.String() - errLog := log.StdLog("websvc: http: "+addr, log.ERROR) + errLog := log.StdLog("websvc: plain http: "+addr, log.ERROR) svc.servers = append(svc.servers, &http.Server{ Addr: addr, Handler: mux, @@ -111,6 +150,21 @@ func newMux(svc *Service) (mux *httptreemux.ContextMux) { method: http.MethodGet, path: PathHealthCheck, isJSON: false, + }, { + handler: svc.handleGetSettingsAll, + method: http.MethodGet, + path: PathV1SettingsAll, + isJSON: true, + }, { + handler: svc.handlePatchSettingsDNS, + method: http.MethodPatch, + path: PathV1SettingsDNS, + isJSON: true, + }, { + handler: svc.handlePatchSettingsHTTP, + method: http.MethodPatch, + path: PathV1SettingsHTTP, + isJSON: true, }, { handler: svc.handleGetV1SystemInfo, method: http.MethodGet, @@ -119,29 +173,41 @@ func newMux(svc *Service) (mux *httptreemux.ContextMux) { }} for _, r := range routes { - var h http.HandlerFunc if r.isJSON { - // TODO(a.garipov): Consider using httptreemux's MiddlewareFunc. - h = jsonMw(r.handler) + mux.Handle(r.method, r.path, jsonMw(r.handler)) } else { - h = r.handler + mux.Handle(r.method, r.path, r.handler) } - - mux.Handle(r.method, r.path, h) } return mux } -// Addrs returns all addresses on which this server serves the HTTP API. Addrs -// must not be called until Start returns. -func (svc *Service) Addrs() (addrs []string) { - addrs = make([]string, 0, len(svc.servers)) +// addrs returns all addresses on which this server serves the HTTP API. addrs +// must not be called simultaneously with Start. If svc was initialized with +// ":0" addresses, addrs will not return the actual bound ports until Start is +// finished. +func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) { for _, srv := range svc.servers { - addrs = append(addrs, srv.Addr) + addrPort, err := netip.ParseAddrPort(srv.Addr) + if err != nil { + // Technically shouldn't happen, since all servers must have a valid + // address. + panic(fmt.Errorf("websvc: server %q: bad address: %w", srv.Addr, err)) + } + + // srv.Serve will set TLSConfig to an almost empty value, so, instead of + // relying only on the nilness of TLSConfig, check the length of the + // certificates field as well. + if srv.TLSConfig == nil || len(srv.TLSConfig.Certificates) == 0 { + addrs = append(addrs, addrPort) + } else { + secureAddrs = append(secureAddrs, addrPort) + } + } - return addrs + return addrs, secureAddrs } // handleGetHealthCheck is the handler for the GET /health-check HTTP API. @@ -149,9 +215,6 @@ func (svc *Service) handleGetHealthCheck(w http.ResponseWriter, _ *http.Request) _, _ = io.WriteString(w, "OK") } -// unit is a convenient alias for struct{}. -type unit = struct{} - // type check var _ agh.Service = (*Service)(nil) @@ -163,11 +226,9 @@ func (svc *Service) Start() (err error) { return nil } - srvs := svc.servers - wg := &sync.WaitGroup{} - wg.Add(len(srvs)) - for _, srv := range srvs { + wg.Add(len(svc.servers)) + for _, srv := range svc.servers { go serve(srv, wg) } @@ -181,11 +242,14 @@ func serve(srv *http.Server, wg *sync.WaitGroup) { addr := srv.Addr defer log.OnPanic(addr) + var proto string var l net.Listener var err error if srv.TLSConfig == nil { + proto = "http" l, err = net.Listen("tcp", addr) } else { + proto = "https" l, err = tls.Listen("tcp", addr, srv.TLSConfig) } if err != nil { @@ -196,8 +260,12 @@ func serve(srv *http.Server, wg *sync.WaitGroup) { // would mean that a random available port was automatically chosen. srv.Addr = l.Addr().String() - log.Info("websvc: starting srv http://%s", srv.Addr) - wg.Done() + log.Info("websvc: starting srv %s://%s", proto, srv.Addr) + + l = &waitListener{ + Listener: l, + firstAcceptWG: wg, + } err = srv.Serve(l) if err != nil && !errors.Is(err, http.ErrServerClosed) { @@ -221,8 +289,28 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) { } if len(errs) > 0 { - return errors.List("shutting down") + return errors.List("shutting down", errs...) } return nil } + +// Config returns the current configuration of the web service. Config must not +// be called simultaneously with Start. If svc was initialized with ":0" +// addresses, addrs will not return the actual bound ports until Start is +// finished. +func (svc *Service) Config() (c *Config) { + c = &Config{ + ConfigManager: svc.confMgr, + TLS: svc.tls, + // Leave Addresses and SecureAddresses empty and get the actual + // addresses that include the :0 ones later. + Start: svc.start, + Timeout: svc.timeout, + ForceHTTPS: svc.forceHTTPS, + } + + c.Addresses, c.SecureAddresses = svc.addrs() + + return c +} diff --git a/internal/next/websvc/websvc_internal_test.go b/internal/next/websvc/websvc_internal_test.go new file mode 100644 index 00000000..3509b193 --- /dev/null +++ b/internal/next/websvc/websvc_internal_test.go @@ -0,0 +1,6 @@ +package websvc + +import "time" + +// testTimeout is the common timeout for tests. +const testTimeout = 1 * time.Second diff --git a/internal/next/websvc/websvc_test.go b/internal/next/websvc/websvc_test.go new file mode 100644 index 00000000..dbce77d5 --- /dev/null +++ b/internal/next/websvc/websvc_test.go @@ -0,0 +1,187 @@ +package websvc_test + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/netip" + "net/url" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" + "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" + "github.com/AdguardTeam/golibs/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMain(m *testing.M) { + aghtest.DiscardLogOutput(m) +} + +// testTimeout is the common timeout for tests. +const testTimeout = 1 * time.Second + +// testStart is the server start value for tests. +var testStart = time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC) + +// type check +var _ websvc.ConfigManager = (*configManager)(nil) + +// configManager is a [websvc.ConfigManager] for tests. +type configManager struct { + onDNS func() (svc websvc.ServiceWithConfig[*dnssvc.Config]) + onWeb func() (svc websvc.ServiceWithConfig[*websvc.Config]) + + onUpdateDNS func(ctx context.Context, c *dnssvc.Config) (err error) + onUpdateWeb func(ctx context.Context, c *websvc.Config) (err error) +} + +// DNS implements the [websvc.ConfigManager] interface for *configManager. +func (m *configManager) DNS() (svc websvc.ServiceWithConfig[*dnssvc.Config]) { + return m.onDNS() +} + +// Web implements the [websvc.ConfigManager] interface for *configManager. +func (m *configManager) Web() (svc websvc.ServiceWithConfig[*websvc.Config]) { + return m.onWeb() +} + +// UpdateDNS implements the [websvc.ConfigManager] interface for *configManager. +func (m *configManager) UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error) { + return m.onUpdateDNS(ctx, c) +} + +// UpdateWeb implements the [websvc.ConfigManager] interface for *configManager. +func (m *configManager) UpdateWeb(ctx context.Context, c *websvc.Config) (err error) { + return m.onUpdateWeb(ctx, c) +} + +// newConfigManager returns a *configManager all methods of which panic. +func newConfigManager() (m *configManager) { + return &configManager{ + onDNS: func() (svc websvc.ServiceWithConfig[*dnssvc.Config]) { panic("not implemented") }, + onWeb: func() (svc websvc.ServiceWithConfig[*websvc.Config]) { panic("not implemented") }, + onUpdateDNS: func(_ context.Context, _ *dnssvc.Config) (err error) { + panic("not implemented") + }, + onUpdateWeb: func(_ context.Context, _ *websvc.Config) (err error) { + panic("not implemented") + }, + } +} + +// newTestServer creates and starts a new web service instance as well as its +// sole address. It also registers a cleanup procedure, which shuts the +// instance down. +// +// TODO(a.garipov): Use svc or remove it. +func newTestServer( + t testing.TB, + confMgr websvc.ConfigManager, +) (svc *websvc.Service, addr netip.AddrPort) { + t.Helper() + + c := &websvc.Config{ + ConfigManager: confMgr, + TLS: nil, + Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")}, + SecureAddresses: nil, + Timeout: testTimeout, + Start: testStart, + ForceHTTPS: false, + } + + svc = websvc.New(c) + + err := svc.Start() + require.NoError(t, err) + t.Cleanup(func() { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + err = svc.Shutdown(ctx) + require.NoError(t, err) + }) + + c = svc.Config() + require.NotNil(t, c) + require.Len(t, c.Addresses, 1) + + return svc, c.Addresses[0] +} + +// jobj is a utility alias for JSON objects. +type jobj map[string]any + +// httpGet is a helper that performs an HTTP GET request and returns the body of +// the response as well as checks that the status code is correct. +// +// TODO(a.garipov): Add helpers for other methods. +func httpGet(t testing.TB, u *url.URL, wantCode int) (body []byte) { + t.Helper() + + req, err := http.NewRequest(http.MethodGet, u.String(), nil) + require.NoErrorf(t, err, "creating req") + + httpCli := &http.Client{ + Timeout: testTimeout, + } + resp, err := httpCli.Do(req) + require.NoErrorf(t, err, "performing req") + require.Equal(t, wantCode, resp.StatusCode) + + testutil.CleanupAndRequireSuccess(t, resp.Body.Close) + + body, err = io.ReadAll(resp.Body) + require.NoErrorf(t, err, "reading body") + + return body +} + +// httpPatch is a helper that performs an HTTP PATCH request with JSON-encoded +// reqBody as the request body and returns the body of the response as well as +// checks that the status code is correct. +// +// TODO(a.garipov): Add helpers for other methods. +func httpPatch(t testing.TB, u *url.URL, reqBody any, wantCode int) (body []byte) { + t.Helper() + + b, err := json.Marshal(reqBody) + require.NoErrorf(t, err, "marshaling reqBody") + + req, err := http.NewRequest(http.MethodPatch, u.String(), bytes.NewReader(b)) + require.NoErrorf(t, err, "creating req") + + httpCli := &http.Client{ + Timeout: testTimeout, + } + resp, err := httpCli.Do(req) + require.NoErrorf(t, err, "performing req") + require.Equal(t, wantCode, resp.StatusCode) + + testutil.CleanupAndRequireSuccess(t, resp.Body.Close) + + body, err = io.ReadAll(resp.Body) + require.NoErrorf(t, err, "reading body") + + return body +} + +func TestService_Start_getHealthCheck(t *testing.T) { + confMgr := newConfigManager() + _, addr := newTestServer(t, confMgr) + u := &url.URL{ + Scheme: "http", + Host: addr.String(), + Path: websvc.PathHealthCheck, + } + + body := httpGet(t, u, http.StatusOK) + + assert.Equal(t, []byte("OK"), body) +} diff --git a/internal/v1/websvc/json.go b/internal/v1/websvc/json.go deleted file mode 100644 index ef84211b..00000000 --- a/internal/v1/websvc/json.go +++ /dev/null @@ -1,61 +0,0 @@ -package websvc - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - "strconv" - "time" - - "github.com/AdguardTeam/golibs/log" -) - -// JSON Utilities - -// jsonTime is a time.Time that can be decoded from JSON and encoded into JSON -// according to our API conventions. -type jsonTime time.Time - -// type check -var _ json.Marshaler = jsonTime{} - -// nsecPerMsec is the number of nanoseconds in a millisecond. -const nsecPerMsec = float64(time.Millisecond / time.Nanosecond) - -// MarshalJSON implements the json.Marshaler interface for jsonTime. err is -// always nil. -func (t jsonTime) MarshalJSON() (b []byte, err error) { - msec := float64(time.Time(t).UnixNano()) / nsecPerMsec - b = strconv.AppendFloat(nil, msec, 'f', 3, 64) - - return b, nil -} - -// type check -var _ json.Unmarshaler = (*jsonTime)(nil) - -// UnmarshalJSON implements the json.Marshaler interface for *jsonTime. -func (t *jsonTime) UnmarshalJSON(b []byte) (err error) { - if t == nil { - return fmt.Errorf("json time is nil") - } - - msec, err := strconv.ParseFloat(string(b), 64) - if err != nil { - return fmt.Errorf("parsing json time: %w", err) - } - - *t = jsonTime(time.Unix(0, int64(msec*nsecPerMsec)).UTC()) - - return nil -} - -// writeJSONResponse encodes v into w and logs any errors it encounters. r is -// used to get additional information from the request. -func writeJSONResponse(w io.Writer, r *http.Request, v any) { - err := json.NewEncoder(w).Encode(v) - if err != nil { - log.Error("websvc: writing resp to %s %s: %s", r.Method, r.URL.Path, err) - } -} diff --git a/internal/v1/websvc/path.go b/internal/v1/websvc/path.go deleted file mode 100644 index cfd67fd9..00000000 --- a/internal/v1/websvc/path.go +++ /dev/null @@ -1,8 +0,0 @@ -package websvc - -// Path constants -const ( - PathHealthCheck = "/health-check" - - PathV1SystemInfo = "/api/v1/system/info" -) diff --git a/internal/v1/websvc/websvc_test.go b/internal/v1/websvc/websvc_test.go deleted file mode 100644 index de4a9f5d..00000000 --- a/internal/v1/websvc/websvc_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package websvc_test - -import ( - "context" - "io" - "net/http" - "net/netip" - "net/url" - "testing" - "time" - - "github.com/AdguardTeam/AdGuardHome/internal/v1/websvc" - "github.com/AdguardTeam/golibs/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -const testTimeout = 1 * time.Second - -// testStart is the server start value for tests. -var testStart = time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC) - -// newTestServer creates and starts a new web service instance as well as its -// sole address. It also registers a cleanup procedure, which shuts the -// instance down. -// -// TODO(a.garipov): Use svc or remove it. -func newTestServer(t testing.TB) (svc *websvc.Service, addr string) { - t.Helper() - - c := &websvc.Config{ - TLS: nil, - Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")}, - SecureAddresses: nil, - Timeout: testTimeout, - Start: testStart, - } - - svc = websvc.New(c) - - err := svc.Start() - require.NoError(t, err) - t.Cleanup(func() { - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - t.Cleanup(cancel) - - err = svc.Shutdown(ctx) - require.NoError(t, err) - }) - - addrs := svc.Addrs() - require.Len(t, addrs, 1) - - return svc, addrs[0] -} - -// httpGet is a helper that performs an HTTP GET request and returns the body of -// the response as well as checks that the status code is correct. -// -// TODO(a.garipov): Add helpers for other methods. -func httpGet(t testing.TB, u *url.URL, wantCode int) (body []byte) { - t.Helper() - - req, err := http.NewRequest(http.MethodGet, u.String(), nil) - require.NoErrorf(t, err, "creating req") - - httpCli := &http.Client{ - Timeout: testTimeout, - } - resp, err := httpCli.Do(req) - require.NoErrorf(t, err, "performing req") - require.Equal(t, wantCode, resp.StatusCode) - - testutil.CleanupAndRequireSuccess(t, resp.Body.Close) - - body, err = io.ReadAll(resp.Body) - require.NoErrorf(t, err, "reading body") - - return body -} - -func TestService_Start_getHealthCheck(t *testing.T) { - _, addr := newTestServer(t) - u := &url.URL{ - Scheme: "http", - Host: addr, - Path: websvc.PathHealthCheck, - } - - body := httpGet(t, u, http.StatusOK) - - assert.Equal(t, []byte("OK"), body) -} diff --git a/internal/version/version.go b/internal/version/version.go index 2091d859..ca78efff 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -63,14 +63,6 @@ func Version() (v string) { return version } -// Constants defining the format of module information string. -const ( - modInfoAtSep = "@" - modInfoDevSep = " " - modInfoSumLeft = " (sum: " - modInfoSumRight = ")" -) - // fmtModule returns formatted information about module. The result looks like: // // github.com/Username/module@v1.2.3 (sum: someHASHSUM=) @@ -87,14 +79,16 @@ func fmtModule(m *debug.Module) (formatted string) { stringutil.WriteToBuilder(b, m.Path) if ver := m.Version; ver != "" { - sep := modInfoAtSep + sep := "@" if ver == "(devel)" { - sep = modInfoDevSep + sep = " " } + stringutil.WriteToBuilder(b, sep, ver) } + if sum := m.Sum; sum != "" { - stringutil.WriteToBuilder(b, modInfoSumLeft, sum, modInfoSumRight) + stringutil.WriteToBuilder(b, "(sum: ", sum, ")") } return b.String() diff --git a/main.go b/main.go index 03ad2f03..615a8a86 100644 --- a/main.go +++ b/main.go @@ -1,5 +1,5 @@ -//go:build !v1 -// +build !v1 +//go:build !next +// +build !next package main diff --git a/main_v1.go b/main_next.go similarity index 79% rename from main_v1.go rename to main_next.go index 6b5f3dea..0006e87b 100644 --- a/main_v1.go +++ b/main_next.go @@ -1,12 +1,12 @@ -//go:build v1 -// +build v1 +//go:build next +// +build next package main import ( "embed" - "github.com/AdguardTeam/AdGuardHome/internal/v1/cmd" + "github.com/AdguardTeam/AdGuardHome/internal/next/cmd" ) // Embed the prebuilt client here since we strive to keep .go files inside the diff --git a/openapi/v1.yaml b/openapi/v1.yaml index 77eb1a09..adab6d4d 100644 --- a/openapi/v1.yaml +++ b/openapi/v1.yaml @@ -2289,7 +2289,7 @@ 'upstream_servers': - '1.1.1.1' - '8.8.8.8' - 'upstream_timeout': '1s' + 'upstream_timeout': 1000 'required': - 'addresses' - 'blocking_mode' @@ -2397,8 +2397,9 @@ 'type': 'array' 'upstream_timeout': 'description': > - Upstream request timeout, as a human readable duration. - 'type': 'string' + Upstream request timeout, in milliseconds. + 'format': 'double' + 'type': 'number' 'type': 'object' 'DnsType': @@ -3505,14 +3506,16 @@ 'addresses': - '127.0.0.1:80' - '192.168.1.1:80' + 'force_https': true 'secure_addresses': - '127.0.0.1:443' - '192.168.1.1:443' - 'force_https': true + 'timeout': 10000 'required': - 'addresses' - - 'secure_addresses' - 'force_https' + - 'secure_addresses' + - 'timeout' 'HttpSettingsPatch': 'description': > @@ -3539,6 +3542,11 @@ 'items': 'type': 'string' 'type': 'array' + 'timeout': + 'description': > + HTTP request timeout, in milliseconds. + 'format': 'double' + 'type': 'number' 'type': 'object' 'InternalServerErrorResp': diff --git a/scripts/make/go-lint.sh b/scripts/make/go-lint.sh index 2cdcc90d..e04af725 100644 --- a/scripts/make/go-lint.sh +++ b/scripts/make/go-lint.sh @@ -136,11 +136,11 @@ underscores() { -e '_freebsd.go'\ -e '_linux.go'\ -e '_little.go'\ + -e '_next.go'\ -e '_openbsd.go'\ -e '_others.go'\ -e '_test.go'\ -e '_unix.go'\ - -e '_v1.go'\ -e '_windows.go' \ -v\ | sed -e 's/./\t\0/' @@ -229,7 +229,7 @@ gocyclo --over 13 ./internal/filtering/ # Apply stricter standards to new or somewhat refactored code. gocyclo --over 10 ./internal/aghio/ ./internal/aghnet/ ./internal/aghos/\ ./internal/aghtest/ ./internal/dnsforward/ ./internal/stats/\ - ./internal/tools/ ./internal/updater/ ./internal/v1/ ./internal/version/\ + ./internal/tools/ ./internal/updater/ ./internal/next/ ./internal/version/\ ./main.go ineffassign ./...