Pull request 1803: 5685-fix-safe-search

Updates #5685.

Squashed commit of the following:

commit 5312147abfa0914c896acbf1e88f8c8f1af90f2b
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Apr 6 14:09:44 2023 +0300

    safesearch: imp tests, logs

commit 298b5d24ce292c5f83ebe33d1e92329e4b3c1acc
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Apr 5 20:36:16 2023 +0300

    safesearch: fix filters, logging

commit 63d6ca5d694d45705473f2f0410e9e0b49cf7346
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Apr 5 20:24:47 2023 +0300

    all: dry; fix logs

commit fdbf2f364fd0484b47b3161bf6f4581856fdf47b
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Apr 5 20:01:08 2023 +0300

    all: fix safe search update
This commit is contained in:
Ainar Garipov 2023-04-06 14:12:50 +03:00
parent a55cbbe79c
commit 61b4043775
19 changed files with 527 additions and 265 deletions

View file

@ -25,6 +25,7 @@ NOTE: Add new changes BELOW THIS COMMENT.
### Added
- IPv6 support in Safe Search for some services.
- The ability to make bootstrap DNS lookups prefer IPv6 addresses to IPv4 ones
using the new `dns.bootstrap_prefer_ipv6` configuration file property
([#4262]).

View file

@ -453,8 +453,9 @@ func TestSafeSearch(t *testing.T) {
SafeSearchCacheSize: 1000,
CacheTime: 30,
}
safeSearch, err := safesearch.NewDefaultSafeSearch(
safeSearch, err := safesearch.NewDefault(
safeSearchConf,
"",
filterConf.SafeSearchCacheSize,
time.Minute*time.Duration(filterConf.CacheTime),
)

View file

@ -1,17 +1,17 @@
package filtering
import (
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
)
import "github.com/miekg/dns"
// SafeSearch interface describes a service for search engines hosts rewrites.
type SafeSearch interface {
// SearchHost returns a replacement address for the search engine host.
SearchHost(host string, qtype uint16) (res *rules.DNSRewrite)
// CheckHost checks host with safe search engine.
// CheckHost checks host with safe search filter. CheckHost must be safe
// for concurrent use. qtype must be either [dns.TypeA] or [dns.TypeAAAA].
CheckHost(host string, qtype uint16) (res Result, err error)
// Update updates the configuration of the safe search filter. Update must
// be safe for concurrent use. An implementation of Update may ignore some
// fields, but it must document which.
Update(conf SafeSearchConfig) (err error)
}
// SafeSearchConfig is a struct with safe search related settings.
@ -37,10 +37,12 @@ type SafeSearchConfig struct {
// [hostChecker.check].
func (d *DNSFilter) checkSafeSearch(
host string,
_ uint16,
qtype uint16,
setts *Settings,
) (res Result, err error) {
if !setts.ProtectionEnabled || !setts.SafeSearchEnabled {
if !setts.ProtectionEnabled ||
!setts.SafeSearchEnabled ||
(qtype != dns.TypeA && qtype != dns.TypeAAAA) {
return Result{}, nil
}
@ -50,8 +52,8 @@ func (d *DNSFilter) checkSafeSearch(
clientSafeSearch := setts.ClientSafeSearch
if clientSafeSearch != nil {
return clientSafeSearch.CheckHost(host, dns.TypeA)
return clientSafeSearch.CheckHost(host, qtype)
}
return d.safeSearch.CheckHost(host, dns.TypeA)
return d.safeSearch.CheckHost(host, qtype)
}

View file

@ -9,6 +9,7 @@ import (
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
@ -53,44 +54,85 @@ func isServiceProtected(s filtering.SafeSearchConfig, service Service) (ok bool)
}
}
// DefaultSafeSearch is the default safesearch struct.
type DefaultSafeSearch struct {
engine *urlfilter.DNSEngine
safeSearchCache cache.Cache
resolver filtering.Resolver
cacheTime time.Duration
// Default is the default safe search filter that uses filtering rules with the
// dnsrewrite modifier.
type Default struct {
// mu protects engine.
mu *sync.RWMutex
// engine is the filtering engine that contains the DNS rewrite rules.
// engine may be nil, which means that this safe search filter is disabled.
engine *urlfilter.DNSEngine
cache cache.Cache
resolver filtering.Resolver
logPrefix string
cacheTTL time.Duration
}
// NewDefaultSafeSearch returns new safesearch struct. CacheTime is an element
// TTL (in minutes).
func NewDefaultSafeSearch(
// NewDefault returns an initialized default safe search filter. name is used
// for logging.
func NewDefault(
conf filtering.SafeSearchConfig,
name string,
cacheSize uint,
cacheTime time.Duration,
) (ss *DefaultSafeSearch, err error) {
engine, err := newEngine(filtering.SafeSearchListID, conf)
if err != nil {
return nil, err
}
cacheTTL time.Duration,
) (ss *Default, err error) {
var resolver filtering.Resolver = net.DefaultResolver
if conf.CustomResolver != nil {
resolver = conf.CustomResolver
}
return &DefaultSafeSearch{
engine: engine,
safeSearchCache: cache.New(cache.Config{
ss = &Default{
mu: &sync.RWMutex{},
cache: cache.New(cache.Config{
EnableLRU: true,
MaxSize: cacheSize,
}),
cacheTime: cacheTime,
resolver: resolver,
}, nil
resolver: resolver,
// Use %s, because the client safe-search names already contain double
// quotes.
logPrefix: fmt.Sprintf("safesearch %s: ", name),
cacheTTL: cacheTTL,
}
err = ss.resetEngine(filtering.SafeSearchListID, conf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
return ss, nil
}
// newEngine creates new engine for provided safe search configuration.
func newEngine(listID int, conf filtering.SafeSearchConfig) (engine *urlfilter.DNSEngine, err error) {
// log is a helper for logging that includes the name of the safe search
// filter. level must be one of [log.DEBUG], [log.INFO], and [log.ERROR].
func (ss *Default) log(level log.Level, msg string, args ...any) {
switch level {
case log.DEBUG:
log.Debug(ss.logPrefix+msg, args...)
case log.INFO:
log.Info(ss.logPrefix+msg, args...)
case log.ERROR:
log.Error(ss.logPrefix+msg, args...)
default:
panic(fmt.Errorf("safesearch: unsupported logging level %d", level))
}
}
// resetEngine creates new engine for provided safe search configuration and
// sets it in ss.
func (ss *Default) resetEngine(
listID int,
conf filtering.SafeSearchConfig,
) (err error) {
if !conf.Enabled {
ss.log(log.INFO, "disabled")
return nil
}
var sb strings.Builder
for service, serviceRules := range safeSearchRules {
if isServiceProtected(conf, service) {
@ -106,20 +148,73 @@ func newEngine(listID int, conf filtering.SafeSearchConfig) (engine *urlfilter.D
rs, err := filterlist.NewRuleStorage([]filterlist.RuleList{strList})
if err != nil {
return nil, fmt.Errorf("creating rule storage: %w", err)
return fmt.Errorf("creating rule storage: %w", err)
}
engine = urlfilter.NewDNSEngine(rs)
log.Info("safesearch: filter %d: reset %d rules", listID, engine.RulesCount)
ss.engine = urlfilter.NewDNSEngine(rs)
return engine, nil
ss.log(log.INFO, "reset %d rules", ss.engine.RulesCount)
return nil
}
// type check
var _ filtering.SafeSearch = (*DefaultSafeSearch)(nil)
var _ filtering.SafeSearch = (*Default)(nil)
// CheckHost implements the [filtering.SafeSearch] interface for
// *DefaultSafeSearch.
func (ss *Default) CheckHost(
host string,
qtype rules.RRType,
) (res filtering.Result, err error) {
start := time.Now()
defer func() {
ss.log(log.DEBUG, "lookup for %q finished in %s", host, time.Since(start))
}()
if qtype != dns.TypeA && qtype != dns.TypeAAAA {
return filtering.Result{}, fmt.Errorf("unsupported question type %s", dns.Type(qtype))
}
// Check cache. Return cached result if it was found
cachedValue, isFound := ss.getCachedResult(host, qtype)
if isFound {
ss.log(log.DEBUG, "found in cache: %q", host)
return cachedValue, nil
}
rewrite := ss.searchHost(host, qtype)
if rewrite == nil {
return filtering.Result{}, nil
}
fltRes, err := ss.newResult(rewrite, qtype)
if err != nil {
ss.log(log.DEBUG, "looking up addresses for %q: %s", host, err)
return filtering.Result{}, err
}
if fltRes != nil {
res = *fltRes
ss.setCacheResult(host, qtype, res)
return res, nil
}
return filtering.Result{}, fmt.Errorf("no ipv4 addresses for %q", host)
}
// searchHost looks up DNS rewrites in the internal DNS filtering engine.
func (ss *Default) searchHost(host string, qtype rules.RRType) (res *rules.DNSRewrite) {
ss.mu.RLock()
defer ss.mu.RUnlock()
if ss.engine == nil {
return nil
}
// SearchHost implements the [filtering.SafeSearch] interface for *DefaultSafeSearch.
func (ss *DefaultSafeSearch) SearchHost(host string, qtype uint16) (res *rules.DNSRewrite) {
r, _ := ss.engine.MatchRequest(&urlfilter.DNSRequest{
Hostname: strings.ToLower(host),
DNSType: qtype,
@ -133,51 +228,11 @@ func (ss *DefaultSafeSearch) SearchHost(host string, qtype uint16) (res *rules.D
return nil
}
// CheckHost implements the [filtering.SafeSearch] interface for
// *DefaultSafeSearch.
func (ss *DefaultSafeSearch) CheckHost(
host string,
qtype uint16,
) (res filtering.Result, err error) {
if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("safesearch: lookup for %s", host)
}
// Check cache. Return cached result if it was found
cachedValue, isFound := ss.getCachedResult(host)
if isFound {
log.Debug("safesearch: found in cache: %s", host)
return cachedValue, nil
}
rewrite := ss.SearchHost(host, qtype)
if rewrite == nil {
return filtering.Result{}, nil
}
dRes, err := ss.newResult(rewrite, qtype)
if err != nil {
log.Debug("safesearch: failed to lookup addresses for %s: %s", host, err)
return filtering.Result{}, err
}
if dRes != nil {
res = *dRes
ss.setCacheResult(host, res)
return res, nil
}
return filtering.Result{}, fmt.Errorf("no ipv4 addresses in safe search response for %s", host)
}
// newResult creates Result object from rewrite rule.
func (ss *DefaultSafeSearch) newResult(
// newResult creates Result object from rewrite rule. qtype must be either
// [dns.TypeA] or [dns.TypeAAAA].
func (ss *Default) newResult(
rewrite *rules.DNSRewrite,
qtype uint16,
qtype rules.RRType,
) (res *filtering.Result, err error) {
res = &filtering.Result{
Rules: []*filtering.ResultRule{{
@ -187,7 +242,7 @@ func (ss *DefaultSafeSearch) newResult(
IsFiltered: true,
}
if rewrite.RRType == qtype && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
if rewrite.RRType == qtype {
ip, ok := rewrite.Value.(net.IP)
if !ok || ip == nil {
return nil, nil
@ -198,17 +253,25 @@ func (ss *DefaultSafeSearch) newResult(
return res, nil
}
if rewrite.NewCNAME == "" {
host := rewrite.NewCNAME
if host == "" {
return nil, nil
}
ips, err := ss.resolver.LookupIP(context.Background(), "ip", rewrite.NewCNAME)
ss.log(log.DEBUG, "resolving %q", host)
ips, err := ss.resolver.LookupIP(context.Background(), qtypeToProto(qtype), host)
if err != nil {
return nil, err
}
ss.log(log.DEBUG, "resolved %s", ips)
for _, ip := range ips {
if ip = ip.To4(); ip == nil {
// TODO(a.garipov): Remove this filtering once the resolver we use
// actually learns about network.
ip = fitToProto(ip, qtype)
if ip == nil {
continue
}
@ -220,38 +283,71 @@ func (ss *DefaultSafeSearch) newResult(
return nil, nil
}
// setCacheResult stores data in cache for host.
func (ss *DefaultSafeSearch) setCacheResult(host string, res filtering.Result) {
expire := uint32(time.Now().Add(ss.cacheTime).Unix())
// qtypeToProto returns "ip4" for [dns.TypeA] and "ip6" for [dns.TypeAAAA].
// It panics for other types.
func qtypeToProto(qtype rules.RRType) (proto string) {
switch qtype {
case dns.TypeA:
return "ip4"
case dns.TypeAAAA:
return "ip6"
default:
panic(fmt.Errorf("safesearch: unsupported question type %s", dns.Type(qtype)))
}
}
// fitToProto returns a non-nil IP address if ip is the correct protocol version
// for qtype. qtype is expected to be either [dns.TypeA] or [dns.TypeAAAA].
func fitToProto(ip net.IP, qtype rules.RRType) (res net.IP) {
ip4 := ip.To4()
if qtype == dns.TypeA {
return ip4
}
if ip4 == nil {
return ip
}
return nil
}
// setCacheResult stores data in cache for host. qtype is expected to be either
// [dns.TypeA] or [dns.TypeAAAA].
func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering.Result) {
expire := uint32(time.Now().Add(ss.cacheTTL).Unix())
exp := make([]byte, 4)
binary.BigEndian.PutUint32(exp, expire)
buf := bytes.NewBuffer(exp)
err := gob.NewEncoder(buf).Encode(res)
if err != nil {
log.Error("safesearch: cache encoding: %s", err)
ss.log(log.ERROR, "cache encoding: %s", err)
return
}
val := buf.Bytes()
_ = ss.safeSearchCache.Set([]byte(host), val)
_ = ss.cache.Set([]byte(dns.Type(qtype).String()+" "+host), val)
log.Debug("safesearch: stored in cache: %s (%d bytes)", host, len(val))
ss.log(log.DEBUG, "stored in cache: %q, %d bytes", host, len(val))
}
// getCachedResult returns stored data from cache for host.
func (ss *DefaultSafeSearch) getCachedResult(host string) (res filtering.Result, ok bool) {
// getCachedResult returns stored data from cache for host. qtype is expected
// to be either [dns.TypeA] or [dns.TypeAAAA].
func (ss *Default) getCachedResult(
host string,
qtype rules.RRType,
) (res filtering.Result, ok bool) {
res = filtering.Result{}
data := ss.safeSearchCache.Get([]byte(host))
data := ss.cache.Get([]byte(dns.Type(qtype).String() + " " + host))
if data == nil {
return res, false
}
exp := binary.BigEndian.Uint32(data[:4])
if exp <= uint32(time.Now().Unix()) {
ss.safeSearchCache.Del([]byte(host))
ss.cache.Del([]byte(host))
return res, false
}
@ -260,10 +356,27 @@ func (ss *DefaultSafeSearch) getCachedResult(host string) (res filtering.Result,
err := gob.NewDecoder(buf).Decode(&res)
if err != nil {
log.Debug("safesearch: cache decoding: %s", err)
ss.log(log.ERROR, "cache decoding: %s", err)
return filtering.Result{}, false
}
return res, true
}
// Update implements the [filtering.SafeSearch] interface for *Default. Update
// ignores the CustomResolver and Enabled fields.
func (ss *Default) Update(conf filtering.SafeSearchConfig) (err error) {
ss.mu.Lock()
defer ss.mu.Unlock()
err = ss.resetEngine(filtering.SafeSearchListID, conf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
ss.cache.Clear()
return nil
}

View file

@ -0,0 +1,137 @@
package safesearch
import (
"context"
"net"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TODO(a.garipov): Move as much of this as possible into proper external tests.
const (
// TODO(a.garipov): Add IPv6 tests.
testQType = dns.TypeA
testCacheSize = 5000
testCacheTTL = 30 * time.Minute
)
var defaultSafeSearchConf = filtering.SafeSearchConfig{
Enabled: true,
Bing: true,
DuckDuckGo: true,
Google: true,
Pixabay: true,
Yandex: true,
YouTube: true,
}
var yandexIP = net.IPv4(213, 180, 193, 56)
func newForTest(t testing.TB, ssConf filtering.SafeSearchConfig) (ss *Default) {
ss, err := NewDefault(ssConf, "", testCacheSize, testCacheTTL)
require.NoError(t, err)
return ss
}
func TestSafeSearch(t *testing.T) {
ss := newForTest(t, defaultSafeSearchConf)
val := ss.searchHost("www.google.com", testQType)
assert.Equal(t, &rules.DNSRewrite{NewCNAME: "forcesafesearch.google.com"}, val)
}
func TestSafeSearchCacheYandex(t *testing.T) {
const domain = "yandex.ru"
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
// Check host with disabled safesearch.
res, err := ss.CheckHost(domain, testQType)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules)
ss = newForTest(t, defaultSafeSearchConf)
res, err = ss.CheckHost(domain, testQType)
require.NoError(t, err)
// For yandex we already know valid IP.
require.Len(t, res.Rules, 1)
assert.Equal(t, res.Rules[0].IP, yandexIP)
// Check cache.
cachedValue, isFound := ss.getCachedResult(domain, testQType)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
}
func TestSafeSearchCacheGoogle(t *testing.T) {
const domain = "www.google.ru"
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
res, err := ss.CheckHost(domain, testQType)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules)
resolver := &aghtest.TestResolver{}
ss = newForTest(t, defaultSafeSearchConf)
ss.resolver = resolver
// Lookup for safesearch domain.
rewrite := ss.searchHost(domain, testQType)
ips, err := resolver.LookupIP(context.Background(), "ip", rewrite.NewCNAME)
require.NoError(t, err)
var foundIP net.IP
for _, ip := range ips {
if ip.To4() != nil {
foundIP = ip
break
}
}
res, err = ss.CheckHost(domain, testQType)
require.NoError(t, err)
require.Len(t, res.Rules, 1)
assert.True(t, res.Rules[0].IP.Equal(foundIP))
// Check cache.
cachedValue, isFound := ss.getCachedResult(domain, testQType)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.True(t, cachedValue.Rules[0].IP.Equal(foundIP))
}
const googleHost = "www.google.com"
var dnsRewriteSink *rules.DNSRewrite
func BenchmarkSafeSearch(b *testing.B) {
ss := newForTest(b, defaultSafeSearchConf)
for n := 0; n < b.N; n++ {
dnsRewriteSink = ss.searchHost(googleHost, testQType)
}
assert.Equal(b, "forcesafesearch.google.com", dnsRewriteSink.NewCNAME)
}

View file

@ -1,26 +1,37 @@
package safesearch
package safesearch_test
import (
"context"
"net"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// Common test constants.
const (
safeSearchCacheSize = 5000
cacheTime = 30 * time.Minute
// TODO(a.garipov): Add IPv6 tests.
testQType = dns.TypeA
testCacheSize = 5000
testCacheTTL = 30 * time.Minute
)
var defaultSafeSearchConf = filtering.SafeSearchConfig{
Enabled: true,
// testConf is the default safe search configuration for tests.
var testConf = filtering.SafeSearchConfig{
CustomResolver: nil,
Enabled: true,
Bing: true,
DuckDuckGo: true,
Google: true,
@ -29,25 +40,15 @@ var defaultSafeSearchConf = filtering.SafeSearchConfig{
YouTube: true,
}
// yandexIP is the expected IP address of Yandex safe search results. Keep in
// sync with the rules data.
var yandexIP = net.IPv4(213, 180, 193, 56)
func newForTest(t testing.TB, ssConf filtering.SafeSearchConfig) (ss *DefaultSafeSearch) {
ss, err := NewDefaultSafeSearch(ssConf, safeSearchCacheSize, cacheTime)
func TestDefault_CheckHost_yandex(t *testing.T) {
conf := testConf
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
require.NoError(t, err)
return ss
}
func TestSafeSearch(t *testing.T) {
ss := newForTest(t, defaultSafeSearchConf)
val := ss.SearchHost("www.google.com", dns.TypeA)
assert.Equal(t, &rules.DNSRewrite{NewCNAME: "forcesafesearch.google.com"}, val)
}
func TestCheckHostSafeSearchYandex(t *testing.T) {
ss := newForTest(t, defaultSafeSearchConf)
// Check host for each domain.
for _, host := range []string{
"yandex.ru",
@ -57,7 +58,8 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
"yandex.kz",
"www.yandex.com",
} {
res, err := ss.CheckHost(host, dns.TypeA)
var res filtering.Result
res, err = ss.CheckHost(host, testQType)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
@ -69,12 +71,14 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
}
}
func TestCheckHostSafeSearchGoogle(t *testing.T) {
func TestDefault_CheckHost_google(t *testing.T) {
resolver := &aghtest.TestResolver{}
ip, _ := resolver.HostToIPs("forcesafesearch.google.com")
ss := newForTest(t, defaultSafeSearchConf)
ss.resolver = resolver
conf := testConf
conf.CustomResolver = resolver
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
require.NoError(t, err)
// Check host for each domain.
for _, host := range []string{
@ -87,7 +91,8 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
"www.google.je",
} {
t.Run(host, func(t *testing.T) {
res, err := ss.CheckHost(host, dns.TypeA)
var res filtering.Result
res, err = ss.CheckHost(host, testQType)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
@ -100,103 +105,35 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
}
}
func TestSafeSearchCacheYandex(t *testing.T) {
const domain = "yandex.ru"
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
// Check host with disabled safesearch.
res, err := ss.CheckHost(domain, dns.TypeA)
func TestDefault_Update(t *testing.T) {
conf := testConf
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules)
ss = newForTest(t, defaultSafeSearchConf)
res, err = ss.CheckHost(domain, dns.TypeA)
res, err := ss.CheckHost("www.yandex.com", testQType)
require.NoError(t, err)
// For yandex we already know valid IP.
require.Len(t, res.Rules, 1)
assert.True(t, res.IsFiltered)
assert.Equal(t, res.Rules[0].IP, yandexIP)
// Check cache.
cachedValue, isFound := ss.getCachedResult(domain)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
}
func TestSafeSearchCacheGoogle(t *testing.T) {
const domain = "www.google.ru"
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
res, err := ss.CheckHost(domain, dns.TypeA)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules)
resolver := &aghtest.TestResolver{}
ss = newForTest(t, defaultSafeSearchConf)
ss.resolver = resolver
// Lookup for safesearch domain.
rewrite := ss.SearchHost(domain, dns.TypeA)
ips, err := resolver.LookupIP(context.Background(), "ip", rewrite.NewCNAME)
require.NoError(t, err)
var foundIP net.IP
for _, ip := range ips {
if ip.To4() != nil {
foundIP = ip
break
}
}
res, err = ss.CheckHost(domain, dns.TypeA)
require.NoError(t, err)
require.Len(t, res.Rules, 1)
assert.True(t, res.Rules[0].IP.Equal(foundIP))
// Check cache.
cachedValue, isFound := ss.getCachedResult(domain)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.True(t, cachedValue.Rules[0].IP.Equal(foundIP))
}
const googleHost = "www.google.com"
var dnsRewriteSink *rules.DNSRewrite
func BenchmarkSafeSearch(b *testing.B) {
ss := newForTest(b, defaultSafeSearchConf)
for n := 0; n < b.N; n++ {
dnsRewriteSink = ss.SearchHost(googleHost, dns.TypeA)
}
assert.Equal(b, "forcesafesearch.google.com", dnsRewriteSink.NewCNAME)
}
var dnsRewriteParallelSink *rules.DNSRewrite
func BenchmarkSafeSearch_parallel(b *testing.B) {
ss := newForTest(b, defaultSafeSearchConf)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
dnsRewriteParallelSink = ss.SearchHost(googleHost, dns.TypeA)
}
err = ss.Update(filtering.SafeSearchConfig{
Enabled: true,
Google: false,
})
require.NoError(t, err)
assert.Equal(b, "forcesafesearch.google.com", dnsRewriteParallelSink.NewCNAME)
res, err = ss.CheckHost("www.yandex.com", testQType)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
err = ss.Update(filtering.SafeSearchConfig{
Enabled: false,
Google: true,
})
require.NoError(t, err)
res, err = ss.CheckHost("www.yandex.com", testQType)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
}

View file

@ -50,11 +50,19 @@ func (d *DNSFilter) handleSafeSearchSettings(w http.ResponseWriter, r *http.Requ
return
}
conf := *req
err = d.safeSearch.Update(conf)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "updating: %s", err)
return
}
func() {
d.confLock.Lock()
defer d.confLock.Unlock()
d.Config.SafeSearchConf = *req
d.Config.SafeSearchConf = conf
}()
d.Config.ConfigModified()

View file

@ -3,8 +3,10 @@ package home
import (
"encoding"
"fmt"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/dnsproxy/proxy"
)
@ -45,6 +47,23 @@ func (c *Client) closeUpstreams() (err error) {
return nil
}
// setSafeSearch initializes and sets the safe search filter for this client.
func (c *Client) setSafeSearch(
conf filtering.SafeSearchConfig,
cacheSize uint,
cacheTTL time.Duration,
) (err error) {
ss, err := safesearch.NewDefault(conf, fmt.Sprintf("client %q", c.Name), cacheSize, cacheTTL)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
c.SafeSearch = ss
return nil
}
// clientSource represents the source from which the information about the
// client has been obtained.
type clientSource uint

View file

@ -13,7 +13,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
@ -55,6 +54,14 @@ type clientsContainer struct {
// more detail.
lock sync.Mutex
// safeSearchCacheSize is the size of the safe search cache to use for
// persistent clients.
safeSearchCacheSize uint
// safeSearchCacheTTL is the TTL of the safe search cache to use for
// persistent clients.
safeSearchCacheTTL time.Duration
// testing is a flag that disables some features for internal tests.
//
// TODO(a.garipov): Awful. Remove.
@ -74,6 +81,7 @@ func (clients *clientsContainer) Init(
if clients.list != nil {
log.Fatal("clients.list != nil")
}
clients.list = make(map[string]*Client)
clients.idIndex = make(map[string]*Client)
clients.ipToRC = map[netip.Addr]*RuntimeClient{}
@ -85,6 +93,9 @@ func (clients *clientsContainer) Init(
clients.arpdb = arpdb
clients.addFromConfig(objects, filteringConf)
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
if clients.testing {
return
}
@ -171,18 +182,16 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
if o.SafeSearchConf.Enabled {
o.SafeSearchConf.CustomResolver = safeSearchResolver{}
ss, err := safesearch.NewDefaultSafeSearch(
err := cli.setSafeSearch(
o.SafeSearchConf,
filteringConf.SafeSearchCacheSize,
time.Minute*time.Duration(filteringConf.CacheTime),
)
if err != nil {
log.Error("clients: init client safesearch %s: %s", cli.Name, err)
log.Error("clients: init client safesearch %q: %s", cli.Name, err)
continue
}
cli.SafeSearch = ss
}
for _, s := range o.BlockedServices {

View file

@ -9,17 +9,27 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestClients(t *testing.T) {
clients := clientsContainer{}
clients.testing = true
// newClientsContainer is a helper that creates a new clients container for
// tests.
func newClientsContainer() (c *clientsContainer) {
c = &clientsContainer{
testing: true,
}
clients.Init(nil, nil, nil, nil, nil)
c.Init(nil, nil, nil, nil, &filtering.Config{})
return c
}
func TestClients(t *testing.T) {
clients := newClientsContainer()
t.Run("add_success", func(t *testing.T) {
var (
@ -198,10 +208,7 @@ func TestClients(t *testing.T) {
}
func TestClientsWHOIS(t *testing.T) {
clients := clientsContainer{
testing: true,
}
clients.Init(nil, nil, nil, nil, nil)
clients := newClientsContainer()
whois := &RuntimeClientWHOISInfo{
Country: "AU",
Orgname: "Example Org",
@ -247,10 +254,7 @@ func TestClientsWHOIS(t *testing.T) {
}
func TestClientsAddExisting(t *testing.T) {
clients := clientsContainer{
testing: true,
}
clients.Init(nil, nil, nil, nil, nil)
clients := newClientsContainer()
t.Run("simple", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
@ -325,10 +329,7 @@ func TestClientsAddExisting(t *testing.T) {
}
func TestClientsCustomUpstream(t *testing.T) {
clients := clientsContainer{
testing: true,
}
clients.Init(nil, nil, nil, nil, nil)
clients := newClientsContainer()
// Add client with upstreams.
ok, err := clients.Add(&Client{

View file

@ -49,8 +49,8 @@ type clientJSON struct {
type runtimeClientJSON struct {
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
Name string `json:"name"`
IP netip.Addr `json:"ip"`
Name string `json:"name"`
Source clientSource `json:"source"`
}
@ -90,14 +90,16 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
}
// jsonToClient converts JSON object to Client object.
func jsonToClient(cj clientJSON) (c *Client) {
func (clients *clientsContainer) jsonToClient(cj clientJSON) (c *Client, err error) {
var safeSearchConf filtering.SafeSearchConfig
if cj.SafeSearchConf != nil {
safeSearchConf = *cj.SafeSearchConf
} else {
// TODO(d.kolyshev): Remove after cleaning the deprecated
// [clientJSON.SafeSearchEnabled] field.
safeSearchConf = filtering.SafeSearchConfig{Enabled: cj.SafeSearchEnabled}
safeSearchConf = filtering.SafeSearchConfig{
Enabled: cj.SafeSearchEnabled,
}
// Set default service flags for enabled safesearch.
if safeSearchConf.Enabled {
@ -110,20 +112,35 @@ func jsonToClient(cj clientJSON) (c *Client) {
}
}
return &Client{
Name: cj.Name,
IDs: cj.IDs,
Tags: cj.Tags,
c = &Client{
safeSearchConf: safeSearchConf,
Name: cj.Name,
IDs: cj.IDs,
Tags: cj.Tags,
BlockedServices: cj.BlockedServices,
Upstreams: cj.Upstreams,
UseOwnSettings: !cj.UseGlobalSettings,
FilteringEnabled: cj.FilteringEnabled,
ParentalEnabled: cj.ParentalEnabled,
SafeBrowsingEnabled: cj.SafeBrowsingEnabled,
safeSearchConf: safeSearchConf,
UseOwnBlockedServices: !cj.UseGlobalBlockedServices,
BlockedServices: cj.BlockedServices,
Upstreams: cj.Upstreams,
}
if safeSearchConf.Enabled {
err = c.setSafeSearch(
safeSearchConf,
clients.safeSearchCacheSize,
clients.safeSearchCacheTTL,
)
if err != nil {
return nil, fmt.Errorf("creating safesearch for client %q: %w", c.Name, err)
}
}
return c, nil
}
// clientToJSON converts Client object to JSON.
@ -161,7 +178,13 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
return
}
c := jsonToClient(cj)
c, err := clients.jsonToClient(cj)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
ok, err := clients.Add(c)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@ -224,7 +247,13 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
return
}
c := jsonToClient(dj.Data)
c, err := clients.jsonToClient(dj.Data)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
err = clients.Update(dj.Name, c)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)

View file

@ -545,6 +545,8 @@ var _ filtering.Resolver = safeSearchResolver{}
// LookupIP implements [filtering.Resolver] interface for safeSearchResolver.
// It returns the slice of net.IP with IPv4 and IPv6 instances.
//
// TODO(a.garipov): Support network.
func (r safeSearchResolver) LookupIP(_ context.Context, _, host string) (ips []net.IP, err error) {
addrs, err := Context.dnsServer.Resolve(host)
if err != nil {

View file

@ -297,8 +297,9 @@ func setupConfig(opts options) (err error) {
config.DNS.DnsfilterConf.HTTPClient = Context.client
config.DNS.DnsfilterConf.SafeSearchConf.CustomResolver = safeSearchResolver{}
config.DNS.DnsfilterConf.SafeSearch, err = safesearch.NewDefaultSafeSearch(
config.DNS.DnsfilterConf.SafeSearch, err = safesearch.NewDefault(
config.DNS.DnsfilterConf.SafeSearchConf,
"default",
config.DNS.DnsfilterConf.SafeSearchCacheSize,
time.Minute*time.Duration(config.DNS.DnsfilterConf.CacheTime),
)
@ -869,8 +870,10 @@ func detectFirstRun() bool {
// Connect to a remote server resolving hostname using our own DNS server.
//
// TODO(e.burkov): This messy logic should be decomposed and clarified.
//
// TODO(a.garipov): Support network.
func customDialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
log.Tracef("network:%v addr:%v", network, addr)
log.Debug("home: customdial: dialing addr %q for network %s", addr, network)
host, port, err := net.SplitHostPort(addr)
if err != nil {