2021-02-04 20:35:13 +03:00
|
|
|
// Package aghtest contains utilities for testing.
|
|
|
|
package aghtest
|
2020-11-16 15:52:05 +03:00
|
|
|
|
|
|
|
import (
|
2023-07-20 14:26:35 +03:00
|
|
|
"crypto/sha256"
|
2020-11-16 19:45:31 +03:00
|
|
|
"io"
|
2023-09-21 17:07:57 +03:00
|
|
|
"net/http"
|
|
|
|
"net/http/httptest"
|
2023-08-23 16:58:24 +03:00
|
|
|
"net/netip"
|
2023-09-21 17:07:57 +03:00
|
|
|
"net/url"
|
2020-11-16 15:52:05 +03:00
|
|
|
"testing"
|
2024-03-11 18:17:04 +03:00
|
|
|
"time"
|
2020-11-16 15:52:05 +03:00
|
|
|
|
2024-03-11 18:17:04 +03:00
|
|
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
2020-11-16 15:52:05 +03:00
|
|
|
"github.com/AdguardTeam/golibs/log"
|
2024-03-11 18:17:04 +03:00
|
|
|
"github.com/AdguardTeam/golibs/netutil"
|
|
|
|
"github.com/AdguardTeam/golibs/testutil"
|
|
|
|
"github.com/miekg/dns"
|
2023-09-21 17:07:57 +03:00
|
|
|
"github.com/stretchr/testify/require"
|
2020-11-16 15:52:05 +03:00
|
|
|
)
|
|
|
|
|
2023-08-09 16:27:21 +03:00
|
|
|
const (
|
|
|
|
// ReqHost is the common request host for filtering tests.
|
|
|
|
ReqHost = "www.host.example"
|
|
|
|
|
|
|
|
// ReqFQDN is the common request FQDN for filtering tests.
|
|
|
|
ReqFQDN = ReqHost + "."
|
|
|
|
)
|
|
|
|
|
2020-11-16 19:45:31 +03:00
|
|
|
// ReplaceLogWriter moves logger output to w and uses Cleanup method of t to
|
|
|
|
// revert changes.
|
2022-02-03 21:19:32 +03:00
|
|
|
func ReplaceLogWriter(t testing.TB, w io.Writer) {
|
|
|
|
t.Helper()
|
|
|
|
|
|
|
|
prev := log.Writer()
|
|
|
|
t.Cleanup(func() { log.SetOutput(prev) })
|
2020-11-16 19:45:31 +03:00
|
|
|
log.SetOutput(w)
|
|
|
|
}
|
|
|
|
|
|
|
|
// ReplaceLogLevel sets logging level to l and uses Cleanup method of t to
|
|
|
|
// revert changes.
|
2022-02-03 21:19:32 +03:00
|
|
|
func ReplaceLogLevel(t testing.TB, l log.Level) {
|
|
|
|
t.Helper()
|
|
|
|
|
2020-11-16 19:45:31 +03:00
|
|
|
switch l {
|
|
|
|
case log.INFO, log.DEBUG, log.ERROR:
|
|
|
|
// Go on.
|
|
|
|
default:
|
|
|
|
t.Fatalf("wrong l value (must be one of %v, %v, %v)", log.INFO, log.DEBUG, log.ERROR)
|
|
|
|
}
|
|
|
|
|
2022-02-03 21:19:32 +03:00
|
|
|
prev := log.GetLevel()
|
|
|
|
t.Cleanup(func() { log.SetLevel(prev) })
|
2020-11-16 19:45:31 +03:00
|
|
|
log.SetLevel(l)
|
|
|
|
}
|
2023-07-20 14:26:35 +03:00
|
|
|
|
|
|
|
// HostToIPs is a helper that generates one IPv4 and one IPv6 address from host.
|
2023-08-23 16:58:24 +03:00
|
|
|
func HostToIPs(host string) (ipv4, ipv6 netip.Addr) {
|
2023-07-20 14:26:35 +03:00
|
|
|
hash := sha256.Sum256([]byte(host))
|
|
|
|
|
2023-08-23 16:58:24 +03:00
|
|
|
return netip.AddrFrom4([4]byte(hash[:4])), netip.AddrFrom16([16]byte(hash[4:20]))
|
2023-07-20 14:26:35 +03:00
|
|
|
}
|
2023-09-21 17:07:57 +03:00
|
|
|
|
|
|
|
// StartHTTPServer is a helper that starts the HTTP server, which is configured
|
|
|
|
// to return data on every request, and returns the client and server URL.
|
|
|
|
func StartHTTPServer(t testing.TB, data []byte) (c *http.Client, u *url.URL) {
|
|
|
|
t.Helper()
|
|
|
|
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
|
|
_, _ = w.Write(data)
|
|
|
|
}))
|
|
|
|
t.Cleanup(srv.Close)
|
|
|
|
|
|
|
|
u, err := url.Parse(srv.URL)
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
return srv.Client(), u
|
|
|
|
}
|
2024-03-11 18:17:04 +03:00
|
|
|
|
|
|
|
// testTimeout is a timeout for tests.
|
|
|
|
//
|
|
|
|
// TODO(e.burkov): Move into agdctest.
|
|
|
|
const testTimeout = 1 * time.Second
|
|
|
|
|
|
|
|
// StartLocalhostUpstream is a test helper that starts a DNS server on
|
|
|
|
// localhost.
|
|
|
|
func StartLocalhostUpstream(t *testing.T, h dns.Handler) (addr *url.URL) {
|
|
|
|
t.Helper()
|
|
|
|
|
|
|
|
startCh := make(chan netip.AddrPort)
|
|
|
|
defer close(startCh)
|
|
|
|
errCh := make(chan error)
|
|
|
|
|
|
|
|
srv := &dns.Server{
|
|
|
|
Addr: "127.0.0.1:0",
|
|
|
|
Net: string(proxy.ProtoTCP),
|
|
|
|
Handler: h,
|
|
|
|
ReadTimeout: testTimeout,
|
|
|
|
WriteTimeout: testTimeout,
|
|
|
|
}
|
|
|
|
srv.NotifyStartedFunc = func() {
|
|
|
|
addrPort := srv.Listener.Addr()
|
|
|
|
startCh <- netutil.NetAddrToAddrPort(addrPort)
|
|
|
|
}
|
|
|
|
|
|
|
|
go func() { errCh <- srv.ListenAndServe() }()
|
|
|
|
|
|
|
|
select {
|
|
|
|
case addrPort := <-startCh:
|
|
|
|
addr = &url.URL{
|
|
|
|
Scheme: string(proxy.ProtoTCP),
|
|
|
|
Host: addrPort.String(),
|
|
|
|
}
|
|
|
|
|
|
|
|
testutil.CleanupAndRequireSuccess(t, func() (err error) { return <-errCh })
|
|
|
|
testutil.CleanupAndRequireSuccess(t, srv.Shutdown)
|
|
|
|
case err := <-errCh:
|
|
|
|
require.NoError(t, err)
|
|
|
|
case <-time.After(testTimeout):
|
|
|
|
require.FailNow(t, "timeout exceeded")
|
|
|
|
}
|
|
|
|
|
|
|
|
return addr
|
|
|
|
}
|