mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-12-15 19:41:36 +03:00
126 lines
2.9 KiB
Go
126 lines
2.9 KiB
Go
|
package dnssvc_test
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"net/netip"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||
|
"github.com/AdguardTeam/golibs/testutil"
|
||
|
"github.com/miekg/dns"
|
||
|
"github.com/stretchr/testify/assert"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
)
|
||
|
|
||
|
func TestMain(m *testing.M) {
|
||
|
testutil.DiscardLogOutput(m)
|
||
|
}
|
||
|
|
||
|
// testTimeout is the common timeout for tests.
|
||
|
const testTimeout = 1 * time.Second
|
||
|
|
||
|
func TestService(t *testing.T) {
|
||
|
const (
|
||
|
listenAddr = "127.0.0.1:0"
|
||
|
bootstrapAddr = "127.0.0.1:0"
|
||
|
upstreamAddr = "upstream.example"
|
||
|
)
|
||
|
|
||
|
upstreamErrCh := make(chan error, 1)
|
||
|
upstreamStartedCh := make(chan struct{})
|
||
|
upstreamSrv := &dns.Server{
|
||
|
Addr: bootstrapAddr,
|
||
|
Net: "udp",
|
||
|
Handler: dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||
|
pt := testutil.PanicT{}
|
||
|
|
||
|
resp := (&dns.Msg{}).SetReply(req)
|
||
|
resp.Answer = append(resp.Answer, &dns.A{
|
||
|
Hdr: dns.RR_Header{},
|
||
|
A: netip.MustParseAddrPort(bootstrapAddr).Addr().AsSlice(),
|
||
|
})
|
||
|
|
||
|
writeErr := w.WriteMsg(resp)
|
||
|
require.NoError(pt, writeErr)
|
||
|
}),
|
||
|
NotifyStartedFunc: func() { close(upstreamStartedCh) },
|
||
|
}
|
||
|
|
||
|
go func() {
|
||
|
listenErr := upstreamSrv.ListenAndServe()
|
||
|
if listenErr != nil {
|
||
|
// Log these immediately to see what happens.
|
||
|
t.Logf("upstream listen error: %s", listenErr)
|
||
|
}
|
||
|
|
||
|
upstreamErrCh <- listenErr
|
||
|
}()
|
||
|
|
||
|
_, _ = testutil.RequireReceive(t, upstreamStartedCh, testTimeout)
|
||
|
|
||
|
c := &dnssvc.Config{
|
||
|
Addresses: []netip.AddrPort{netip.MustParseAddrPort(listenAddr)},
|
||
|
BootstrapServers: []string{upstreamSrv.PacketConn.LocalAddr().String()},
|
||
|
UpstreamServers: []string{upstreamAddr},
|
||
|
DNS64Prefixes: nil,
|
||
|
UpstreamTimeout: testTimeout,
|
||
|
BootstrapPreferIPv6: false,
|
||
|
UseDNS64: false,
|
||
|
}
|
||
|
|
||
|
svc, err := dnssvc.New(c)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
err = svc.Start()
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
gotConf := svc.Config()
|
||
|
require.NotNil(t, gotConf)
|
||
|
require.Len(t, gotConf.Addresses, 1)
|
||
|
|
||
|
addr := gotConf.Addresses[0]
|
||
|
|
||
|
t.Run("dns", func(t *testing.T) {
|
||
|
req := &dns.Msg{
|
||
|
MsgHdr: dns.MsgHdr{
|
||
|
Id: dns.Id(),
|
||
|
RecursionDesired: true,
|
||
|
},
|
||
|
Question: []dns.Question{{
|
||
|
Name: "example.com.",
|
||
|
Qtype: dns.TypeA,
|
||
|
Qclass: dns.ClassINET,
|
||
|
}},
|
||
|
}
|
||
|
|
||
|
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||
|
defer cancel()
|
||
|
|
||
|
cli := &dns.Client{}
|
||
|
|
||
|
var resp *dns.Msg
|
||
|
require.Eventually(t, func() (ok bool) {
|
||
|
var excErr error
|
||
|
resp, _, excErr = cli.ExchangeContext(ctx, req, addr.String())
|
||
|
|
||
|
return excErr == nil
|
||
|
}, testTimeout, testTimeout/10)
|
||
|
|
||
|
assert.NotNil(t, resp)
|
||
|
})
|
||
|
|
||
|
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||
|
defer cancel()
|
||
|
|
||
|
err = svc.Shutdown(ctx)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
err = upstreamSrv.Shutdown()
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
err, ok := testutil.RequireReceive(t, upstreamErrCh, testTimeout)
|
||
|
require.True(t, ok)
|
||
|
require.NoError(t, err)
|
||
|
}
|