From 792975e248bb04bce5a8ec767441fcf253c6d00f Mon Sep 17 00:00:00 2001 From: Stanislav Chzhen Date: Thu, 3 Oct 2024 15:46:40 +0300 Subject: [PATCH] all: slog safesearch --- internal/client/persistent.go | 19 --- internal/dnsforward/dnsforward_test.go | 14 +- internal/filtering/safesearch.go | 13 +- internal/filtering/safesearch/safesearch.go | 125 +++++++++++------- .../safesearch/safesearch_internal_test.go | 19 ++- .../filtering/safesearch/safesearch_test.go | 54 +++++--- internal/filtering/safesearchhttp.go | 2 +- internal/home/clients.go | 27 +++- internal/home/clients_internal_test.go | 15 ++- internal/home/clientshttp.go | 22 ++- internal/home/clientshttp_internal_test.go | 11 +- internal/home/home.go | 29 ++-- 12 files changed, 224 insertions(+), 126 deletions(-) diff --git a/internal/client/persistent.go b/internal/client/persistent.go index 4e09c5b8..7a0339b0 100644 --- a/internal/client/persistent.go +++ b/internal/client/persistent.go @@ -7,10 +7,8 @@ import ( "net/netip" "slices" "strings" - "time" "github.com/AdguardTeam/AdGuardHome/internal/filtering" - "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" @@ -323,20 +321,3 @@ func (c *Persistent) CloseUpstreams() (err error) { return nil } - -// SetSafeSearch initializes and sets the safe search filter for this client. -func (c *Persistent) SetSafeSearch( - conf filtering.SafeSearchConfig, - cacheSize uint, - cacheTTL time.Duration, -) (err error) { - ss, err := safesearch.NewDefault(conf, fmt.Sprintf("client %q", c.Name), cacheSize, cacheTTL) - if err != nil { - // Don't wrap the error, because it's informative enough as is. - return err - } - - c.SafeSearch = ss - - return nil -} diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 25565503..94811545 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -513,12 +513,14 @@ func TestSafeSearch(t *testing.T) { SafeSearchCacheSize: 1000, CacheTime: 30, } - safeSearch, err := safesearch.NewDefault( - safeSearchConf, - "", - filterConf.SafeSearchCacheSize, - time.Minute*time.Duration(filterConf.CacheTime), - ) + + ctx := testutil.ContextWithTimeout(t, testTimeout) + safeSearch, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{ + BaseLogger: slogutil.NewDiscardLogger(), + ServicesConfig: safeSearchConf, + CacheSize: filterConf.SafeSearchCacheSize, + CacheTTL: time.Minute * time.Duration(filterConf.CacheTime), + }) require.NoError(t, err) filterConf.SafeSearch = safeSearch diff --git a/internal/filtering/safesearch.go b/internal/filtering/safesearch.go index 50bba61d..b389573a 100644 --- a/internal/filtering/safesearch.go +++ b/internal/filtering/safesearch.go @@ -1,15 +1,17 @@ package filtering +import "context" + // SafeSearch interface describes a service for search engines hosts rewrites. type SafeSearch interface { // CheckHost checks host with safe search filter. CheckHost must be safe // for concurrent use. qtype must be either [dns.TypeA] or [dns.TypeAAAA]. - CheckHost(host string, qtype uint16) (res Result, err error) + CheckHost(ctx context.Context, host string, qtype uint16) (res Result, err error) // Update updates the configuration of the safe search filter. Update must // be safe for concurrent use. An implementation of Update may ignore some // fields, but it must document which. - Update(conf SafeSearchConfig) (err error) + Update(ctx context.Context, conf SafeSearchConfig) (err error) } // SafeSearchConfig is a struct with safe search related settings. @@ -40,10 +42,13 @@ func (d *DNSFilter) checkSafeSearch( return Result{}, nil } + // TODO(s.chzhen): Pass context. + ctx := context.TODO() + clientSafeSearch := setts.ClientSafeSearch if clientSafeSearch != nil { - return clientSafeSearch.CheckHost(host, qtype) + return clientSafeSearch.CheckHost(ctx, host, qtype) } - return d.safeSearch.CheckHost(host, qtype) + return d.safeSearch.CheckHost(ctx, host, qtype) } diff --git a/internal/filtering/safesearch/safesearch.go b/internal/filtering/safesearch/safesearch.go index 9417102f..02227cf8 100644 --- a/internal/filtering/safesearch/safesearch.go +++ b/internal/filtering/safesearch/safesearch.go @@ -3,9 +3,11 @@ package safesearch import ( "bytes" + "context" "encoding/binary" "encoding/gob" "fmt" + "log/slog" "net/netip" "strings" "sync" @@ -14,10 +16,11 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" "github.com/AdguardTeam/golibs/cache" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/urlfilter" "github.com/AdguardTeam/urlfilter/filterlist" "github.com/AdguardTeam/urlfilter/rules" + "github.com/c2h5oh/datasize" "github.com/miekg/dns" ) @@ -57,9 +60,32 @@ func isServiceProtected(s filtering.SafeSearchConfig, service Service) (ok bool) } } +// DefaultConfig is the configuration structure for [Default]. +type DefaultConfig struct { + // BaseLogger is used to create logger for [Default]. + BaseLogger *slog.Logger + + // ClientName is the name of the persistent client associated with the safe + // search filter, if there is one. + ClientName string + + // CacheSize is the size of the filter results cache. + CacheSize uint + + // CacheTTL is the Time to Live duration for cached items. + CacheTTL time.Duration + + // ServicesConfig contains safe search settings for services. It must not + // be nil. + ServicesConfig filtering.SafeSearchConfig +} + // Default is the default safe search filter that uses filtering rules with the // dnsrewrite modifier. type Default struct { + // logger is used for logging the operation of the safe search filter. + logger *slog.Logger + // mu protects engine. mu *sync.RWMutex @@ -67,33 +93,29 @@ type Default struct { // engine may be nil, which means that this safe search filter is disabled. engine *urlfilter.DNSEngine - cache cache.Cache - logPrefix string - cacheTTL time.Duration + cache cache.Cache + cacheTTL time.Duration } // NewDefault returns an initialized default safe search filter. name is used // for logging. -func NewDefault( - conf filtering.SafeSearchConfig, - name string, - cacheSize uint, - cacheTTL time.Duration, -) (ss *Default, err error) { - ss = &Default{ - mu: &sync.RWMutex{}, - - cache: cache.New(cache.Config{ - EnableLRU: true, - MaxSize: cacheSize, - }), - // Use %s, because the client safe-search names already contain double - // quotes. - logPrefix: fmt.Sprintf("safesearch %s: ", name), - cacheTTL: cacheTTL, +func NewDefault(ctx context.Context, conf *DefaultConfig) (ss *Default, err error) { + logger := conf.BaseLogger.With(slogutil.KeyPrefix, "safesearch") + if conf.ClientName != "" { + logger = logger.With("client", conf.ClientName) } - err = ss.resetEngine(rulelist.URLFilterIDSafeSearch, conf) + ss = &Default{ + logger: logger, + mu: &sync.RWMutex{}, + cache: cache.New(cache.Config{ + EnableLRU: true, + MaxSize: conf.CacheSize, + }), + cacheTTL: conf.CacheTTL, + } + + err = ss.resetEngine(ctx, rulelist.URLFilterIDSafeSearch, conf.ServicesConfig) if err != nil { // Don't wrap the error, because it's informative enough as is. return nil, err @@ -102,29 +124,15 @@ func NewDefault( return ss, nil } -// log is a helper for logging that includes the name of the safe search -// filter. level must be one of [log.DEBUG], [log.INFO], and [log.ERROR]. -func (ss *Default) log(level log.Level, msg string, args ...any) { - switch level { - case log.DEBUG: - log.Debug(ss.logPrefix+msg, args...) - case log.INFO: - log.Info(ss.logPrefix+msg, args...) - case log.ERROR: - log.Error(ss.logPrefix+msg, args...) - default: - panic(fmt.Errorf("safesearch: unsupported logging level %d", level)) - } -} - // resetEngine creates new engine for provided safe search configuration and // sets it in ss. func (ss *Default) resetEngine( + ctx context.Context, listID int, conf filtering.SafeSearchConfig, ) (err error) { if !conf.Enabled { - ss.log(log.INFO, "disabled") + ss.logger.InfoContext(ctx, "disabled") return nil } @@ -149,7 +157,7 @@ func (ss *Default) resetEngine( ss.engine = urlfilter.NewDNSEngine(rs) - ss.log(log.INFO, "reset %d rules", ss.engine.RulesCount) + ss.logger.InfoContext(ctx, "reset rules", "count", ss.engine.RulesCount) return nil } @@ -158,10 +166,14 @@ func (ss *Default) resetEngine( var _ filtering.SafeSearch = (*Default)(nil) // CheckHost implements the [filtering.SafeSearch] interface for *Default. -func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Result, err error) { +func (ss *Default) CheckHost( + ctx context.Context, + host string, + qtype rules.RRType, +) (res filtering.Result, err error) { start := time.Now() defer func() { - ss.log(log.DEBUG, "lookup for %q finished in %s", host, time.Since(start)) + ss.logger.DebugContext(ctx, "lookup finished", "host", host, "elapsed", time.Since(start)) }() switch qtype { @@ -172,9 +184,9 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res } // Check cache. Return cached result if it was found - cachedValue, isFound := ss.getCachedResult(host, qtype) + cachedValue, isFound := ss.getCachedResult(ctx, host, qtype) if isFound { - ss.log(log.DEBUG, "found in cache: %q", host) + ss.logger.DebugContext(ctx, "found in cache", "host", host) return cachedValue, nil } @@ -186,7 +198,7 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res fltRes, err := ss.newResult(rewrite, qtype) if err != nil { - ss.log(log.DEBUG, "looking up addresses for %q: %s", host, err) + ss.logger.ErrorContext(ctx, "looking up addresses", "host", host, slogutil.KeyError, err) return filtering.Result{}, err } @@ -195,7 +207,7 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res // TODO(a.garipov): Consider switch back to resolving CNAME records IPs and // saving results to cache. - ss.setCacheResult(host, qtype, res) + ss.setCacheResult(ctx, host, qtype, res) return res, nil } @@ -255,7 +267,12 @@ func (ss *Default) newResult( // setCacheResult stores data in cache for host. qtype is expected to be either // [dns.TypeA] or [dns.TypeAAAA]. -func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering.Result) { +func (ss *Default) setCacheResult( + ctx context.Context, + host string, + qtype rules.RRType, + res filtering.Result, +) { expire := uint32(time.Now().Add(ss.cacheTTL).Unix()) exp := make([]byte, 4) binary.BigEndian.PutUint32(exp, expire) @@ -263,7 +280,7 @@ func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering err := gob.NewEncoder(buf).Encode(res) if err != nil { - ss.log(log.ERROR, "cache encoding: %s", err) + ss.logger.ErrorContext(ctx, "cache encoding", slogutil.KeyError, err) return } @@ -271,12 +288,18 @@ func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering val := buf.Bytes() _ = ss.cache.Set([]byte(dns.Type(qtype).String()+" "+host), val) - ss.log(log.DEBUG, "stored in cache: %q, %d bytes", host, len(val)) + ss.logger.DebugContext( + ctx, + "stored in cache", + "host", host, + "entry_size", datasize.ByteSize(len(val)), + ) } // getCachedResult returns stored data from cache for host. qtype is expected // to be either [dns.TypeA] or [dns.TypeAAAA]. func (ss *Default) getCachedResult( + ctx context.Context, host string, qtype rules.RRType, ) (res filtering.Result, ok bool) { @@ -298,7 +321,7 @@ func (ss *Default) getCachedResult( err := gob.NewDecoder(buf).Decode(&res) if err != nil { - ss.log(log.ERROR, "cache decoding: %s", err) + ss.logger.ErrorContext(ctx, "cache decoding", slogutil.KeyError, err) return filtering.Result{}, false } @@ -308,11 +331,11 @@ func (ss *Default) getCachedResult( // Update implements the [filtering.SafeSearch] interface for *Default. Update // ignores the CustomResolver and Enabled fields. -func (ss *Default) Update(conf filtering.SafeSearchConfig) (err error) { +func (ss *Default) Update(ctx context.Context, conf filtering.SafeSearchConfig) (err error) { ss.mu.Lock() defer ss.mu.Unlock() - err = ss.resetEngine(rulelist.URLFilterIDSafeSearch, conf) + err = ss.resetEngine(ctx, rulelist.URLFilterIDSafeSearch, conf) if err != nil { // Don't wrap the error, because it's informative enough as is. return err diff --git a/internal/filtering/safesearch/safesearch_internal_test.go b/internal/filtering/safesearch/safesearch_internal_test.go index 24282b75..4779cc53 100644 --- a/internal/filtering/safesearch/safesearch_internal_test.go +++ b/internal/filtering/safesearch/safesearch_internal_test.go @@ -6,6 +6,8 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -21,6 +23,9 @@ const ( testCacheTTL = 30 * time.Minute ) +// testTimeout is the common timeout for tests and contexts. +const testTimeout = 1 * time.Second + var defaultSafeSearchConf = filtering.SafeSearchConfig{ Enabled: true, Bing: true, @@ -35,7 +40,12 @@ var defaultSafeSearchConf = filtering.SafeSearchConfig{ var yandexIP = netip.AddrFrom4([4]byte{213, 180, 193, 56}) func newForTest(t testing.TB, ssConf filtering.SafeSearchConfig) (ss *Default) { - ss, err := NewDefault(ssConf, "", testCacheSize, testCacheTTL) + ss, err := NewDefault(testutil.ContextWithTimeout(t, testTimeout), &DefaultConfig{ + BaseLogger: slogutil.NewDiscardLogger(), + ServicesConfig: ssConf, + CacheSize: testCacheSize, + CacheTTL: testCacheTTL, + }) require.NoError(t, err) return ss @@ -52,16 +62,17 @@ func TestSafeSearchCacheYandex(t *testing.T) { const domain = "yandex.ru" ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false}) + ctx := testutil.ContextWithTimeout(t, testTimeout) // Check host with disabled safesearch. - res, err := ss.CheckHost(domain, testQType) + res, err := ss.CheckHost(ctx, domain, testQType) require.NoError(t, err) assert.False(t, res.IsFiltered) assert.Empty(t, res.Rules) ss = newForTest(t, defaultSafeSearchConf) - res, err = ss.CheckHost(domain, testQType) + res, err = ss.CheckHost(ctx, domain, testQType) require.NoError(t, err) // For yandex we already know valid IP. @@ -70,7 +81,7 @@ func TestSafeSearchCacheYandex(t *testing.T) { assert.Equal(t, res.Rules[0].IP, yandexIP) // Check cache. - cachedValue, isFound := ss.getCachedResult(domain, testQType) + cachedValue, isFound := ss.getCachedResult(ctx, domain, testQType) require.True(t, isFound) require.Len(t, cachedValue.Rules, 1) diff --git a/internal/filtering/safesearch/safesearch_test.go b/internal/filtering/safesearch/safesearch_test.go index bcd3534d..688d54b5 100644 --- a/internal/filtering/safesearch/safesearch_test.go +++ b/internal/filtering/safesearch/safesearch_test.go @@ -10,15 +10,15 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestMain(m *testing.M) { - testutil.DiscardLogOutput(m) -} +// testTimeout is the common timeout for tests and contexts. +const testTimeout = 1 * time.Second // Common test constants. const ( @@ -47,7 +47,13 @@ var yandexIP = netip.AddrFrom4([4]byte{213, 180, 193, 56}) func TestDefault_CheckHost_yandex(t *testing.T) { conf := testConf - ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL) + ctx := testutil.ContextWithTimeout(t, testTimeout) + ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{ + BaseLogger: slogutil.NewDiscardLogger(), + ServicesConfig: conf, + CacheSize: testCacheSize, + CacheTTL: testCacheTTL, + }) require.NoError(t, err) hosts := []string{ @@ -82,7 +88,7 @@ func TestDefault_CheckHost_yandex(t *testing.T) { for _, host := range hosts { // Check host for each domain. var res filtering.Result - res, err = ss.CheckHost(host, tc.qt) + res, err = ss.CheckHost(ctx, host, tc.qt) require.NoError(t, err) assert.True(t, res.IsFiltered) @@ -103,7 +109,13 @@ func TestDefault_CheckHost_yandex(t *testing.T) { } func TestDefault_CheckHost_google(t *testing.T) { - ss, err := safesearch.NewDefault(testConf, "", testCacheSize, testCacheTTL) + ctx := testutil.ContextWithTimeout(t, testTimeout) + ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{ + BaseLogger: slogutil.NewDiscardLogger(), + ServicesConfig: testConf, + CacheSize: testCacheSize, + CacheTTL: testCacheTTL, + }) require.NoError(t, err) // Check host for each domain. @@ -118,7 +130,7 @@ func TestDefault_CheckHost_google(t *testing.T) { } { t.Run(host, func(t *testing.T) { var res filtering.Result - res, err = ss.CheckHost(host, testQType) + res, err = ss.CheckHost(ctx, host, testQType) require.NoError(t, err) assert.True(t, res.IsFiltered) @@ -149,13 +161,19 @@ func (r *testResolver) LookupIP( } func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) { - ss, err := safesearch.NewDefault(testConf, "", testCacheSize, testCacheTTL) + ctx := testutil.ContextWithTimeout(t, testTimeout) + ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{ + BaseLogger: slogutil.NewDiscardLogger(), + ServicesConfig: testConf, + CacheSize: testCacheSize, + CacheTTL: testCacheTTL, + }) require.NoError(t, err) // The DuckDuckGo safe-search addresses are resolved through CNAMEs, but // DuckDuckGo doesn't have a safe-search IPv6 address. The result should be // the same as the one for Yandex IPv6. That is, a NODATA response. - res, err := ss.CheckHost("www.duckduckgo.com", dns.TypeAAAA) + res, err := ss.CheckHost(ctx, "www.duckduckgo.com", dns.TypeAAAA) require.NoError(t, err) assert.True(t, res.IsFiltered) @@ -166,32 +184,38 @@ func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) { func TestDefault_Update(t *testing.T) { conf := testConf - ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL) + ctx := testutil.ContextWithTimeout(t, testTimeout) + ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{ + BaseLogger: slogutil.NewDiscardLogger(), + ServicesConfig: conf, + CacheSize: testCacheSize, + CacheTTL: testCacheTTL, + }) require.NoError(t, err) - res, err := ss.CheckHost("www.yandex.com", testQType) + res, err := ss.CheckHost(ctx, "www.yandex.com", testQType) require.NoError(t, err) assert.True(t, res.IsFiltered) - err = ss.Update(filtering.SafeSearchConfig{ + err = ss.Update(ctx, filtering.SafeSearchConfig{ Enabled: true, Google: false, }) require.NoError(t, err) - res, err = ss.CheckHost("www.yandex.com", testQType) + res, err = ss.CheckHost(ctx, "www.yandex.com", testQType) require.NoError(t, err) assert.False(t, res.IsFiltered) - err = ss.Update(filtering.SafeSearchConfig{ + err = ss.Update(ctx, filtering.SafeSearchConfig{ Enabled: false, Google: true, }) require.NoError(t, err) - res, err = ss.CheckHost("www.yandex.com", testQType) + res, err = ss.CheckHost(ctx, "www.yandex.com", testQType) require.NoError(t, err) assert.False(t, res.IsFiltered) diff --git a/internal/filtering/safesearchhttp.go b/internal/filtering/safesearchhttp.go index eb6fa401..8790b297 100644 --- a/internal/filtering/safesearchhttp.go +++ b/internal/filtering/safesearchhttp.go @@ -51,7 +51,7 @@ func (d *DNSFilter) handleSafeSearchSettings(w http.ResponseWriter, r *http.Requ } conf := *req - err = d.safeSearch.Update(conf) + err = d.safeSearch.Update(r.Context(), conf) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "updating: %s", err) diff --git a/internal/home/clients.go b/internal/home/clients.go index 5a30d6bc..5005cbf3 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -3,6 +3,7 @@ package home import ( "context" "fmt" + "log/slog" "net/netip" "slices" "sync" @@ -13,6 +14,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/schedule" "github.com/AdguardTeam/AdGuardHome/internal/whois" @@ -24,6 +26,10 @@ import ( // clientsContainer is the storage of all runtime and persistent clients. type clientsContainer struct { + // baseLogger is used to create loggers with custom prefixes for safe search + // filter. It must not be nil. + baseLogger *slog.Logger + // storage stores information about persistent clients. storage *client.Storage @@ -61,6 +67,8 @@ type BlockedClientChecker interface { // dhcpServer: optional // Note: this function must be called only once func (clients *clientsContainer) Init( + ctx context.Context, + baseLogger *slog.Logger, objects []*clientObject, dhcpServer client.DHCP, etcHosts *aghnet.HostsContainer, @@ -78,7 +86,7 @@ func (clients *clientsContainer) Init( confClients := make([]*client.Persistent, 0, len(objects)) for i, o := range objects { var p *client.Persistent - p, err = o.toPersistent(clients.safeSearchCacheSize, clients.safeSearchCacheTTL) + p, err = o.toPersistent(ctx, baseLogger, clients.safeSearchCacheSize, clients.safeSearchCacheTTL) if err != nil { return fmt.Errorf("init persistent client at index %d: %w", i, err) } @@ -168,6 +176,8 @@ type clientObject struct { // toPersistent returns an initialized persistent client if there are no errors. func (o *clientObject) toPersistent( + ctx context.Context, + baseLogger *slog.Logger, safeSearchCacheSize uint, safeSearchCacheTTL time.Duration, ) (cli *client.Persistent, err error) { @@ -203,14 +213,19 @@ func (o *clientObject) toPersistent( } if o.SafeSearchConf.Enabled { - err = cli.SetSafeSearch( - o.SafeSearchConf, - safeSearchCacheSize, - safeSearchCacheTTL, - ) + var ss *safesearch.Default + ss, err = safesearch.NewDefault(ctx, &safesearch.DefaultConfig{ + BaseLogger: baseLogger, + ServicesConfig: o.SafeSearchConf, + ClientName: cli.Name, + CacheSize: safeSearchCacheSize, + CacheTTL: safeSearchCacheTTL, + }) if err != nil { return nil, fmt.Errorf("init safesearch %q: %w", cli.Name, err) } + + cli.SafeSearch = ss } if o.BlockedServices == nil { diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index c23f4b23..13aed3cd 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -7,6 +7,8 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -20,7 +22,18 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) { testing: true, } - require.NoError(t, c.Init(nil, client.EmptyDHCP{}, nil, nil, &filtering.Config{})) + ctx := testutil.ContextWithTimeout(t, testTimeout) + err := c.Init( + ctx, + slogutil.NewDiscardLogger(), + nil, + client.EmptyDHCP{}, + nil, + nil, + &filtering.Config{}, + ) + + require.NoError(t, err) return c } diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index 73259d29..c5c4913d 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -1,6 +1,7 @@ package home import ( + "context" "encoding/json" "fmt" "net/http" @@ -10,6 +11,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/AdGuardHome/internal/schedule" "github.com/AdguardTeam/AdGuardHome/internal/whois" ) @@ -181,6 +183,7 @@ func initPrev(cj clientJSON, prev *client.Persistent) (c *client.Persistent, err // jsonToClient converts JSON object to persistent client object if there are no // errors. func (clients *clientsContainer) jsonToClient( + ctx context.Context, cj clientJSON, prev *client.Persistent, ) (c *client.Persistent, err error) { @@ -207,14 +210,19 @@ func (clients *clientsContainer) jsonToClient( c.UseOwnBlockedServices = !cj.UseGlobalBlockedServices if c.SafeSearchConf.Enabled { - err = c.SetSafeSearch( - c.SafeSearchConf, - clients.safeSearchCacheSize, - clients.safeSearchCacheTTL, - ) + var ss *safesearch.Default + ss, err = safesearch.NewDefault(ctx, &safesearch.DefaultConfig{ + BaseLogger: clients.baseLogger, + ServicesConfig: c.SafeSearchConf, + ClientName: c.Name, + CacheSize: clients.safeSearchCacheSize, + CacheTTL: clients.safeSearchCacheTTL, + }) if err != nil { return nil, fmt.Errorf("creating safesearch for client %q: %w", c.Name, err) } + + c.SafeSearch = ss } return c, nil @@ -321,7 +329,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http. return } - c, err := clients.jsonToClient(cj, nil) + c, err := clients.jsonToClient(r.Context(), cj, nil) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) @@ -391,7 +399,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht return } - c, err := clients.jsonToClient(dj.Data, nil) + c, err := clients.jsonToClient(r.Context(), dj.Data, nil) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) diff --git a/internal/home/clientshttp_internal_test.go b/internal/home/clientshttp_internal_test.go index 7c1f3dfa..117b9b6e 100644 --- a/internal/home/clientshttp_internal_test.go +++ b/internal/home/clientshttp_internal_test.go @@ -11,14 +11,19 @@ import ( "net/url" "slices" "testing" + "time" "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/schedule" + "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +// testTimeout is the common timeout for tests and contexts. +const testTimeout = 1 * time.Second + const ( testClientIP1 = "1.1.1.1" testClientIP2 = "2.2.2.2" @@ -105,7 +110,8 @@ func assertPersistentClients(tb testing.TB, clients *clientsContainer, want []*c var got []*client.Persistent for _, cj := range clientList.Clients { var c *client.Persistent - c, err = clients.jsonToClient(*cj, nil) + ctx := testutil.ContextWithTimeout(tb, testTimeout) + c, err = clients.jsonToClient(ctx, *cj, nil) require.NoError(tb, err) got = append(got, c) @@ -128,7 +134,8 @@ func assertPersistentClientsData( for _, cm := range data { for _, cj := range cm { var c *client.Persistent - c, err := clients.jsonToClient(*cj, nil) + ctx := testutil.ContextWithTimeout(tb, testTimeout) + c, err := clients.jsonToClient(ctx, *cj, nil) require.NoError(tb, err) got = append(got, c) diff --git a/internal/home/home.go b/internal/home/home.go index df4e4296..09812a91 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -278,8 +278,8 @@ func setupOpts(opts options) (err error) { } // initContextClients initializes Context clients and related fields. -func initContextClients(logger *slog.Logger) (err error) { - err = setupDNSFilteringConf(config.Filtering) +func initContextClients(ctx context.Context, logger *slog.Logger) (err error) { + err = setupDNSFilteringConf(ctx, logger, config.Filtering) if err != nil { // Don't wrap the error, because it's informative enough as is. return err @@ -306,6 +306,8 @@ func initContextClients(logger *slog.Logger) (err error) { } return Context.clients.Init( + ctx, + logger, config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, @@ -355,7 +357,11 @@ func setupBindOpts(opts options) (err error) { } // setupDNSFilteringConf sets up DNS filtering configuration settings. -func setupDNSFilteringConf(conf *filtering.Config) (err error) { +func setupDNSFilteringConf( + ctx context.Context, + baseLogger *slog.Logger, + conf *filtering.Config, +) (err error) { const ( dnsTimeout = 3 * time.Second @@ -446,12 +452,12 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) { conf.ParentalBlockHost = host } - conf.SafeSearch, err = safesearch.NewDefault( - conf.SafeSearchConf, - "default", - conf.SafeSearchCacheSize, - cacheTime, - ) + conf.SafeSearch, err = safesearch.NewDefault(ctx, &safesearch.DefaultConfig{ + BaseLogger: baseLogger, + ServicesConfig: conf.SafeSearchConf, + CacheSize: conf.SafeSearchCacheSize, + CacheTTL: cacheTime, + }) if err != nil { return fmt.Errorf("initializing safesearch: %w", err) } @@ -584,7 +590,10 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) { // data first, but also to avoid relying on automatic Go init() function. filtering.InitModule() - err = initContextClients(slogLogger) + // TODO(s.chzhen): Use it. + ctx := context.Background() + + err = initContextClients(ctx, slogLogger) fatalOnError(err) err = setupOpts(opts)