Pull request #1558: add-dnssvc

Merge in DNS/adguard-home from add-dnssvc to master

Squashed commit of the following:

commit 55f4f114bab65a03c0d65383e89020a7356cff32
Merge: 95dc28d9 6e63757f
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon Aug 15 20:53:07 2022 +0300

    Merge branch 'master' into add-dnssvc

commit 95dc28d9d77d06e8ac98c1e6772557bffbf1705b
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon Aug 15 20:52:50 2022 +0300

    all: imp tests, docs

commit 0d9d02950d84afd160b4b1c118da856cee6f12e5
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Aug 11 19:27:59 2022 +0300

    all: imp docs

commit 8990e038a81da4430468da12fcebedf79fe14df6
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Aug 11 19:05:29 2022 +0300

    all: imp tests more

commit 92730d93a2a1ac77888c2655508e43efaf0e9fde
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Aug 11 18:37:48 2022 +0300

    all: imp tests more

commit 8cd45ba30da7ac310e9dc666fb2af438e577b02d
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Aug 11 18:11:15 2022 +0300

    all: add v1 dnssvc stub; refactor tests
This commit is contained in:
Ainar Garipov 2022-08-16 13:21:25 +03:00
parent 6e63757fc7
commit d4c3a43bcb
19 changed files with 742 additions and 319 deletions

View file

@ -10,6 +10,20 @@ import (
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )
// Coalesce returns the first non-zero value. It is named after the function
// COALESCE in SQL. If values or all its elements are empty, it returns a zero
// value.
func Coalesce[T comparable](values ...T) (res T) {
var zero T
for _, v := range values {
if v != zero {
return v
}
}
return zero
}
// UniqChecker allows validating uniqueness of comparable items. // UniqChecker allows validating uniqueness of comparable items.
// //
// TODO(a.garipov): The Ordered constraint is only really necessary in Validate. // TODO(a.garipov): The Ordered constraint is only really necessary in Validate.

View file

@ -470,7 +470,7 @@ func TestHostsContainer(t *testing.T) {
}}, }},
}, { }, {
req: &urlfilter.DNSRequest{ req: &urlfilter.DNSRequest{
Hostname: "nonexisting", Hostname: "nonexistent.example",
DNSType: dns.TypeA, DNSType: dns.TypeA,
}, },
name: "non-existing", name: "non-existing",

View file

@ -1,4 +1,4 @@
package aghos package aghos_test
import ( import (
"testing" "testing"

View file

@ -0,0 +1,57 @@
package aghos
import (
"io/fs"
"path"
"testing"
"testing/fstest"
"github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// errFS is an fs.FS implementation, method Open of which always returns
// errFSOpen.
type errFS struct{}
// errFSOpen is returned from errGlobFS.Open.
const errFSOpen errors.Error = "test open error"
// Open implements the fs.FS interface for *errGlobFS. fsys is always nil and
// err is always errFSOpen.
func (efs *errFS) Open(name string) (fsys fs.File, err error) {
return nil, errFSOpen
}
func TestWalkerFunc_CheckFile(t *testing.T) {
emptyFS := fstest.MapFS{}
t.Run("non-existing", func(t *testing.T) {
_, ok, err := checkFile(emptyFS, nil, "lol")
require.NoError(t, err)
assert.True(t, ok)
})
t.Run("invalid_argument", func(t *testing.T) {
_, ok, err := checkFile(&errFS{}, nil, "")
require.ErrorIs(t, err, errFSOpen)
assert.False(t, ok)
})
t.Run("ignore_dirs", func(t *testing.T) {
const dirName = "dir"
testFS := fstest.MapFS{
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
}
patterns, ok, err := checkFile(testFS, nil, dirName)
require.NoError(t, err)
assert.Empty(t, patterns)
assert.True(t, ok)
})
}

View file

@ -1,13 +1,13 @@
package aghos package aghos_test
import ( import (
"bufio" "bufio"
"io" "io"
"io/fs"
"path" "path"
"testing" "testing"
"testing/fstest" "testing/fstest"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -16,7 +16,7 @@ import (
func TestFileWalker_Walk(t *testing.T) { func TestFileWalker_Walk(t *testing.T) {
const attribute = `000` const attribute = `000`
makeFileWalker := func(_ string) (fw FileWalker) { makeFileWalker := func(_ string) (fw aghos.FileWalker) {
return func(r io.Reader) (patterns []string, cont bool, err error) { return func(r io.Reader) (patterns []string, cont bool, err error) {
s := bufio.NewScanner(r) s := bufio.NewScanner(r)
for s.Scan() { for s.Scan() {
@ -113,7 +113,7 @@ func TestFileWalker_Walk(t *testing.T) {
f := fstest.MapFS{ f := fstest.MapFS{
filename: &fstest.MapFile{Data: []byte("[]")}, filename: &fstest.MapFile{Data: []byte("[]")},
} }
ok, err := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) { ok, err := aghos.FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
s := bufio.NewScanner(r) s := bufio.NewScanner(r)
for s.Scan() { for s.Scan() {
patterns = append(patterns, s.Text()) patterns = append(patterns, s.Text())
@ -134,7 +134,7 @@ func TestFileWalker_Walk(t *testing.T) {
"mockfile.txt": &fstest.MapFile{Data: []byte(`mockdata`)}, "mockfile.txt": &fstest.MapFile{Data: []byte(`mockdata`)},
} }
ok, err := FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) { ok, err := aghos.FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
return nil, true, rerr return nil, true, rerr
}).Walk(f, "*") }).Walk(f, "*")
require.ErrorIs(t, err, rerr) require.ErrorIs(t, err, rerr)
@ -142,45 +142,3 @@ func TestFileWalker_Walk(t *testing.T) {
assert.False(t, ok) assert.False(t, ok)
}) })
} }
type errFS struct {
fs.GlobFS
}
const errErrFSOpen errors.Error = "this error is always returned"
func (efs *errFS) Open(name string) (fs.File, error) {
return nil, errErrFSOpen
}
func TestWalkerFunc_CheckFile(t *testing.T) {
emptyFS := fstest.MapFS{}
t.Run("non-existing", func(t *testing.T) {
_, ok, err := checkFile(emptyFS, nil, "lol")
require.NoError(t, err)
assert.True(t, ok)
})
t.Run("invalid_argument", func(t *testing.T) {
_, ok, err := checkFile(&errFS{}, nil, "")
require.ErrorIs(t, err, errErrFSOpen)
assert.False(t, ok)
})
t.Run("ignore_dirs", func(t *testing.T) {
const dirName = "dir"
testFS := fstest.MapFS{
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
}
patterns, ok, err := checkFile(testFS, nil, dirName)
require.NoError(t, err)
assert.Empty(t, patterns)
assert.True(t, ok)
})
}

View file

@ -1,20 +0,0 @@
package aghtest
import (
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
)
// Exchanger is a mock aghnet.Exchanger implementation for tests.
type Exchanger struct {
Ups upstream.Upstream
}
// Exchange implements aghnet.Exchanger interface for *Exchanger.
func (e *Exchanger) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
if e.Ups == nil {
e.Ups = &TestErrUpstream{}
}
return e.Ups.Exchange(req)
}

View file

@ -1,23 +0,0 @@
package aghtest
// FSWatcher is a mock aghos.FSWatcher implementation to use in tests.
type FSWatcher struct {
OnEvents func() (e <-chan struct{})
OnAdd func(name string) (err error)
OnClose func() (err error)
}
// Events implements the aghos.FSWatcher interface for *FSWatcher.
func (w *FSWatcher) Events() (e <-chan struct{}) {
return w.OnEvents()
}
// Add implements the aghos.FSWatcher interface for *FSWatcher.
func (w *FSWatcher) Add(name string) (err error) {
return w.OnAdd(name)
}
// Close implements the aghos.FSWatcher interface for *FSWatcher.
func (w *FSWatcher) Close() (err error) {
return w.OnClose()
}

View file

@ -0,0 +1,135 @@
package aghtest
import (
"io/fs"
"net"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
)
// Interface Mocks
//
// Keep entities in this file in alphabetic order.
// Standard Library
// type check
var _ fs.FS = &FS{}
// FS is a mock [fs.FS] implementation for tests.
type FS struct {
OnOpen func(name string) (fs.File, error)
}
// Open implements the [fs.FS] interface for *FS.
func (fsys *FS) Open(name string) (fs.File, error) {
return fsys.OnOpen(name)
}
// type check
var _ fs.GlobFS = &GlobFS{}
// GlobFS is a mock [fs.GlobFS] implementation for tests.
type GlobFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnGlob func(pattern string) ([]string, error)
}
// Glob implements the [fs.GlobFS] interface for *GlobFS.
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
return fsys.OnGlob(pattern)
}
// type check
var _ fs.StatFS = &StatFS{}
// StatFS is a mock [fs.StatFS] implementation for tests.
type StatFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnStat func(name string) (fs.FileInfo, error)
}
// Stat implements the [fs.StatFS] interface for *StatFS.
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
return fsys.OnStat(name)
}
// type check
var _ net.Listener = (*Listener)(nil)
// Listener is a mock [net.Listener] implementation for tests.
type Listener struct {
OnAccept func() (conn net.Conn, err error)
OnAddr func() (addr net.Addr)
OnClose func() (err error)
}
// Accept implements the [net.Listener] interface for *Listener.
func (l *Listener) Accept() (conn net.Conn, err error) {
return l.OnAccept()
}
// Addr implements the [net.Listener] interface for *Listener.
func (l *Listener) Addr() (addr net.Addr) {
return l.OnAddr()
}
// Close implements the [net.Listener] interface for *Listener.
func (l *Listener) Close() (err error) {
return l.OnClose()
}
// Module dnsproxy
// type check
var _ upstream.Upstream = (*UpstreamMock)(nil)
// UpstreamMock is a mock [upstream.Upstream] implementation for tests.
//
// TODO(a.garipov): Replace with all uses of Upstream with UpstreamMock and
// rename it to just Upstream.
type UpstreamMock struct {
OnAddress func() (addr string)
OnExchange func(req *dns.Msg) (resp *dns.Msg, err error)
}
// Address implements the [upstream.Upstream] interface for *UpstreamMock.
func (u *UpstreamMock) Address() (addr string) {
return u.OnAddress()
}
// Exchange implements the [upstream.Upstream] interface for *UpstreamMock.
func (u *UpstreamMock) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
return u.OnExchange(req)
}
// Module AdGuardHome
// type check
var _ aghos.FSWatcher = (*FSWatcher)(nil)
// FSWatcher is a mock [aghos.FSWatcher] implementation for tests.
type FSWatcher struct {
OnEvents func() (e <-chan struct{})
OnAdd func(name string) (err error)
OnClose func() (err error)
}
// Events implements the [aghos.FSWatcher] interface for *FSWatcher.
func (w *FSWatcher) Events() (e <-chan struct{}) {
return w.OnEvents()
}
// Add implements the [aghos.FSWatcher] interface for *FSWatcher.
func (w *FSWatcher) Add(name string) (err error) {
return w.OnAdd(name)
}
// Close implements the [aghos.FSWatcher] interface for *FSWatcher.
func (w *FSWatcher) Close() (err error) {
return w.OnClose()
}

View file

@ -0,0 +1,9 @@
package aghtest_test
import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
)
// type check
var _ aghos.FSWatcher = (*aghtest.FSWatcher)(nil)

View file

@ -1,46 +0,0 @@
package aghtest
import "io/fs"
// type check
var _ fs.FS = &FS{}
// FS is a mock fs.FS implementation to use in tests.
type FS struct {
OnOpen func(name string) (fs.File, error)
}
// Open implements the fs.FS interface for *FS.
func (fsys *FS) Open(name string) (fs.File, error) {
return fsys.OnOpen(name)
}
// type check
var _ fs.StatFS = &StatFS{}
// StatFS is a mock fs.StatFS implementation to use in tests.
type StatFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnStat func(name string) (fs.FileInfo, error)
}
// Stat implements the fs.StatFS interface for *StatFS.
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
return fsys.OnStat(name)
}
// type check
var _ fs.GlobFS = &GlobFS{}
// GlobFS is a mock fs.GlobFS implementation to use in tests.
type GlobFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnGlob func(pattern string) ([]string, error)
}
// Glob implements the fs.GlobFS interface for *GlobFS.
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
return fsys.OnGlob(pattern)
}

View file

@ -6,12 +6,18 @@ import (
"fmt" "fmt"
"net" "net"
"strings" "strings"
"sync" "testing"
"github.com/AdguardTeam/golibs/errors"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/require"
) )
// Additional Upstream Testing Utilities
// Upstream is a mock implementation of upstream.Upstream. // Upstream is a mock implementation of upstream.Upstream.
//
// TODO(a.garipov): Replace with UpstreamMock and rename it to just Upstream.
type Upstream struct { type Upstream struct {
// CName is a map of hostname to canonical name. // CName is a map of hostname to canonical name.
CName map[string][]string CName map[string][]string
@ -25,6 +31,43 @@ type Upstream struct {
Addr string Addr string
} }
// RespondTo returns a response with answer if req has class cl, question type
// qt, and target targ.
func RespondTo(t testing.TB, req *dns.Msg, cl, qt uint16, targ, answer string) (resp *dns.Msg) {
t.Helper()
require.NotNil(t, req)
require.Len(t, req.Question, 1)
q := req.Question[0]
targ = dns.Fqdn(targ)
if q.Qclass != cl || q.Qtype != qt || q.Name != targ {
return nil
}
respHdr := dns.RR_Header{
Name: targ,
Rrtype: qt,
Class: cl,
Ttl: 60,
}
resp = new(dns.Msg).SetReply(req)
switch qt {
case dns.TypePTR:
resp.Answer = []dns.RR{
&dns.PTR{
Hdr: respHdr,
Ptr: answer,
},
}
default:
t.Fatalf("unsupported question type: %s", dns.Type(qt))
}
return resp
}
// Exchange implements the upstream.Upstream interface for *Upstream. // Exchange implements the upstream.Upstream interface for *Upstream.
// //
// TODO(a.garipov): Split further into handlers. // TODO(a.garipov): Split further into handlers.
@ -76,74 +119,57 @@ func (u *Upstream) Address() string {
return u.Addr return u.Addr
} }
// TestBlockUpstream implements upstream.Upstream interface for replacing real // NewBlockUpstream returns an [*UpstreamMock] that works like an upstream that
// upstream in tests. // supports hash-based safe-browsing/adult-blocking feature. If shouldBlock is
type TestBlockUpstream struct { // true, hostname's actual hash is returned, blocking it. Otherwise, it returns
Hostname string // a different hash.
func NewBlockUpstream(hostname string, shouldBlock bool) (u *UpstreamMock) {
// lock protects reqNum. hash := sha256.Sum256([]byte(hostname))
lock sync.RWMutex hashStr := hex.EncodeToString(hash[:])
reqNum int if !shouldBlock {
hashStr = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
Block bool
}
// Exchange returns a message unique for TestBlockUpstream's Hostname-Block
// pair.
func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) {
u.lock.Lock()
defer u.lock.Unlock()
u.reqNum++
hash := sha256.Sum256([]byte(u.Hostname))
hashToReturn := hex.EncodeToString(hash[:])
if !u.Block {
hashToReturn = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
} }
m := &dns.Msg{} ans := &dns.TXT{
m.SetReply(r) Hdr: dns.RR_Header{
m.Answer = []dns.RR{ Name: "",
&dns.TXT{ Rrtype: dns.TypeTXT,
Hdr: dns.RR_Header{ Class: dns.ClassINET,
Name: r.Question[0].Name, Ttl: 60,
}, },
Txt: []string{ Txt: []string{hashStr},
hashToReturn, }
}, respTmpl := &dns.Msg{
Answer: []dns.RR{ans},
}
return &UpstreamMock{
OnAddress: func() (addr string) {
return "sbpc.upstream.example"
},
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = respTmpl.Copy()
resp.SetReply(req)
resp.Answer[0].(*dns.TXT).Hdr.Name = req.Question[0].Name
return resp, nil
}, },
} }
return m, nil
} }
// Address always returns an empty string. // ErrUpstream is the error returned from the [*UpstreamMock] created by
func (u *TestBlockUpstream) Address() string { // [NewErrorUpstream].
return "" const ErrUpstream errors.Error = "test upstream error"
}
// RequestsCount returns the number of handled requests. It's safe for // NewErrorUpstream returns an [*UpstreamMock] that returns [ErrUpstream] from
// concurrent use. // its Exchange method.
func (u *TestBlockUpstream) RequestsCount() int { func NewErrorUpstream() (u *UpstreamMock) {
u.lock.Lock() return &UpstreamMock{
defer u.lock.Unlock() OnAddress: func() (addr string) {
return "error.upstream.example"
return u.reqNum },
} OnExchange: func(_ *dns.Msg) (resp *dns.Msg, err error) {
return nil, errors.Error("test upstream error")
// TestErrUpstream implements upstream.Upstream interface for replacing real },
// upstream in tests. }
type TestErrUpstream struct {
// The error returned by Exchange may be unwrapped to the Err.
Err error
}
// Exchange always returns nil Msg and non-nil error.
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
return nil, fmt.Errorf("errupstream: %w", u.Err)
}
// Address always returns an empty string.
func (u *TestErrUpstream) Address() string {
return ""
} }

View file

@ -17,13 +17,13 @@ import (
"testing/fstest" "testing/fstest"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
@ -853,10 +853,7 @@ func TestBlockedByHosts(t *testing.T) {
func TestBlockedBySafeBrowsing(t *testing.T) { func TestBlockedBySafeBrowsing(t *testing.T) {
const hostname = "wmconvirus.narod.ru" const hostname = "wmconvirus.narod.ru"
sbUps := &aghtest.TestBlockUpstream{ sbUps := aghtest.NewBlockUpstream(hostname, true)
Hostname: hostname,
Block: true,
}
ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname) ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname)
filterConf := &filtering.Config{ filterConf := &filtering.Config{
@ -1029,7 +1026,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
s.conf.UDPListenAddrs = []*net.UDPAddr{{}} s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
s.conf.TCPListenAddrs = []*net.TCPAddr{{}} s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true s.conf.ProtectionEnabled = true
err = s.Prepare(nil) err = s.Prepare(nil)
require.NoError(t, err) require.NoError(t, err)
@ -1177,25 +1174,48 @@ func TestNewServer(t *testing.T) {
} }
func TestServer_Exchange(t *testing.T) { func TestServer_Exchange(t *testing.T) {
extUpstream := &aghtest.Upstream{ const (
Reverse: map[string][]string{ onesHost = "one.one.one.one"
"1.1.1.1.in-addr.arpa.": {"one.one.one.one"}, localDomainHost = "local.domain"
)
var (
onesIP = net.IP{1, 1, 1, 1}
localIP = net.IP{192, 168, 1, 1}
)
revExtIPv4, err := netutil.IPToReversedAddr(onesIP)
require.NoError(t, err)
extUpstream := &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "external.upstream.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = aghalg.Coalesce(
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revExtIPv4, onesHost),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
)
return resp, nil
}, },
} }
locUpstream := &aghtest.Upstream{
Reverse: map[string][]string{ revLocIPv4, err := netutil.IPToReversedAddr(localIP)
"1.1.168.192.in-addr.arpa.": {"local.domain"}, require.NoError(t, err)
"2.1.168.192.in-addr.arpa.": {},
locUpstream := &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "local.upstream.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = aghalg.Coalesce(
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revLocIPv4, localDomainHost),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
)
return resp, nil
}, },
} }
upstreamErr := errors.Error("upstream error")
errUpstream := &aghtest.TestErrUpstream{ errUpstream := aghtest.NewErrorUpstream()
Err: upstreamErr, nonPtrUpstream := aghtest.NewBlockUpstream("some-host", true)
}
nonPtrUpstream := &aghtest.TestBlockUpstream{
Hostname: "some-host",
Block: true,
}
srv := NewCustomServer(&proxy.Proxy{ srv := NewCustomServer(&proxy.Proxy{
Config: proxy.Config{ Config: proxy.Config{
@ -1209,7 +1229,6 @@ func TestServer_Exchange(t *testing.T) {
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed) srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
localIP := net.IP{192, 168, 1, 1}
testCases := []struct { testCases := []struct {
name string name string
want string want string
@ -1218,20 +1237,20 @@ func TestServer_Exchange(t *testing.T) {
req net.IP req net.IP
}{{ }{{
name: "external_good", name: "external_good",
want: "one.one.one.one", want: onesHost,
wantErr: nil, wantErr: nil,
locUpstream: nil, locUpstream: nil,
req: net.IP{1, 1, 1, 1}, req: onesIP,
}, { }, {
name: "local_good", name: "local_good",
want: "local.domain", want: localDomainHost,
wantErr: nil, wantErr: nil,
locUpstream: locUpstream, locUpstream: locUpstream,
req: localIP, req: localIP,
}, { }, {
name: "upstream_error", name: "upstream_error",
want: "", want: "",
wantErr: upstreamErr, wantErr: aghtest.ErrUpstream,
locUpstream: errUpstream, locUpstream: errUpstream,
req: localIP, req: localIP,
}, { }, {

View file

@ -21,6 +21,11 @@ func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m) aghtest.DiscardLogOutput(m)
} }
const (
sbBlocked = "wmconvirus.narod.ru"
pcBlocked = "pornhub.com"
)
var setts = Settings{ var setts = Settings{
ProtectionEnabled: true, ProtectionEnabled: true,
} }
@ -173,43 +178,37 @@ func TestSafeBrowsing(t *testing.T) {
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
const matching = "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
Hostname: matching,
Block: true,
})
d.checkMatch(t, matching)
require.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching) d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
d.checkMatch(t, sbBlocked)
d.checkMatch(t, "test."+matching) require.Contains(t, logOutput.String(), fmt.Sprintf("safebrowsing lookup for %q", sbBlocked))
d.checkMatch(t, "test."+sbBlocked)
d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "yandex.ru")
d.checkMatchEmpty(t, "pornhub.com") d.checkMatchEmpty(t, pcBlocked)
// Cached result. // Cached result.
d.safeBrowsingServer = "127.0.0.1" d.safeBrowsingServer = "127.0.0.1"
d.checkMatch(t, matching) d.checkMatch(t, sbBlocked)
d.checkMatchEmpty(t, "pornhub.com") d.checkMatchEmpty(t, pcBlocked)
d.safeBrowsingServer = defaultSafebrowsingServer d.safeBrowsingServer = defaultSafebrowsingServer
} }
func TestParallelSB(t *testing.T) { func TestParallelSB(t *testing.T) {
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
const matching = "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{ d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
Hostname: matching,
Block: true,
})
t.Run("group", func(t *testing.T) { t.Run("group", func(t *testing.T) {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) {
t.Parallel() t.Parallel()
d.checkMatch(t, matching) d.checkMatch(t, sbBlocked)
d.checkMatch(t, "test."+matching) d.checkMatch(t, "test."+sbBlocked)
d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "yandex.ru")
d.checkMatchEmpty(t, "pornhub.com") d.checkMatchEmpty(t, pcBlocked)
}) })
} }
}) })
@ -382,23 +381,19 @@ func TestParentalControl(t *testing.T) {
d := newForTest(t, &Config{ParentalEnabled: true}, nil) d := newForTest(t, &Config{ParentalEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
const matching = "pornhub.com"
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
Hostname: matching,
Block: true,
})
d.checkMatch(t, matching) d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
require.Contains(t, logOutput.String(), "Parental lookup for "+matching) d.checkMatch(t, pcBlocked)
require.Contains(t, logOutput.String(), fmt.Sprintf("parental lookup for %q", pcBlocked))
d.checkMatch(t, "www."+matching) d.checkMatch(t, "www."+pcBlocked)
d.checkMatchEmpty(t, "www.yandex.ru") d.checkMatchEmpty(t, "www.yandex.ru")
d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "yandex.ru")
d.checkMatchEmpty(t, "api.jquery.com") d.checkMatchEmpty(t, "api.jquery.com")
// Test cached result. // Test cached result.
d.parentalServer = "127.0.0.1" d.parentalServer = "127.0.0.1"
d.checkMatch(t, matching) d.checkMatch(t, pcBlocked)
d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "yandex.ru")
} }
@ -445,7 +440,7 @@ func TestMatching(t *testing.T) {
}, { }, {
name: "sanity", name: "sanity",
rules: "||doubleclick.net^", rules: "||doubleclick.net^",
host: "wmconvirus.narod.ru", host: sbBlocked,
wantIsFiltered: false, wantIsFiltered: false,
wantReason: NotFilteredNotFound, wantReason: NotFilteredNotFound,
wantDNSType: dns.TypeA, wantDNSType: dns.TypeA,
@ -765,14 +760,9 @@ func TestClientSettings(t *testing.T) {
}}, }},
) )
t.Cleanup(d.Close) t.Cleanup(d.Close)
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
Hostname: "pornhub.com", d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
Block: true, d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
})
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
Hostname: "wmconvirus.narod.ru",
Block: true,
})
type testCase struct { type testCase struct {
name string name string
@ -787,12 +777,12 @@ func TestClientSettings(t *testing.T) {
wantReason: FilteredBlockList, wantReason: FilteredBlockList,
}, { }, {
name: "parental", name: "parental",
host: "pornhub.com", host: pcBlocked,
before: true, before: true,
wantReason: FilteredParental, wantReason: FilteredParental,
}, { }, {
name: "safebrowsing", name: "safebrowsing",
host: "wmconvirus.narod.ru", host: sbBlocked,
before: false, before: false,
wantReason: FilteredSafeBrowsing, wantReason: FilteredSafeBrowsing,
}, { }, {
@ -836,33 +826,29 @@ func TestClientSettings(t *testing.T) {
func BenchmarkSafeBrowsing(b *testing.B) { func BenchmarkSafeBrowsing(b *testing.B) {
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil) d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
b.Cleanup(d.Close) b.Cleanup(d.Close)
blocked := "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{ d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
Hostname: blocked,
Block: true,
})
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
res, err := d.CheckHost(blocked, dns.TypeA, &setts) res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
require.NoError(b, err) require.NoError(b, err)
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked) assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
} }
} }
func BenchmarkSafeBrowsingParallel(b *testing.B) { func BenchmarkSafeBrowsingParallel(b *testing.B) {
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil) d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
b.Cleanup(d.Close) b.Cleanup(d.Close)
blocked := "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{ d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
Hostname: blocked,
Block: true,
})
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
res, err := d.CheckHost(blocked, dns.TypeA, &setts) res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
require.NoError(b, err) require.NoError(b, err)
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked) assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
} }
}) })
} }

View file

@ -314,7 +314,7 @@ func (d *DNSFilter) checkSafeBrowsing(
if log.GetLevel() >= log.DEBUG { if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer() timer := log.StartTimer()
defer timer.LogElapsed("SafeBrowsing lookup for %s", host) defer timer.LogElapsed("safebrowsing lookup for %q", host)
} }
sctx := &sbCtx{ sctx := &sbCtx{
@ -348,7 +348,7 @@ func (d *DNSFilter) checkParental(
if log.GetLevel() >= log.DEBUG { if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer() timer := log.StartTimer()
defer timer.LogElapsed("Parental lookup for %s", host) defer timer.LogElapsed("parental lookup for %q", host)
} }
sctx := &sbCtx{ sctx := &sbCtx{

View file

@ -74,21 +74,20 @@ func TestSafeBrowsingCache(t *testing.T) {
c.hashToHost[hash] = "sub.host.com" c.hashToHost[hash] = "sub.host.com"
assert.Equal(t, -1, c.getCached()) assert.Equal(t, -1, c.getCached())
// match "sub.host.com" from cache, // Match "sub.host.com" from cache. Another hash for "host.example" is not
// but another hash for "nonexisting.com" is not in cache // in the cache, so get data for it from the server.
// which means that we must get data from server for it
c.hashToHost = make(map[[32]byte]string) c.hashToHost = make(map[[32]byte]string)
hash = sha256.Sum256([]byte("sub.host.com")) hash = sha256.Sum256([]byte("sub.host.com"))
c.hashToHost[hash] = "sub.host.com" c.hashToHost[hash] = "sub.host.com"
hash = sha256.Sum256([]byte("nonexisting.com")) hash = sha256.Sum256([]byte("host.example"))
c.hashToHost[hash] = "nonexisting.com" c.hashToHost[hash] = "host.example"
assert.Empty(t, c.getCached()) assert.Empty(t, c.getCached())
hash = sha256.Sum256([]byte("sub.host.com")) hash = sha256.Sum256([]byte("sub.host.com"))
_, ok := c.hashToHost[hash] _, ok := c.hashToHost[hash]
assert.False(t, ok) assert.False(t, ok)
hash = sha256.Sum256([]byte("nonexisting.com")) hash = sha256.Sum256([]byte("host.example"))
_, ok = c.hashToHost[hash] _, ok = c.hashToHost[hash]
assert.True(t, ok) assert.True(t, ok)
@ -111,8 +110,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
ups := &aghtest.TestErrUpstream{} ups := aghtest.NewErrorUpstream()
d.SetSafeBrowsingUpstream(ups) d.SetSafeBrowsingUpstream(ups)
d.SetParentalUpstream(ups) d.SetParentalUpstream(ups)
@ -170,10 +168,16 @@ func TestSBPC(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
// Prepare the upstream. // Prepare the upstream.
ups := &aghtest.TestBlockUpstream{ ups := aghtest.NewBlockUpstream(hostname, tc.block)
Hostname: hostname,
Block: tc.block, var numReq int
onExchange := ups.OnExchange
ups.OnExchange = func(req *dns.Msg) (resp *dns.Msg, err error) {
numReq++
return onExchange(req)
} }
d.SetSafeBrowsingUpstream(ups) d.SetSafeBrowsingUpstream(ups)
d.SetParentalUpstream(ups) d.SetParentalUpstream(ups)
@ -196,7 +200,7 @@ func TestSBPC(t *testing.T) {
assert.Equal(t, hits, tc.testCache.Stats().Hit) assert.Equal(t, hits, tc.testCache.Stats().Hit)
// There was one request to an upstream. // There was one request to an upstream.
assert.Equal(t, 1, ups.RequestsCount()) assert.Equal(t, 1, numReq)
// Now make the same request to check the cache was used. // Now make the same request to check the cache was used.
res, err = tc.testFunc(hostname, dns.TypeA, setts) res, err = tc.testFunc(hostname, dns.TypeA, setts)
@ -214,7 +218,7 @@ func TestSBPC(t *testing.T) {
assert.Equal(t, hits+1, tc.testCache.Stats().Hit) assert.Equal(t, hits+1, tc.testCache.Stats().Hit)
// Check that there were no additional requests. // Check that there were no additional requests.
assert.Equal(t, 1, ups.RequestsCount()) assert.Equal(t, 1, numReq)
}) })
purgeCaches(d) purgeCaches(d)

View file

@ -3,15 +3,16 @@ package home
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"fmt"
"net" "net"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/stringutil"
@ -80,8 +81,10 @@ func TestRDNS_Begin(t *testing.T) {
binary.BigEndian.PutUint64(ttl, uint64(time.Now().Add(100*time.Hour).Unix())) binary.BigEndian.PutUint64(ttl, uint64(time.Now().Add(100*time.Hour).Unix()))
rdns := &RDNS{ rdns := &RDNS{
ipCache: ipCache, ipCache: ipCache,
exchanger: &rDNSExchanger{}, exchanger: &rDNSExchanger{
ex: aghtest.NewErrorUpstream(),
},
clients: &clientsContainer{ clients: &clientsContainer{
list: map[string]*Client{}, list: map[string]*Client{},
idIndex: tc.cliIDIndex, idIndex: tc.cliIDIndex,
@ -108,16 +111,22 @@ func TestRDNS_Begin(t *testing.T) {
// rDNSExchanger is a mock dnsforward.RDNSExchanger implementation for tests. // rDNSExchanger is a mock dnsforward.RDNSExchanger implementation for tests.
type rDNSExchanger struct { type rDNSExchanger struct {
ex aghtest.Exchanger ex upstream.Upstream
usePrivate bool usePrivate bool
} }
// Exchange implements dnsforward.RDNSExchanger interface for *RDNSExchanger. // Exchange implements dnsforward.RDNSExchanger interface for *RDNSExchanger.
func (e *rDNSExchanger) Exchange(ip net.IP) (host string, err error) { func (e *rDNSExchanger) Exchange(ip net.IP) (host string, err error) {
rev, err := netutil.IPToReversedAddr(ip)
if err != nil {
return "", fmt.Errorf("reversing ip: %w", err)
}
req := &dns.Msg{ req := &dns.Msg{
Question: []dns.Question{{ Question: []dns.Question{{
Name: ip.String(), Name: dns.Fqdn(rev),
Qtype: dns.TypePTR, Qclass: dns.ClassINET,
Qtype: dns.TypePTR,
}}, }},
} }
@ -146,7 +155,9 @@ func TestRDNS_ensurePrivateCache(t *testing.T) {
MaxCount: defaultRDNSCacheSize, MaxCount: defaultRDNSCacheSize,
}) })
ex := &rDNSExchanger{} ex := &rDNSExchanger{
ex: aghtest.NewErrorUpstream(),
}
rdns := &RDNS{ rdns := &RDNS{
ipCache: ipCache, ipCache: ipCache,
@ -167,15 +178,27 @@ func TestRDNS_WorkerLoop(t *testing.T) {
w := &bytes.Buffer{} w := &bytes.Buffer{}
aghtest.ReplaceLogWriter(t, w) aghtest.ReplaceLogWriter(t, w)
locUpstream := &aghtest.Upstream{ localIP := net.IP{192, 168, 1, 1}
Reverse: map[string][]string{ revIPv4, err := netutil.IPToReversedAddr(localIP)
"192.168.1.1": {"local.domain"}, require.NoError(t, err)
"2a00:1450:400c:c06::93": {"ipv6.domain"},
revIPv6, err := netutil.IPToReversedAddr(net.ParseIP("2a00:1450:400c:c06::93"))
require.NoError(t, err)
locUpstream := &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "local.upstream.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = aghalg.Coalesce(
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revIPv4, "local.domain"),
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revIPv6, "ipv6.domain"),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
)
return resp, nil
}, },
} }
errUpstream := &aghtest.TestErrUpstream{
Err: errors.Error("1234"), errUpstream := aghtest.NewErrorUpstream()
}
testCases := []struct { testCases := []struct {
ups upstream.Upstream ups upstream.Upstream
@ -186,10 +209,10 @@ func TestRDNS_WorkerLoop(t *testing.T) {
ups: locUpstream, ups: locUpstream,
wantLog: "", wantLog: "",
name: "all_good", name: "all_good",
cliIP: net.IP{192, 168, 1, 1}, cliIP: localIP,
}, { }, {
ups: errUpstream, ups: errUpstream,
wantLog: `rdns: resolving "192.168.1.2": errupstream: 1234`, wantLog: `rdns: resolving "192.168.1.2": test upstream error`,
name: "resolve_error", name: "resolve_error",
cliIP: net.IP{192, 168, 1, 2}, cliIP: net.IP{192, 168, 1, 2},
}, { }, {
@ -211,9 +234,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
ch := make(chan net.IP) ch := make(chan net.IP)
rdns := &RDNS{ rdns := &RDNS{
exchanger: &rDNSExchanger{ exchanger: &rDNSExchanger{
ex: aghtest.Exchanger{ ex: tc.ups,
Ups: tc.ups,
},
}, },
clients: cc, clients: cc,
ipCh: ch, ipCh: ch,

View file

@ -0,0 +1,193 @@
// Package dnssvc contains the AdGuard Home DNS service.
//
// TODO(a.garipov): Define, if all methods of a *Service should work with a nil
// receiver.
package dnssvc
import (
"context"
"fmt"
"net"
"net/netip"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/v1/agh"
// TODO(a.garipov): Add a “dnsproxy proxy” package to shield us from changes
// and replacement of module dnsproxy.
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
)
// Config is the AdGuard Home DNS service configuration structure.
//
// TODO(a.garipov): Add timeout for incoming requests.
type Config struct {
// Addresses are the addresses on which to serve plain DNS queries.
Addresses []netip.AddrPort
// Upstreams are the DNS upstreams to use. If not set, upstreams are
// created using data from BootstrapServers, UpstreamServers, and
// UpstreamTimeout.
//
// TODO(a.garipov): Think of a better scheme. Those other three parameters
// are here only to make Config work properly.
Upstreams []upstream.Upstream
// BootstrapServers are the addresses for bootstrapping the upstream DNS
// server addresses.
BootstrapServers []string
// UpstreamServers are the upstream DNS server addresses to use.
UpstreamServers []string
// UpstreamTimeout is the timeout for upstream requests.
UpstreamTimeout time.Duration
}
// Service is the AdGuard Home DNS service. A nil *Service is a valid
// [agh.Service] that does nothing.
type Service struct {
proxy *proxy.Proxy
bootstraps []string
upstreams []string
upsTimeout time.Duration
}
// New returns a new properly initialized *Service. If c is nil, svc is a nil
// *Service that does nothing. The fields of c must not be modified after
// calling New.
func New(c *Config) (svc *Service, err error) {
if c == nil {
return nil, nil
}
svc = &Service{
bootstraps: c.BootstrapServers,
upstreams: c.UpstreamServers,
upsTimeout: c.UpstreamTimeout,
}
var upstreams []upstream.Upstream
if len(c.Upstreams) > 0 {
upstreams = c.Upstreams
} else {
upstreams, err = addressesToUpstreams(
c.UpstreamServers,
c.BootstrapServers,
c.UpstreamTimeout,
)
if err != nil {
return nil, fmt.Errorf("converting upstreams: %w", err)
}
}
svc.proxy = &proxy.Proxy{
Config: proxy.Config{
UDPListenAddr: udpAddrs(c.Addresses),
TCPListenAddr: tcpAddrs(c.Addresses),
UpstreamConfig: &proxy.UpstreamConfig{
Upstreams: upstreams,
},
},
}
err = svc.proxy.Init()
if err != nil {
return nil, fmt.Errorf("proxy: %w", err)
}
return svc, nil
}
// addressesToUpstreams is a wrapper around [upstream.AddressToUpstream]. It
// accepts a slice of addresses and other upstream parameters, and returns a
// slice of upstreams.
func addressesToUpstreams(
upsStrs []string,
bootstraps []string,
timeout time.Duration,
) (upstreams []upstream.Upstream, err error) {
upstreams = make([]upstream.Upstream, len(upsStrs))
for i, upsStr := range upsStrs {
upstreams[i], err = upstream.AddressToUpstream(upsStr, &upstream.Options{
Bootstrap: bootstraps,
Timeout: timeout,
})
if err != nil {
return nil, fmt.Errorf("upstream at index %d: %w", i, err)
}
}
return upstreams, nil
}
// tcpAddrs converts []netip.AddrPort into []*net.TCPAddr.
func tcpAddrs(addrPorts []netip.AddrPort) (tcpAddrs []*net.TCPAddr) {
if addrPorts == nil {
return nil
}
tcpAddrs = make([]*net.TCPAddr, len(addrPorts))
for i, a := range addrPorts {
tcpAddrs[i] = net.TCPAddrFromAddrPort(a)
}
return tcpAddrs
}
// udpAddrs converts []netip.AddrPort into []*net.UDPAddr.
func udpAddrs(addrPorts []netip.AddrPort) (udpAddrs []*net.UDPAddr) {
if addrPorts == nil {
return nil
}
udpAddrs = make([]*net.UDPAddr, len(addrPorts))
for i, a := range addrPorts {
udpAddrs[i] = net.UDPAddrFromAddrPort(a)
}
return udpAddrs
}
// type check
var _ agh.Service = (*Service)(nil)
// Start implements the [agh.Service] interface for *Service. svc may be nil.
// After Start exits, all DNS servers have tried to start, but there is no
// guarantee that they did. Errors from the servers are written to the log.
func (svc *Service) Start() (err error) {
if svc == nil {
return nil
}
return svc.proxy.Start()
}
// Shutdown implements the [agh.Service] interface for *Service. svc may be
// nil.
func (svc *Service) Shutdown(ctx context.Context) (err error) {
if svc == nil {
return nil
}
return svc.proxy.Stop()
}
// Config returns the current configuration of the web service.
func (svc *Service) Config() (c *Config) {
// TODO(a.garipov): Do we need to get the TCP addresses separately?
udpAddrs := svc.proxy.Addrs(proxy.ProtoUDP)
addrs := make([]netip.AddrPort, len(udpAddrs))
for i, a := range udpAddrs {
addrs[i] = a.(*net.UDPAddr).AddrPort()
}
c = &Config{
Addresses: addrs,
BootstrapServers: svc.bootstraps,
UpstreamServers: svc.upstreams,
UpstreamTimeout: svc.upsTimeout,
}
return c
}

View file

@ -0,0 +1,89 @@
package dnssvc_test
import (
"context"
"net/netip"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/v1/dnssvc"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
}
// testTimeout is the common timeout for tests.
const testTimeout = 100 * time.Millisecond
func TestService(t *testing.T) {
const (
bootstrapAddr = "bootstrap.example"
upstreamAddr = "upstream.example"
)
ups := &aghtest.UpstreamMock{
OnAddress: func() (addr string) {
return upstreamAddr
},
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = (&dns.Msg{}).SetReply(req)
return resp, nil
},
}
c := &dnssvc.Config{
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
Upstreams: []upstream.Upstream{ups},
BootstrapServers: []string{bootstrapAddr},
UpstreamServers: []string{upstreamAddr},
UpstreamTimeout: testTimeout,
}
svc, err := dnssvc.New(c)
require.NoError(t, err)
err = svc.Start()
require.NoError(t, err)
gotConf := svc.Config()
require.NotNil(t, gotConf)
require.Len(t, gotConf.Addresses, 1)
addr := gotConf.Addresses[0]
t.Run("dns", func(t *testing.T) {
req := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: "example.com.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
cli := &dns.Client{}
resp, _, excErr := cli.ExchangeContext(ctx, req, addr.String())
require.NoError(t, excErr)
assert.NotNil(t, resp)
})
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
err = svc.Shutdown(ctx)
require.NoError(t, err)
}

View file

@ -40,8 +40,8 @@ type Config struct {
Timeout time.Duration Timeout time.Duration
} }
// Service is the AdGuard Home web service. A nil *Service is a valid service // Service is the AdGuard Home web service. A nil *Service is a valid
// that does nothing. // [agh.Service] that does nothing.
type Service struct { type Service struct {
tls *tls.Config tls *tls.Config
servers []*http.Server servers []*http.Server
@ -155,7 +155,7 @@ type unit = struct{}
// type check // type check
var _ agh.Service = (*Service)(nil) var _ agh.Service = (*Service)(nil)
// Start implements the agh.Service interface for *Service. svc may be nil. // Start implements the [agh.Service] interface for *Service. svc may be nil.
// After Start exits, all HTTP servers have tried to start, possibly failing and // After Start exits, all HTTP servers have tried to start, possibly failing and
// writing error messages to the log. // writing error messages to the log.
func (svc *Service) Start() (err error) { func (svc *Service) Start() (err error) {
@ -205,7 +205,8 @@ func serve(srv *http.Server, wg *sync.WaitGroup) {
} }
} }
// Shutdown implements the agh.Service interface for *Service. svc may be nil. // Shutdown implements the [agh.Service] interface for *Service. svc may be
// nil.
func (svc *Service) Shutdown(ctx context.Context) (err error) { func (svc *Service) Shutdown(ctx context.Context) (err error) {
if svc == nil { if svc == nil {
return nil return nil