mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-21 12:35:33 +03:00
Pull request #1558: add-dnssvc
Merge in DNS/adguard-home from add-dnssvc to master
Squashed commit of the following:
commit 55f4f114bab65a03c0d65383e89020a7356cff32
Merge: 95dc28d9 6e63757f
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date: Mon Aug 15 20:53:07 2022 +0300
Merge branch 'master' into add-dnssvc
commit 95dc28d9d77d06e8ac98c1e6772557bffbf1705b
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date: Mon Aug 15 20:52:50 2022 +0300
all: imp tests, docs
commit 0d9d02950d84afd160b4b1c118da856cee6f12e5
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date: Thu Aug 11 19:27:59 2022 +0300
all: imp docs
commit 8990e038a81da4430468da12fcebedf79fe14df6
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date: Thu Aug 11 19:05:29 2022 +0300
all: imp tests more
commit 92730d93a2a1ac77888c2655508e43efaf0e9fde
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date: Thu Aug 11 18:37:48 2022 +0300
all: imp tests more
commit 8cd45ba30da7ac310e9dc666fb2af438e577b02d
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date: Thu Aug 11 18:11:15 2022 +0300
all: add v1 dnssvc stub; refactor tests
This commit is contained in:
parent
6e63757fc7
commit
d4c3a43bcb
19 changed files with 742 additions and 319 deletions
|
@ -10,6 +10,20 @@ import (
|
|||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// Coalesce returns the first non-zero value. It is named after the function
|
||||
// COALESCE in SQL. If values or all its elements are empty, it returns a zero
|
||||
// value.
|
||||
func Coalesce[T comparable](values ...T) (res T) {
|
||||
var zero T
|
||||
for _, v := range values {
|
||||
if v != zero {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
return zero
|
||||
}
|
||||
|
||||
// UniqChecker allows validating uniqueness of comparable items.
|
||||
//
|
||||
// TODO(a.garipov): The Ordered constraint is only really necessary in Validate.
|
||||
|
|
|
@ -470,7 +470,7 @@ func TestHostsContainer(t *testing.T) {
|
|||
}},
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "nonexisting",
|
||||
Hostname: "nonexistent.example",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
name: "non-existing",
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package aghos
|
||||
package aghos_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
|
57
internal/aghos/filewalker_internal_test.go
Normal file
57
internal/aghos/filewalker_internal_test.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package aghos
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"path"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// errFS is an fs.FS implementation, method Open of which always returns
|
||||
// errFSOpen.
|
||||
type errFS struct{}
|
||||
|
||||
// errFSOpen is returned from errGlobFS.Open.
|
||||
const errFSOpen errors.Error = "test open error"
|
||||
|
||||
// Open implements the fs.FS interface for *errGlobFS. fsys is always nil and
|
||||
// err is always errFSOpen.
|
||||
func (efs *errFS) Open(name string) (fsys fs.File, err error) {
|
||||
return nil, errFSOpen
|
||||
}
|
||||
|
||||
func TestWalkerFunc_CheckFile(t *testing.T) {
|
||||
emptyFS := fstest.MapFS{}
|
||||
|
||||
t.Run("non-existing", func(t *testing.T) {
|
||||
_, ok, err := checkFile(emptyFS, nil, "lol")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("invalid_argument", func(t *testing.T) {
|
||||
_, ok, err := checkFile(&errFS{}, nil, "")
|
||||
require.ErrorIs(t, err, errFSOpen)
|
||||
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("ignore_dirs", func(t *testing.T) {
|
||||
const dirName = "dir"
|
||||
|
||||
testFS := fstest.MapFS{
|
||||
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
|
||||
}
|
||||
|
||||
patterns, ok, err := checkFile(testFS, nil, dirName)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Empty(t, patterns)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
|
@ -1,13 +1,13 @@
|
|||
package aghos
|
||||
package aghos_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"io/fs"
|
||||
"path"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -16,7 +16,7 @@ import (
|
|||
func TestFileWalker_Walk(t *testing.T) {
|
||||
const attribute = `000`
|
||||
|
||||
makeFileWalker := func(_ string) (fw FileWalker) {
|
||||
makeFileWalker := func(_ string) (fw aghos.FileWalker) {
|
||||
return func(r io.Reader) (patterns []string, cont bool, err error) {
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
|
@ -113,7 +113,7 @@ func TestFileWalker_Walk(t *testing.T) {
|
|||
f := fstest.MapFS{
|
||||
filename: &fstest.MapFile{Data: []byte("[]")},
|
||||
}
|
||||
ok, err := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
|
||||
ok, err := aghos.FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
patterns = append(patterns, s.Text())
|
||||
|
@ -134,7 +134,7 @@ func TestFileWalker_Walk(t *testing.T) {
|
|||
"mockfile.txt": &fstest.MapFile{Data: []byte(`mockdata`)},
|
||||
}
|
||||
|
||||
ok, err := FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
|
||||
ok, err := aghos.FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
|
||||
return nil, true, rerr
|
||||
}).Walk(f, "*")
|
||||
require.ErrorIs(t, err, rerr)
|
||||
|
@ -142,45 +142,3 @@ func TestFileWalker_Walk(t *testing.T) {
|
|||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
type errFS struct {
|
||||
fs.GlobFS
|
||||
}
|
||||
|
||||
const errErrFSOpen errors.Error = "this error is always returned"
|
||||
|
||||
func (efs *errFS) Open(name string) (fs.File, error) {
|
||||
return nil, errErrFSOpen
|
||||
}
|
||||
|
||||
func TestWalkerFunc_CheckFile(t *testing.T) {
|
||||
emptyFS := fstest.MapFS{}
|
||||
|
||||
t.Run("non-existing", func(t *testing.T) {
|
||||
_, ok, err := checkFile(emptyFS, nil, "lol")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("invalid_argument", func(t *testing.T) {
|
||||
_, ok, err := checkFile(&errFS{}, nil, "")
|
||||
require.ErrorIs(t, err, errErrFSOpen)
|
||||
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("ignore_dirs", func(t *testing.T) {
|
||||
const dirName = "dir"
|
||||
|
||||
testFS := fstest.MapFS{
|
||||
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
|
||||
}
|
||||
|
||||
patterns, ok, err := checkFile(testFS, nil, dirName)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Empty(t, patterns)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,20 +0,0 @@
|
|||
package aghtest
|
||||
|
||||
import (
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Exchanger is a mock aghnet.Exchanger implementation for tests.
|
||||
type Exchanger struct {
|
||||
Ups upstream.Upstream
|
||||
}
|
||||
|
||||
// Exchange implements aghnet.Exchanger interface for *Exchanger.
|
||||
func (e *Exchanger) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
if e.Ups == nil {
|
||||
e.Ups = &TestErrUpstream{}
|
||||
}
|
||||
|
||||
return e.Ups.Exchange(req)
|
||||
}
|
|
@ -1,23 +0,0 @@
|
|||
package aghtest
|
||||
|
||||
// FSWatcher is a mock aghos.FSWatcher implementation to use in tests.
|
||||
type FSWatcher struct {
|
||||
OnEvents func() (e <-chan struct{})
|
||||
OnAdd func(name string) (err error)
|
||||
OnClose func() (err error)
|
||||
}
|
||||
|
||||
// Events implements the aghos.FSWatcher interface for *FSWatcher.
|
||||
func (w *FSWatcher) Events() (e <-chan struct{}) {
|
||||
return w.OnEvents()
|
||||
}
|
||||
|
||||
// Add implements the aghos.FSWatcher interface for *FSWatcher.
|
||||
func (w *FSWatcher) Add(name string) (err error) {
|
||||
return w.OnAdd(name)
|
||||
}
|
||||
|
||||
// Close implements the aghos.FSWatcher interface for *FSWatcher.
|
||||
func (w *FSWatcher) Close() (err error) {
|
||||
return w.OnClose()
|
||||
}
|
135
internal/aghtest/interface.go
Normal file
135
internal/aghtest/interface.go
Normal file
|
@ -0,0 +1,135 @@
|
|||
package aghtest
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Interface Mocks
|
||||
//
|
||||
// Keep entities in this file in alphabetic order.
|
||||
|
||||
// Standard Library
|
||||
|
||||
// type check
|
||||
var _ fs.FS = &FS{}
|
||||
|
||||
// FS is a mock [fs.FS] implementation for tests.
|
||||
type FS struct {
|
||||
OnOpen func(name string) (fs.File, error)
|
||||
}
|
||||
|
||||
// Open implements the [fs.FS] interface for *FS.
|
||||
func (fsys *FS) Open(name string) (fs.File, error) {
|
||||
return fsys.OnOpen(name)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.GlobFS = &GlobFS{}
|
||||
|
||||
// GlobFS is a mock [fs.GlobFS] implementation for tests.
|
||||
type GlobFS struct {
|
||||
// FS is embedded here to avoid implementing all it's methods.
|
||||
FS
|
||||
OnGlob func(pattern string) ([]string, error)
|
||||
}
|
||||
|
||||
// Glob implements the [fs.GlobFS] interface for *GlobFS.
|
||||
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
|
||||
return fsys.OnGlob(pattern)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.StatFS = &StatFS{}
|
||||
|
||||
// StatFS is a mock [fs.StatFS] implementation for tests.
|
||||
type StatFS struct {
|
||||
// FS is embedded here to avoid implementing all it's methods.
|
||||
FS
|
||||
OnStat func(name string) (fs.FileInfo, error)
|
||||
}
|
||||
|
||||
// Stat implements the [fs.StatFS] interface for *StatFS.
|
||||
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
|
||||
return fsys.OnStat(name)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ net.Listener = (*Listener)(nil)
|
||||
|
||||
// Listener is a mock [net.Listener] implementation for tests.
|
||||
type Listener struct {
|
||||
OnAccept func() (conn net.Conn, err error)
|
||||
OnAddr func() (addr net.Addr)
|
||||
OnClose func() (err error)
|
||||
}
|
||||
|
||||
// Accept implements the [net.Listener] interface for *Listener.
|
||||
func (l *Listener) Accept() (conn net.Conn, err error) {
|
||||
return l.OnAccept()
|
||||
}
|
||||
|
||||
// Addr implements the [net.Listener] interface for *Listener.
|
||||
func (l *Listener) Addr() (addr net.Addr) {
|
||||
return l.OnAddr()
|
||||
}
|
||||
|
||||
// Close implements the [net.Listener] interface for *Listener.
|
||||
func (l *Listener) Close() (err error) {
|
||||
return l.OnClose()
|
||||
}
|
||||
|
||||
// Module dnsproxy
|
||||
|
||||
// type check
|
||||
var _ upstream.Upstream = (*UpstreamMock)(nil)
|
||||
|
||||
// UpstreamMock is a mock [upstream.Upstream] implementation for tests.
|
||||
//
|
||||
// TODO(a.garipov): Replace with all uses of Upstream with UpstreamMock and
|
||||
// rename it to just Upstream.
|
||||
type UpstreamMock struct {
|
||||
OnAddress func() (addr string)
|
||||
OnExchange func(req *dns.Msg) (resp *dns.Msg, err error)
|
||||
}
|
||||
|
||||
// Address implements the [upstream.Upstream] interface for *UpstreamMock.
|
||||
func (u *UpstreamMock) Address() (addr string) {
|
||||
return u.OnAddress()
|
||||
}
|
||||
|
||||
// Exchange implements the [upstream.Upstream] interface for *UpstreamMock.
|
||||
func (u *UpstreamMock) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return u.OnExchange(req)
|
||||
}
|
||||
|
||||
// Module AdGuardHome
|
||||
|
||||
// type check
|
||||
var _ aghos.FSWatcher = (*FSWatcher)(nil)
|
||||
|
||||
// FSWatcher is a mock [aghos.FSWatcher] implementation for tests.
|
||||
type FSWatcher struct {
|
||||
OnEvents func() (e <-chan struct{})
|
||||
OnAdd func(name string) (err error)
|
||||
OnClose func() (err error)
|
||||
}
|
||||
|
||||
// Events implements the [aghos.FSWatcher] interface for *FSWatcher.
|
||||
func (w *FSWatcher) Events() (e <-chan struct{}) {
|
||||
return w.OnEvents()
|
||||
}
|
||||
|
||||
// Add implements the [aghos.FSWatcher] interface for *FSWatcher.
|
||||
func (w *FSWatcher) Add(name string) (err error) {
|
||||
return w.OnAdd(name)
|
||||
}
|
||||
|
||||
// Close implements the [aghos.FSWatcher] interface for *FSWatcher.
|
||||
func (w *FSWatcher) Close() (err error) {
|
||||
return w.OnClose()
|
||||
}
|
9
internal/aghtest/interface_test.go
Normal file
9
internal/aghtest/interface_test.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
package aghtest_test
|
||||
|
||||
import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
)
|
||||
|
||||
// type check
|
||||
var _ aghos.FSWatcher = (*aghtest.FSWatcher)(nil)
|
|
@ -1,46 +0,0 @@
|
|||
package aghtest
|
||||
|
||||
import "io/fs"
|
||||
|
||||
// type check
|
||||
var _ fs.FS = &FS{}
|
||||
|
||||
// FS is a mock fs.FS implementation to use in tests.
|
||||
type FS struct {
|
||||
OnOpen func(name string) (fs.File, error)
|
||||
}
|
||||
|
||||
// Open implements the fs.FS interface for *FS.
|
||||
func (fsys *FS) Open(name string) (fs.File, error) {
|
||||
return fsys.OnOpen(name)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.StatFS = &StatFS{}
|
||||
|
||||
// StatFS is a mock fs.StatFS implementation to use in tests.
|
||||
type StatFS struct {
|
||||
// FS is embedded here to avoid implementing all it's methods.
|
||||
FS
|
||||
OnStat func(name string) (fs.FileInfo, error)
|
||||
}
|
||||
|
||||
// Stat implements the fs.StatFS interface for *StatFS.
|
||||
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
|
||||
return fsys.OnStat(name)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.GlobFS = &GlobFS{}
|
||||
|
||||
// GlobFS is a mock fs.GlobFS implementation to use in tests.
|
||||
type GlobFS struct {
|
||||
// FS is embedded here to avoid implementing all it's methods.
|
||||
FS
|
||||
OnGlob func(pattern string) ([]string, error)
|
||||
}
|
||||
|
||||
// Glob implements the fs.GlobFS interface for *GlobFS.
|
||||
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
|
||||
return fsys.OnGlob(pattern)
|
||||
}
|
|
@ -6,12 +6,18 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Additional Upstream Testing Utilities
|
||||
|
||||
// Upstream is a mock implementation of upstream.Upstream.
|
||||
//
|
||||
// TODO(a.garipov): Replace with UpstreamMock and rename it to just Upstream.
|
||||
type Upstream struct {
|
||||
// CName is a map of hostname to canonical name.
|
||||
CName map[string][]string
|
||||
|
@ -25,6 +31,43 @@ type Upstream struct {
|
|||
Addr string
|
||||
}
|
||||
|
||||
// RespondTo returns a response with answer if req has class cl, question type
|
||||
// qt, and target targ.
|
||||
func RespondTo(t testing.TB, req *dns.Msg, cl, qt uint16, targ, answer string) (resp *dns.Msg) {
|
||||
t.Helper()
|
||||
|
||||
require.NotNil(t, req)
|
||||
require.Len(t, req.Question, 1)
|
||||
|
||||
q := req.Question[0]
|
||||
targ = dns.Fqdn(targ)
|
||||
if q.Qclass != cl || q.Qtype != qt || q.Name != targ {
|
||||
return nil
|
||||
}
|
||||
|
||||
respHdr := dns.RR_Header{
|
||||
Name: targ,
|
||||
Rrtype: qt,
|
||||
Class: cl,
|
||||
Ttl: 60,
|
||||
}
|
||||
|
||||
resp = new(dns.Msg).SetReply(req)
|
||||
switch qt {
|
||||
case dns.TypePTR:
|
||||
resp.Answer = []dns.RR{
|
||||
&dns.PTR{
|
||||
Hdr: respHdr,
|
||||
Ptr: answer,
|
||||
},
|
||||
}
|
||||
default:
|
||||
t.Fatalf("unsupported question type: %s", dns.Type(qt))
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// Exchange implements the upstream.Upstream interface for *Upstream.
|
||||
//
|
||||
// TODO(a.garipov): Split further into handlers.
|
||||
|
@ -76,74 +119,57 @@ func (u *Upstream) Address() string {
|
|||
return u.Addr
|
||||
}
|
||||
|
||||
// TestBlockUpstream implements upstream.Upstream interface for replacing real
|
||||
// upstream in tests.
|
||||
type TestBlockUpstream struct {
|
||||
Hostname string
|
||||
|
||||
// lock protects reqNum.
|
||||
lock sync.RWMutex
|
||||
reqNum int
|
||||
|
||||
Block bool
|
||||
}
|
||||
|
||||
// Exchange returns a message unique for TestBlockUpstream's Hostname-Block
|
||||
// pair.
|
||||
func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
u.reqNum++
|
||||
|
||||
hash := sha256.Sum256([]byte(u.Hostname))
|
||||
hashToReturn := hex.EncodeToString(hash[:])
|
||||
if !u.Block {
|
||||
hashToReturn = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
|
||||
// NewBlockUpstream returns an [*UpstreamMock] that works like an upstream that
|
||||
// supports hash-based safe-browsing/adult-blocking feature. If shouldBlock is
|
||||
// true, hostname's actual hash is returned, blocking it. Otherwise, it returns
|
||||
// a different hash.
|
||||
func NewBlockUpstream(hostname string, shouldBlock bool) (u *UpstreamMock) {
|
||||
hash := sha256.Sum256([]byte(hostname))
|
||||
hashStr := hex.EncodeToString(hash[:])
|
||||
if !shouldBlock {
|
||||
hashStr = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
|
||||
}
|
||||
|
||||
m := &dns.Msg{}
|
||||
m.SetReply(r)
|
||||
m.Answer = []dns.RR{
|
||||
&dns.TXT{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: r.Question[0].Name,
|
||||
},
|
||||
Txt: []string{
|
||||
hashToReturn,
|
||||
},
|
||||
ans := &dns.TXT{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "",
|
||||
Rrtype: dns.TypeTXT,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
},
|
||||
Txt: []string{hashStr},
|
||||
}
|
||||
respTmpl := &dns.Msg{
|
||||
Answer: []dns.RR{ans},
|
||||
}
|
||||
|
||||
return &UpstreamMock{
|
||||
OnAddress: func() (addr string) {
|
||||
return "sbpc.upstream.example"
|
||||
},
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = respTmpl.Copy()
|
||||
resp.SetReply(req)
|
||||
resp.Answer[0].(*dns.TXT).Hdr.Name = req.Question[0].Name
|
||||
|
||||
return resp, nil
|
||||
},
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Address always returns an empty string.
|
||||
func (u *TestBlockUpstream) Address() string {
|
||||
return ""
|
||||
}
|
||||
// ErrUpstream is the error returned from the [*UpstreamMock] created by
|
||||
// [NewErrorUpstream].
|
||||
const ErrUpstream errors.Error = "test upstream error"
|
||||
|
||||
// RequestsCount returns the number of handled requests. It's safe for
|
||||
// concurrent use.
|
||||
func (u *TestBlockUpstream) RequestsCount() int {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
return u.reqNum
|
||||
}
|
||||
|
||||
// TestErrUpstream implements upstream.Upstream interface for replacing real
|
||||
// upstream in tests.
|
||||
type TestErrUpstream struct {
|
||||
// The error returned by Exchange may be unwrapped to the Err.
|
||||
Err error
|
||||
}
|
||||
|
||||
// Exchange always returns nil Msg and non-nil error.
|
||||
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
|
||||
return nil, fmt.Errorf("errupstream: %w", u.Err)
|
||||
}
|
||||
|
||||
// Address always returns an empty string.
|
||||
func (u *TestErrUpstream) Address() string {
|
||||
return ""
|
||||
// NewErrorUpstream returns an [*UpstreamMock] that returns [ErrUpstream] from
|
||||
// its Exchange method.
|
||||
func NewErrorUpstream() (u *UpstreamMock) {
|
||||
return &UpstreamMock{
|
||||
OnAddress: func() (addr string) {
|
||||
return "error.upstream.example"
|
||||
},
|
||||
OnExchange: func(_ *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return nil, errors.Error("test upstream error")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,13 +17,13 @@ import (
|
|||
"testing/fstest"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
|
@ -853,10 +853,7 @@ func TestBlockedByHosts(t *testing.T) {
|
|||
func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
const hostname = "wmconvirus.narod.ru"
|
||||
|
||||
sbUps := &aghtest.TestBlockUpstream{
|
||||
Hostname: hostname,
|
||||
Block: true,
|
||||
}
|
||||
sbUps := aghtest.NewBlockUpstream(hostname, true)
|
||||
ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname)
|
||||
|
||||
filterConf := &filtering.Config{
|
||||
|
@ -1029,7 +1026,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
|||
s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
|
||||
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
|
||||
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
|
||||
s.conf.FilteringConfig.ProtectionEnabled = true
|
||||
s.conf.ProtectionEnabled = true
|
||||
|
||||
err = s.Prepare(nil)
|
||||
require.NoError(t, err)
|
||||
|
@ -1177,25 +1174,48 @@ func TestNewServer(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestServer_Exchange(t *testing.T) {
|
||||
extUpstream := &aghtest.Upstream{
|
||||
Reverse: map[string][]string{
|
||||
"1.1.1.1.in-addr.arpa.": {"one.one.one.one"},
|
||||
const (
|
||||
onesHost = "one.one.one.one"
|
||||
localDomainHost = "local.domain"
|
||||
)
|
||||
|
||||
var (
|
||||
onesIP = net.IP{1, 1, 1, 1}
|
||||
localIP = net.IP{192, 168, 1, 1}
|
||||
)
|
||||
|
||||
revExtIPv4, err := netutil.IPToReversedAddr(onesIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
extUpstream := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "external.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = aghalg.Coalesce(
|
||||
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revExtIPv4, onesHost),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
return resp, nil
|
||||
},
|
||||
}
|
||||
locUpstream := &aghtest.Upstream{
|
||||
Reverse: map[string][]string{
|
||||
"1.1.168.192.in-addr.arpa.": {"local.domain"},
|
||||
"2.1.168.192.in-addr.arpa.": {},
|
||||
|
||||
revLocIPv4, err := netutil.IPToReversedAddr(localIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
locUpstream := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "local.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = aghalg.Coalesce(
|
||||
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revLocIPv4, localDomainHost),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
return resp, nil
|
||||
},
|
||||
}
|
||||
upstreamErr := errors.Error("upstream error")
|
||||
errUpstream := &aghtest.TestErrUpstream{
|
||||
Err: upstreamErr,
|
||||
}
|
||||
nonPtrUpstream := &aghtest.TestBlockUpstream{
|
||||
Hostname: "some-host",
|
||||
Block: true,
|
||||
}
|
||||
|
||||
errUpstream := aghtest.NewErrorUpstream()
|
||||
nonPtrUpstream := aghtest.NewBlockUpstream("some-host", true)
|
||||
|
||||
srv := NewCustomServer(&proxy.Proxy{
|
||||
Config: proxy.Config{
|
||||
|
@ -1209,7 +1229,6 @@ func TestServer_Exchange(t *testing.T) {
|
|||
|
||||
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
|
||||
localIP := net.IP{192, 168, 1, 1}
|
||||
testCases := []struct {
|
||||
name string
|
||||
want string
|
||||
|
@ -1218,20 +1237,20 @@ func TestServer_Exchange(t *testing.T) {
|
|||
req net.IP
|
||||
}{{
|
||||
name: "external_good",
|
||||
want: "one.one.one.one",
|
||||
want: onesHost,
|
||||
wantErr: nil,
|
||||
locUpstream: nil,
|
||||
req: net.IP{1, 1, 1, 1},
|
||||
req: onesIP,
|
||||
}, {
|
||||
name: "local_good",
|
||||
want: "local.domain",
|
||||
want: localDomainHost,
|
||||
wantErr: nil,
|
||||
locUpstream: locUpstream,
|
||||
req: localIP,
|
||||
}, {
|
||||
name: "upstream_error",
|
||||
want: "",
|
||||
wantErr: upstreamErr,
|
||||
wantErr: aghtest.ErrUpstream,
|
||||
locUpstream: errUpstream,
|
||||
req: localIP,
|
||||
}, {
|
||||
|
|
|
@ -21,6 +21,11 @@ func TestMain(m *testing.M) {
|
|||
aghtest.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
const (
|
||||
sbBlocked = "wmconvirus.narod.ru"
|
||||
pcBlocked = "pornhub.com"
|
||||
)
|
||||
|
||||
var setts = Settings{
|
||||
ProtectionEnabled: true,
|
||||
}
|
||||
|
@ -173,43 +178,37 @@ func TestSafeBrowsing(t *testing.T) {
|
|||
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
const matching = "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: matching,
|
||||
Block: true,
|
||||
})
|
||||
d.checkMatch(t, matching)
|
||||
|
||||
require.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching)
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
d.checkMatch(t, sbBlocked)
|
||||
|
||||
d.checkMatch(t, "test."+matching)
|
||||
require.Contains(t, logOutput.String(), fmt.Sprintf("safebrowsing lookup for %q", sbBlocked))
|
||||
|
||||
d.checkMatch(t, "test."+sbBlocked)
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
d.checkMatchEmpty(t, "pornhub.com")
|
||||
d.checkMatchEmpty(t, pcBlocked)
|
||||
|
||||
// Cached result.
|
||||
d.safeBrowsingServer = "127.0.0.1"
|
||||
d.checkMatch(t, matching)
|
||||
d.checkMatchEmpty(t, "pornhub.com")
|
||||
d.checkMatch(t, sbBlocked)
|
||||
d.checkMatchEmpty(t, pcBlocked)
|
||||
d.safeBrowsingServer = defaultSafebrowsingServer
|
||||
}
|
||||
|
||||
func TestParallelSB(t *testing.T) {
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
const matching = "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: matching,
|
||||
Block: true,
|
||||
})
|
||||
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
|
||||
t.Run("group", func(t *testing.T) {
|
||||
for i := 0; i < 100; i++ {
|
||||
t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
d.checkMatch(t, matching)
|
||||
d.checkMatch(t, "test."+matching)
|
||||
d.checkMatch(t, sbBlocked)
|
||||
d.checkMatch(t, "test."+sbBlocked)
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
d.checkMatchEmpty(t, "pornhub.com")
|
||||
d.checkMatchEmpty(t, pcBlocked)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
@ -382,23 +381,19 @@ func TestParentalControl(t *testing.T) {
|
|||
|
||||
d := newForTest(t, &Config{ParentalEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
const matching = "pornhub.com"
|
||||
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: matching,
|
||||
Block: true,
|
||||
})
|
||||
|
||||
d.checkMatch(t, matching)
|
||||
require.Contains(t, logOutput.String(), "Parental lookup for "+matching)
|
||||
d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
|
||||
d.checkMatch(t, pcBlocked)
|
||||
require.Contains(t, logOutput.String(), fmt.Sprintf("parental lookup for %q", pcBlocked))
|
||||
|
||||
d.checkMatch(t, "www."+matching)
|
||||
d.checkMatch(t, "www."+pcBlocked)
|
||||
d.checkMatchEmpty(t, "www.yandex.ru")
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
d.checkMatchEmpty(t, "api.jquery.com")
|
||||
|
||||
// Test cached result.
|
||||
d.parentalServer = "127.0.0.1"
|
||||
d.checkMatch(t, matching)
|
||||
d.checkMatch(t, pcBlocked)
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
}
|
||||
|
||||
|
@ -445,7 +440,7 @@ func TestMatching(t *testing.T) {
|
|||
}, {
|
||||
name: "sanity",
|
||||
rules: "||doubleclick.net^",
|
||||
host: "wmconvirus.narod.ru",
|
||||
host: sbBlocked,
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
|
@ -765,14 +760,9 @@ func TestClientSettings(t *testing.T) {
|
|||
}},
|
||||
)
|
||||
t.Cleanup(d.Close)
|
||||
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: "pornhub.com",
|
||||
Block: true,
|
||||
})
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: "wmconvirus.narod.ru",
|
||||
Block: true,
|
||||
})
|
||||
|
||||
d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
|
@ -787,12 +777,12 @@ func TestClientSettings(t *testing.T) {
|
|||
wantReason: FilteredBlockList,
|
||||
}, {
|
||||
name: "parental",
|
||||
host: "pornhub.com",
|
||||
host: pcBlocked,
|
||||
before: true,
|
||||
wantReason: FilteredParental,
|
||||
}, {
|
||||
name: "safebrowsing",
|
||||
host: "wmconvirus.narod.ru",
|
||||
host: sbBlocked,
|
||||
before: false,
|
||||
wantReason: FilteredSafeBrowsing,
|
||||
}, {
|
||||
|
@ -836,33 +826,29 @@ func TestClientSettings(t *testing.T) {
|
|||
func BenchmarkSafeBrowsing(b *testing.B) {
|
||||
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
blocked := "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: blocked,
|
||||
Block: true,
|
||||
})
|
||||
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
|
||||
require.NoError(b, err)
|
||||
|
||||
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
|
||||
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
blocked := "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: blocked,
|
||||
Block: true,
|
||||
})
|
||||
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
|
||||
require.NoError(b, err)
|
||||
|
||||
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
|
||||
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -314,7 +314,7 @@ func (d *DNSFilter) checkSafeBrowsing(
|
|||
|
||||
if log.GetLevel() >= log.DEBUG {
|
||||
timer := log.StartTimer()
|
||||
defer timer.LogElapsed("SafeBrowsing lookup for %s", host)
|
||||
defer timer.LogElapsed("safebrowsing lookup for %q", host)
|
||||
}
|
||||
|
||||
sctx := &sbCtx{
|
||||
|
@ -348,7 +348,7 @@ func (d *DNSFilter) checkParental(
|
|||
|
||||
if log.GetLevel() >= log.DEBUG {
|
||||
timer := log.StartTimer()
|
||||
defer timer.LogElapsed("Parental lookup for %s", host)
|
||||
defer timer.LogElapsed("parental lookup for %q", host)
|
||||
}
|
||||
|
||||
sctx := &sbCtx{
|
||||
|
|
|
@ -74,21 +74,20 @@ func TestSafeBrowsingCache(t *testing.T) {
|
|||
c.hashToHost[hash] = "sub.host.com"
|
||||
assert.Equal(t, -1, c.getCached())
|
||||
|
||||
// match "sub.host.com" from cache,
|
||||
// but another hash for "nonexisting.com" is not in cache
|
||||
// which means that we must get data from server for it
|
||||
// Match "sub.host.com" from cache. Another hash for "host.example" is not
|
||||
// in the cache, so get data for it from the server.
|
||||
c.hashToHost = make(map[[32]byte]string)
|
||||
hash = sha256.Sum256([]byte("sub.host.com"))
|
||||
c.hashToHost[hash] = "sub.host.com"
|
||||
hash = sha256.Sum256([]byte("nonexisting.com"))
|
||||
c.hashToHost[hash] = "nonexisting.com"
|
||||
hash = sha256.Sum256([]byte("host.example"))
|
||||
c.hashToHost[hash] = "host.example"
|
||||
assert.Empty(t, c.getCached())
|
||||
|
||||
hash = sha256.Sum256([]byte("sub.host.com"))
|
||||
_, ok := c.hashToHost[hash]
|
||||
assert.False(t, ok)
|
||||
|
||||
hash = sha256.Sum256([]byte("nonexisting.com"))
|
||||
hash = sha256.Sum256([]byte("host.example"))
|
||||
_, ok = c.hashToHost[hash]
|
||||
assert.True(t, ok)
|
||||
|
||||
|
@ -111,8 +110,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
|
|||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
ups := &aghtest.TestErrUpstream{}
|
||||
|
||||
ups := aghtest.NewErrorUpstream()
|
||||
d.SetSafeBrowsingUpstream(ups)
|
||||
d.SetParentalUpstream(ups)
|
||||
|
||||
|
@ -170,10 +168,16 @@ func TestSBPC(t *testing.T) {
|
|||
|
||||
for _, tc := range testCases {
|
||||
// Prepare the upstream.
|
||||
ups := &aghtest.TestBlockUpstream{
|
||||
Hostname: hostname,
|
||||
Block: tc.block,
|
||||
ups := aghtest.NewBlockUpstream(hostname, tc.block)
|
||||
|
||||
var numReq int
|
||||
onExchange := ups.OnExchange
|
||||
ups.OnExchange = func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
numReq++
|
||||
|
||||
return onExchange(req)
|
||||
}
|
||||
|
||||
d.SetSafeBrowsingUpstream(ups)
|
||||
d.SetParentalUpstream(ups)
|
||||
|
||||
|
@ -196,7 +200,7 @@ func TestSBPC(t *testing.T) {
|
|||
assert.Equal(t, hits, tc.testCache.Stats().Hit)
|
||||
|
||||
// There was one request to an upstream.
|
||||
assert.Equal(t, 1, ups.RequestsCount())
|
||||
assert.Equal(t, 1, numReq)
|
||||
|
||||
// Now make the same request to check the cache was used.
|
||||
res, err = tc.testFunc(hostname, dns.TypeA, setts)
|
||||
|
@ -214,7 +218,7 @@ func TestSBPC(t *testing.T) {
|
|||
assert.Equal(t, hits+1, tc.testCache.Stats().Hit)
|
||||
|
||||
// Check that there were no additional requests.
|
||||
assert.Equal(t, 1, ups.RequestsCount())
|
||||
assert.Equal(t, 1, numReq)
|
||||
})
|
||||
|
||||
purgeCaches(d)
|
||||
|
|
|
@ -3,15 +3,16 @@ package home
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
|
@ -80,8 +81,10 @@ func TestRDNS_Begin(t *testing.T) {
|
|||
binary.BigEndian.PutUint64(ttl, uint64(time.Now().Add(100*time.Hour).Unix()))
|
||||
|
||||
rdns := &RDNS{
|
||||
ipCache: ipCache,
|
||||
exchanger: &rDNSExchanger{},
|
||||
ipCache: ipCache,
|
||||
exchanger: &rDNSExchanger{
|
||||
ex: aghtest.NewErrorUpstream(),
|
||||
},
|
||||
clients: &clientsContainer{
|
||||
list: map[string]*Client{},
|
||||
idIndex: tc.cliIDIndex,
|
||||
|
@ -108,16 +111,22 @@ func TestRDNS_Begin(t *testing.T) {
|
|||
|
||||
// rDNSExchanger is a mock dnsforward.RDNSExchanger implementation for tests.
|
||||
type rDNSExchanger struct {
|
||||
ex aghtest.Exchanger
|
||||
ex upstream.Upstream
|
||||
usePrivate bool
|
||||
}
|
||||
|
||||
// Exchange implements dnsforward.RDNSExchanger interface for *RDNSExchanger.
|
||||
func (e *rDNSExchanger) Exchange(ip net.IP) (host string, err error) {
|
||||
rev, err := netutil.IPToReversedAddr(ip)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reversing ip: %w", err)
|
||||
}
|
||||
|
||||
req := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
Name: ip.String(),
|
||||
Qtype: dns.TypePTR,
|
||||
Name: dns.Fqdn(rev),
|
||||
Qclass: dns.ClassINET,
|
||||
Qtype: dns.TypePTR,
|
||||
}},
|
||||
}
|
||||
|
||||
|
@ -146,7 +155,9 @@ func TestRDNS_ensurePrivateCache(t *testing.T) {
|
|||
MaxCount: defaultRDNSCacheSize,
|
||||
})
|
||||
|
||||
ex := &rDNSExchanger{}
|
||||
ex := &rDNSExchanger{
|
||||
ex: aghtest.NewErrorUpstream(),
|
||||
}
|
||||
|
||||
rdns := &RDNS{
|
||||
ipCache: ipCache,
|
||||
|
@ -167,15 +178,27 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
|||
w := &bytes.Buffer{}
|
||||
aghtest.ReplaceLogWriter(t, w)
|
||||
|
||||
locUpstream := &aghtest.Upstream{
|
||||
Reverse: map[string][]string{
|
||||
"192.168.1.1": {"local.domain"},
|
||||
"2a00:1450:400c:c06::93": {"ipv6.domain"},
|
||||
localIP := net.IP{192, 168, 1, 1}
|
||||
revIPv4, err := netutil.IPToReversedAddr(localIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
revIPv6, err := netutil.IPToReversedAddr(net.ParseIP("2a00:1450:400c:c06::93"))
|
||||
require.NoError(t, err)
|
||||
|
||||
locUpstream := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "local.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = aghalg.Coalesce(
|
||||
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revIPv4, "local.domain"),
|
||||
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revIPv6, "ipv6.domain"),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
return resp, nil
|
||||
},
|
||||
}
|
||||
errUpstream := &aghtest.TestErrUpstream{
|
||||
Err: errors.Error("1234"),
|
||||
}
|
||||
|
||||
errUpstream := aghtest.NewErrorUpstream()
|
||||
|
||||
testCases := []struct {
|
||||
ups upstream.Upstream
|
||||
|
@ -186,10 +209,10 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
|||
ups: locUpstream,
|
||||
wantLog: "",
|
||||
name: "all_good",
|
||||
cliIP: net.IP{192, 168, 1, 1},
|
||||
cliIP: localIP,
|
||||
}, {
|
||||
ups: errUpstream,
|
||||
wantLog: `rdns: resolving "192.168.1.2": errupstream: 1234`,
|
||||
wantLog: `rdns: resolving "192.168.1.2": test upstream error`,
|
||||
name: "resolve_error",
|
||||
cliIP: net.IP{192, 168, 1, 2},
|
||||
}, {
|
||||
|
@ -211,9 +234,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
|||
ch := make(chan net.IP)
|
||||
rdns := &RDNS{
|
||||
exchanger: &rDNSExchanger{
|
||||
ex: aghtest.Exchanger{
|
||||
Ups: tc.ups,
|
||||
},
|
||||
ex: tc.ups,
|
||||
},
|
||||
clients: cc,
|
||||
ipCh: ch,
|
||||
|
|
193
internal/v1/dnssvc/dnssvc.go
Normal file
193
internal/v1/dnssvc/dnssvc.go
Normal file
|
@ -0,0 +1,193 @@
|
|||
// Package dnssvc contains the AdGuard Home DNS service.
|
||||
//
|
||||
// TODO(a.garipov): Define, if all methods of a *Service should work with a nil
|
||||
// receiver.
|
||||
package dnssvc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/v1/agh"
|
||||
// TODO(a.garipov): Add a “dnsproxy proxy” package to shield us from changes
|
||||
// and replacement of module dnsproxy.
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
)
|
||||
|
||||
// Config is the AdGuard Home DNS service configuration structure.
|
||||
//
|
||||
// TODO(a.garipov): Add timeout for incoming requests.
|
||||
type Config struct {
|
||||
// Addresses are the addresses on which to serve plain DNS queries.
|
||||
Addresses []netip.AddrPort
|
||||
|
||||
// Upstreams are the DNS upstreams to use. If not set, upstreams are
|
||||
// created using data from BootstrapServers, UpstreamServers, and
|
||||
// UpstreamTimeout.
|
||||
//
|
||||
// TODO(a.garipov): Think of a better scheme. Those other three parameters
|
||||
// are here only to make Config work properly.
|
||||
Upstreams []upstream.Upstream
|
||||
|
||||
// BootstrapServers are the addresses for bootstrapping the upstream DNS
|
||||
// server addresses.
|
||||
BootstrapServers []string
|
||||
|
||||
// UpstreamServers are the upstream DNS server addresses to use.
|
||||
UpstreamServers []string
|
||||
|
||||
// UpstreamTimeout is the timeout for upstream requests.
|
||||
UpstreamTimeout time.Duration
|
||||
}
|
||||
|
||||
// Service is the AdGuard Home DNS service. A nil *Service is a valid
|
||||
// [agh.Service] that does nothing.
|
||||
type Service struct {
|
||||
proxy *proxy.Proxy
|
||||
bootstraps []string
|
||||
upstreams []string
|
||||
upsTimeout time.Duration
|
||||
}
|
||||
|
||||
// New returns a new properly initialized *Service. If c is nil, svc is a nil
|
||||
// *Service that does nothing. The fields of c must not be modified after
|
||||
// calling New.
|
||||
func New(c *Config) (svc *Service, err error) {
|
||||
if c == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
svc = &Service{
|
||||
bootstraps: c.BootstrapServers,
|
||||
upstreams: c.UpstreamServers,
|
||||
upsTimeout: c.UpstreamTimeout,
|
||||
}
|
||||
|
||||
var upstreams []upstream.Upstream
|
||||
if len(c.Upstreams) > 0 {
|
||||
upstreams = c.Upstreams
|
||||
} else {
|
||||
upstreams, err = addressesToUpstreams(
|
||||
c.UpstreamServers,
|
||||
c.BootstrapServers,
|
||||
c.UpstreamTimeout,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("converting upstreams: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
svc.proxy = &proxy.Proxy{
|
||||
Config: proxy.Config{
|
||||
UDPListenAddr: udpAddrs(c.Addresses),
|
||||
TCPListenAddr: tcpAddrs(c.Addresses),
|
||||
UpstreamConfig: &proxy.UpstreamConfig{
|
||||
Upstreams: upstreams,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = svc.proxy.Init()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("proxy: %w", err)
|
||||
}
|
||||
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
// addressesToUpstreams is a wrapper around [upstream.AddressToUpstream]. It
|
||||
// accepts a slice of addresses and other upstream parameters, and returns a
|
||||
// slice of upstreams.
|
||||
func addressesToUpstreams(
|
||||
upsStrs []string,
|
||||
bootstraps []string,
|
||||
timeout time.Duration,
|
||||
) (upstreams []upstream.Upstream, err error) {
|
||||
upstreams = make([]upstream.Upstream, len(upsStrs))
|
||||
for i, upsStr := range upsStrs {
|
||||
upstreams[i], err = upstream.AddressToUpstream(upsStr, &upstream.Options{
|
||||
Bootstrap: bootstraps,
|
||||
Timeout: timeout,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upstream at index %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
return upstreams, nil
|
||||
}
|
||||
|
||||
// tcpAddrs converts []netip.AddrPort into []*net.TCPAddr.
|
||||
func tcpAddrs(addrPorts []netip.AddrPort) (tcpAddrs []*net.TCPAddr) {
|
||||
if addrPorts == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
tcpAddrs = make([]*net.TCPAddr, len(addrPorts))
|
||||
for i, a := range addrPorts {
|
||||
tcpAddrs[i] = net.TCPAddrFromAddrPort(a)
|
||||
}
|
||||
|
||||
return tcpAddrs
|
||||
}
|
||||
|
||||
// udpAddrs converts []netip.AddrPort into []*net.UDPAddr.
|
||||
func udpAddrs(addrPorts []netip.AddrPort) (udpAddrs []*net.UDPAddr) {
|
||||
if addrPorts == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
udpAddrs = make([]*net.UDPAddr, len(addrPorts))
|
||||
for i, a := range addrPorts {
|
||||
udpAddrs[i] = net.UDPAddrFromAddrPort(a)
|
||||
}
|
||||
|
||||
return udpAddrs
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ agh.Service = (*Service)(nil)
|
||||
|
||||
// Start implements the [agh.Service] interface for *Service. svc may be nil.
|
||||
// After Start exits, all DNS servers have tried to start, but there is no
|
||||
// guarantee that they did. Errors from the servers are written to the log.
|
||||
func (svc *Service) Start() (err error) {
|
||||
if svc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return svc.proxy.Start()
|
||||
}
|
||||
|
||||
// Shutdown implements the [agh.Service] interface for *Service. svc may be
|
||||
// nil.
|
||||
func (svc *Service) Shutdown(ctx context.Context) (err error) {
|
||||
if svc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return svc.proxy.Stop()
|
||||
}
|
||||
|
||||
// Config returns the current configuration of the web service.
|
||||
func (svc *Service) Config() (c *Config) {
|
||||
// TODO(a.garipov): Do we need to get the TCP addresses separately?
|
||||
udpAddrs := svc.proxy.Addrs(proxy.ProtoUDP)
|
||||
addrs := make([]netip.AddrPort, len(udpAddrs))
|
||||
for i, a := range udpAddrs {
|
||||
addrs[i] = a.(*net.UDPAddr).AddrPort()
|
||||
}
|
||||
|
||||
c = &Config{
|
||||
Addresses: addrs,
|
||||
BootstrapServers: svc.bootstraps,
|
||||
UpstreamServers: svc.upstreams,
|
||||
UpstreamTimeout: svc.upsTimeout,
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
89
internal/v1/dnssvc/dnssvc_test.go
Normal file
89
internal/v1/dnssvc/dnssvc_test.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
package dnssvc_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/v1/dnssvc"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
aghtest.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
// testTimeout is the common timeout for tests.
|
||||
const testTimeout = 100 * time.Millisecond
|
||||
|
||||
func TestService(t *testing.T) {
|
||||
const (
|
||||
bootstrapAddr = "bootstrap.example"
|
||||
upstreamAddr = "upstream.example"
|
||||
)
|
||||
|
||||
ups := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) {
|
||||
return upstreamAddr
|
||||
},
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = (&dns.Msg{}).SetReply(req)
|
||||
|
||||
return resp, nil
|
||||
},
|
||||
}
|
||||
|
||||
c := &dnssvc.Config{
|
||||
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
|
||||
Upstreams: []upstream.Upstream{ups},
|
||||
BootstrapServers: []string{bootstrapAddr},
|
||||
UpstreamServers: []string{upstreamAddr},
|
||||
UpstreamTimeout: testTimeout,
|
||||
}
|
||||
|
||||
svc, err := dnssvc.New(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
gotConf := svc.Config()
|
||||
require.NotNil(t, gotConf)
|
||||
require.Len(t, gotConf.Addresses, 1)
|
||||
|
||||
addr := gotConf.Addresses[0]
|
||||
|
||||
t.Run("dns", func(t *testing.T) {
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: dns.Id(),
|
||||
RecursionDesired: true,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: "example.com.",
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
cli := &dns.Client{}
|
||||
resp, _, excErr := cli.ExchangeContext(ctx, req, addr.String())
|
||||
require.NoError(t, excErr)
|
||||
|
||||
assert.NotNil(t, resp)
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
err = svc.Shutdown(ctx)
|
||||
require.NoError(t, err)
|
||||
}
|
|
@ -40,8 +40,8 @@ type Config struct {
|
|||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// Service is the AdGuard Home web service. A nil *Service is a valid service
|
||||
// that does nothing.
|
||||
// Service is the AdGuard Home web service. A nil *Service is a valid
|
||||
// [agh.Service] that does nothing.
|
||||
type Service struct {
|
||||
tls *tls.Config
|
||||
servers []*http.Server
|
||||
|
@ -155,7 +155,7 @@ type unit = struct{}
|
|||
// type check
|
||||
var _ agh.Service = (*Service)(nil)
|
||||
|
||||
// Start implements the agh.Service interface for *Service. svc may be nil.
|
||||
// Start implements the [agh.Service] interface for *Service. svc may be nil.
|
||||
// After Start exits, all HTTP servers have tried to start, possibly failing and
|
||||
// writing error messages to the log.
|
||||
func (svc *Service) Start() (err error) {
|
||||
|
@ -205,7 +205,8 @@ func serve(srv *http.Server, wg *sync.WaitGroup) {
|
|||
}
|
||||
}
|
||||
|
||||
// Shutdown implements the agh.Service interface for *Service. svc may be nil.
|
||||
// Shutdown implements the [agh.Service] interface for *Service. svc may be
|
||||
// nil.
|
||||
func (svc *Service) Shutdown(ctx context.Context) (err error) {
|
||||
if svc == nil {
|
||||
return nil
|
||||
|
|
Loading…
Reference in a new issue