From 03c69ab2729eb424d768def986cba83731ad3e3b Mon Sep 17 00:00:00 2001
From: Stanislav Chzhen <s.chzhen@adguard.com>
Date: Wed, 21 Aug 2024 19:08:30 +0300
Subject: [PATCH] all: imp code

---
 internal/dnsforward/config.go               |  4 +-
 internal/dnsforward/dnsforward.go           | 10 ++--
 internal/dnsforward/ipset.go                | 51 +++++++++++----------
 internal/dnsforward/ipset_internal_test.go  | 11 +++--
 internal/ipset/ipset.go                     | 34 +++++++++-----
 internal/ipset/ipset_linux.go               | 34 +++++++-------
 internal/ipset/ipset_linux_internal_test.go | 17 +++++--
 internal/ipset/ipset_others.go              |  4 +-
 8 files changed, 94 insertions(+), 71 deletions(-)

diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go
index c5c45dee..1e735660 100644
--- a/internal/dnsforward/config.go
+++ b/internal/dnsforward/config.go
@@ -159,7 +159,7 @@ type Config struct {
 	// IpsetList is the ipset configuration that allows AdGuard Home to add IP
 	// addresses of the specified domain names to an ipset list.  Syntax:
 	//
-	//	DOMAIN[,DOMAIN].../IPSET_NAME
+	//	DOMAIN[,DOMAIN].../IPSET_NAME[,IPSET_NAME]...
 	//
 	// This field is ignored if [IpsetListFileName] is set.
 	IpsetList []string `yaml:"ipset"`
@@ -470,7 +470,7 @@ func (s *Server) prepareIpsetListSettings() (ipsets []string, err error) {
 	}
 
 	ipsets = stringutil.SplitTrimmed(string(data), "\n")
-	ipsets = stringutil.FilterOut(ipsets, IsCommentOrEmpty)
+	slices.DeleteFunc(ipsets, IsCommentOrEmpty)
 
 	log.Debug("dns: using %d ipset rules from file %q", len(ipsets), fn)
 
diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go
index c92f2221..997925f2 100644
--- a/internal/dnsforward/dnsforward.go
+++ b/internal/dnsforward/dnsforward.go
@@ -133,8 +133,9 @@ type Server struct {
 	// must be a valid domain name plus dots on each side.
 	localDomainSuffix string
 
-	// ipset processes DNS requests using ipset data.
-	ipset *ipsetCtx
+	// ipset processes DNS requests using ipset data.  It must not be nil after
+	// initialization.  See [newIpsetHandler].
+	ipset *ipsetHandler
 
 	// privateNets is the configured set of IP networks considered private.
 	privateNets netutil.SubnetSet
@@ -609,12 +610,13 @@ func (s *Server) prepareLocalResolvers() (uc *proxy.UpstreamConfig, err error) {
 // the primary DNS proxy instance.  It assumes s.serverLock is locked or the
 // Server not running.
 func (s *Server) prepareInternalDNS() (err error) {
-	ipsetConf, err := s.prepareIpsetListSettings()
+	ipsetList, err := s.prepareIpsetListSettings()
 	if err != nil {
 		return fmt.Errorf("preparing ipset settings: %w", err)
 	}
 
-	s.ipset, err = newIPSetCtx(s.logger, ipsetConf)
+	ipsetLogger := s.logger.With(slogutil.KeyPrefix, "ipset")
+	s.ipset, err = newIpsetHandler(ipsetLogger, ipsetList)
 	if err != nil {
 		// Don't wrap the error, because it's informative enough as is.
 		return err
diff --git a/internal/dnsforward/ipset.go b/internal/dnsforward/ipset.go
index 64adcfe4..f3c414a5 100644
--- a/internal/dnsforward/ipset.go
+++ b/internal/dnsforward/ipset.go
@@ -1,6 +1,7 @@
 package dnsforward
 
 import (
+	"context"
 	"fmt"
 	"log/slog"
 	"net"
@@ -13,21 +14,25 @@ import (
 	"github.com/miekg/dns"
 )
 
-// ipsetCtx is the ipset context.  ipsetMgr can be nil.
-type ipsetCtx struct {
+// ipsetHandler is the ipset context.  ipsetMgr can be nil.
+type ipsetHandler struct {
 	ipsetMgr ipset.Manager
 	logger   *slog.Logger
 }
 
-// newIPSetCtx returns a new initialized [ipsetCtx].  It is not safe for
-// concurrent use.
-func newIPSetCtx(logger *slog.Logger, ipsetConf []string) (c *ipsetCtx, err error) {
-	c = &ipsetCtx{
+// newIpsetHandler returns a new initialized [ipsetHandler].  It is not safe for
+// concurrent use.  c is always non-nil for [Server.Close].
+func newIpsetHandler(logger *slog.Logger, ipsetList []string) (c *ipsetHandler, err error) {
+	c = &ipsetHandler{
 		logger: logger,
 	}
-	ipsetLogger := logger.With(slogutil.KeyPrefix, "ipset")
-	c.ipsetMgr, err = ipset.NewManager(ipsetLogger, ipsetConf)
-	if errors.Is(err, os.ErrInvalid) || errors.Is(err, os.ErrPermission) {
+	c.ipsetMgr, err = ipset.NewManager(&ipset.Config{
+		Logger:    logger,
+		IpsetList: ipsetList,
+	})
+	if errors.Is(err, os.ErrInvalid) ||
+		errors.Is(err, os.ErrPermission) ||
+		errors.Is(err, errors.ErrUnsupported) {
 		// ipset cannot currently be initialized if the server was installed
 		// from Snap or when the user or the binary doesn't have the required
 		// permissions, or when the kernel doesn't support netfilter.
@@ -36,22 +41,18 @@ func newIPSetCtx(logger *slog.Logger, ipsetConf []string) (c *ipsetCtx, err erro
 		//
 		// TODO(a.garipov): The Snap problem can probably be solved if we add
 		// the netlink-connector interface plug.
-		logger.Warn("ipset: cannot initialize", slogutil.KeyError, err)
-
-		return c, nil
-	} else if errors.Is(err, errors.ErrUnsupported) {
-		logger.Warn("ipset: cannot initialize", slogutil.KeyError, err)
+		logger.Warn("cannot initialize", slogutil.KeyError, err)
 
 		return c, nil
 	} else if err != nil {
-		return nil, fmt.Errorf("initializing ipset: %w", err)
+		return c, fmt.Errorf("initializing ipset: %w", err)
 	}
 
 	return c, nil
 }
 
 // close closes the Linux Netfilter connections.
-func (c *ipsetCtx) close() (err error) {
+func (c *ipsetHandler) close() (err error) {
 	if c.ipsetMgr != nil {
 		return c.ipsetMgr.Close()
 	}
@@ -59,7 +60,7 @@ func (c *ipsetCtx) close() (err error) {
 	return nil
 }
 
-func (c *ipsetCtx) dctxIsfilled(dctx *dnsContext) (ok bool) {
+func (c *ipsetHandler) dctxIsfilled(dctx *dnsContext) (ok bool) {
 	return dctx != nil &&
 		dctx.responseFromUpstream &&
 		dctx.proxyCtx != nil &&
@@ -70,7 +71,7 @@ func (c *ipsetCtx) dctxIsfilled(dctx *dnsContext) (ok bool) {
 
 // skipIpsetProcessing returns true when the ipset processing can be skipped for
 // this request.
-func (c *ipsetCtx) skipIpsetProcessing(dctx *dnsContext) (ok bool) {
+func (c *ipsetHandler) skipIpsetProcessing(dctx *dnsContext) (ok bool) {
 	if c == nil || c.ipsetMgr == nil || !c.dctxIsfilled(dctx) {
 		return true
 	}
@@ -113,9 +114,11 @@ func ipsFromAnswer(ans []dns.RR) (ip4s, ip6s []net.IP) {
 }
 
 // process adds the resolved IP addresses to the domain's ipsets, if any.
-func (c *ipsetCtx) process(dctx *dnsContext) (rc resultCode) {
-	c.logger.Debug("ipset: started processing")
-	defer c.logger.Debug("ipset: finished processing")
+func (c *ipsetHandler) process(dctx *dnsContext) (rc resultCode) {
+	c.logger.Debug("started processing")
+	defer c.logger.Debug("finished processing")
+
+	ctx := context.TODO()
 
 	if c.skipIpsetProcessing(dctx) {
 		return resultCodeSuccess
@@ -127,15 +130,15 @@ func (c *ipsetCtx) process(dctx *dnsContext) (rc resultCode) {
 	host = strings.ToLower(host)
 
 	ip4s, ip6s := ipsFromAnswer(dctx.proxyCtx.Res.Answer)
-	n, err := c.ipsetMgr.Add(host, ip4s, ip6s)
+	n, err := c.ipsetMgr.Add(ctx, host, ip4s, ip6s)
 	if err != nil {
 		// Consider ipset errors non-critical to the request.
-		c.logger.Error("ipset: adding host ips", slogutil.KeyError, err)
+		c.logger.ErrorContext(ctx, "adding host ips", slogutil.KeyError, err)
 
 		return resultCodeSuccess
 	}
 
-	c.logger.Debug("ipset: added new ipset entries", "num", n)
+	c.logger.DebugContext(ctx, "added new ipset entries", "num", n)
 
 	return resultCodeSuccess
 }
diff --git a/internal/dnsforward/ipset_internal_test.go b/internal/dnsforward/ipset_internal_test.go
index 01e29832..09601ac6 100644
--- a/internal/dnsforward/ipset_internal_test.go
+++ b/internal/dnsforward/ipset_internal_test.go
@@ -1,6 +1,7 @@
 package dnsforward
 
 import (
+	"context"
 	"net"
 	"testing"
 
@@ -17,7 +18,7 @@ type fakeIpsetMgr struct {
 }
 
 // Add implements the aghnet.IpsetManager interface for *fakeIpsetMgr.
-func (m *fakeIpsetMgr) Add(host string, ip4s, ip6s []net.IP) (n int, err error) {
+func (m *fakeIpsetMgr) Add(_ context.Context, host string, ip4s, ip6s []net.IP) (n int, err error) {
 	m.ip4s = append(m.ip4s, ip4s...)
 	m.ip6s = append(m.ip6s, ip6s...)
 
@@ -59,7 +60,7 @@ func TestIpsetCtx_process(t *testing.T) {
 			responseFromUpstream: true,
 		}
 
-		ictx := &ipsetCtx{
+		ictx := &ipsetHandler{
 			logger: slogutil.NewDiscardLogger(),
 		}
 		rc := ictx.process(dctx)
@@ -80,7 +81,7 @@ func TestIpsetCtx_process(t *testing.T) {
 		}
 
 		m := &fakeIpsetMgr{}
-		ictx := &ipsetCtx{
+		ictx := &ipsetHandler{
 			ipsetMgr: m,
 			logger:   slogutil.NewDiscardLogger(),
 		}
@@ -105,7 +106,7 @@ func TestIpsetCtx_process(t *testing.T) {
 		}
 
 		m := &fakeIpsetMgr{}
-		ictx := &ipsetCtx{
+		ictx := &ipsetHandler{
 			ipsetMgr: m,
 			logger:   slogutil.NewDiscardLogger(),
 		}
@@ -129,7 +130,7 @@ func TestIpsetCtx_SkipIpsetProcessing(t *testing.T) {
 	}
 
 	m := &fakeIpsetMgr{}
-	ictx := &ipsetCtx{
+	ictx := &ipsetHandler{
 		ipsetMgr: m,
 		logger:   slogutil.NewDiscardLogger(),
 	}
diff --git a/internal/ipset/ipset.go b/internal/ipset/ipset.go
index f271c67e..5adfc3a6 100644
--- a/internal/ipset/ipset.go
+++ b/internal/ipset/ipset.go
@@ -2,6 +2,7 @@
 package ipset
 
 import (
+	"context"
 	"log/slog"
 	"net"
 )
@@ -11,24 +12,33 @@ import (
 // TODO(a.garipov): Perhaps generalize this into some kind of a NetFilter type,
 // since ipset is exclusive to Linux?
 type Manager interface {
-	Add(host string, ip4s, ip6s []net.IP) (n int, err error)
+	Add(ctx context.Context, host string, ip4s, ip6s []net.IP) (n int, err error)
 	Close() (err error)
 }
 
-// NewManager returns a new ipset manager.  IPv4 addresses are added to an
-// ipset with an ipv4 family; IPv6 addresses, to an ipv6 ipset.  ipset must
-// exist.
+// Config is the configuration structure for the ipset manager.
+type Config struct {
+	// Logger is used for logging the operation of the ipset manager.  It must
+	// not be nil.
+	Logger *slog.Logger
+
+	// IpsetList is the ipset configuration with the following syntax:
+	//
+	//	DOMAIN[,DOMAIN].../IPSET_NAME[,IPSET_NAME]...
+	//
+	// IpsetList must not contain any blank lines or comments.
+	IpsetList []string
+}
+
+// NewManager returns a new ipset manager.  IPv4 addresses are added to an ipset
+// with an ipv4 family; IPv6 addresses, to an ipv6 ipset.  ipset must exist.
 //
-// The syntax of the ipsetConf is:
-//
-//	DOMAIN[,DOMAIN].../IPSET_NAME[,IPSET_NAME]...
-//
-// If ipsetConf is empty, msg and err are nil.  The error's chain contains
+// If conf.IpsetList is empty, mgr and err are nil.  The error's chain contains
 // [errors.ErrUnsupported] if current OS is not supported.
-func NewManager(logger *slog.Logger, ipsetConf []string) (mgr Manager, err error) {
-	if len(ipsetConf) == 0 {
+func NewManager(conf *Config) (mgr Manager, err error) {
+	if len(conf.IpsetList) == 0 {
 		return nil, nil
 	}
 
-	return newManager(logger, ipsetConf)
+	return newManager(conf)
 }
diff --git a/internal/ipset/ipset_linux.go b/internal/ipset/ipset_linux.go
index 7e842e53..a8e992b8 100644
--- a/internal/ipset/ipset_linux.go
+++ b/internal/ipset/ipset_linux.go
@@ -4,6 +4,7 @@ package ipset
 
 import (
 	"bytes"
+	"context"
 	"fmt"
 	"log/slog"
 	"net"
@@ -35,8 +36,8 @@ import (
 //     resolved IP addresses.
 
 // newManager returns a new Linux ipset manager.
-func newManager(logger *slog.Logger, ipsetConf []string) (set Manager, err error) {
-	return newManagerWithDialer(logger, ipsetConf, defaultDial)
+func newManager(conf *Config) (set Manager, err error) {
+	return newManagerWithDialer(conf, defaultDial)
 }
 
 // defaultDial is the default netfilter dialing function.
@@ -339,8 +340,8 @@ func (m *manager) ipsets(names []string, currentlyKnown map[string]props) (sets
 		}
 
 		if p.family != netfilter.ProtoIPv4 && p.family != netfilter.ProtoIPv6 {
-			m.logger.Debug("getting properties",
-				slogutil.KeyError, "unexpected ipset family",
+			m.logger.Debug(
+				"got unexpected ipset family while getting set properties",
 				"set_name", p.name,
 				"set_type", p.typeName,
 				"set_family", p.family,
@@ -361,11 +362,7 @@ func (m *manager) ipsets(names []string, currentlyKnown map[string]props) (sets
 
 // newManagerWithDialer returns a new Linux ipset manager using the provided
 // dialer.
-func newManagerWithDialer(
-	logger *slog.Logger,
-	ipsetConf []string,
-	dial dialer,
-) (mgr Manager, err error) {
+func newManagerWithDialer(conf *Config, dial dialer) (mgr Manager, err error) {
 	defer func() { err = errors.Annotate(err, "ipset: %w") }()
 
 	m := &manager{
@@ -374,7 +371,7 @@ func newManagerWithDialer(
 		nameToIpset:    make(map[string]props),
 		domainToIpsets: make(map[string][]props),
 
-		logger: logger,
+		logger: conf.Logger,
 
 		dial: dial,
 
@@ -386,7 +383,7 @@ func newManagerWithDialer(
 		if errors.Is(err, unix.EPROTONOSUPPORT) {
 			// The implementation doesn't support this protocol version.  Just
 			// issue a warning.
-			logger.Warn("dialing netfilter", slogutil.KeyError, err)
+			m.logger.Warn("dialing netfilter", slogutil.KeyError, err)
 
 			return nil, nil
 		}
@@ -394,12 +391,12 @@ func newManagerWithDialer(
 		return nil, fmt.Errorf("dialing netfilter: %w", err)
 	}
 
-	err = m.parseIpsetConfig(ipsetConf)
+	err = m.parseIpsetConfig(conf.IpsetList)
 	if err != nil {
 		return nil, fmt.Errorf("getting ipsets: %w", err)
 	}
 
-	logger.Debug("initialized")
+	m.logger.Debug("initialized")
 
 	return m, nil
 }
@@ -486,6 +483,7 @@ func (m *manager) addIPs(host string, set props, ips []net.IP) (n int, err error
 
 // addToSets adds the IP addresses to the corresponding ipset.
 func (m *manager) addToSets(
+	ctx context.Context,
 	host string,
 	ip4s []net.IP,
 	ip6s []net.IP,
@@ -508,7 +506,9 @@ func (m *manager) addToSets(
 			return n, fmt.Errorf("%q %q unexpected family %q", set.name, set.typeName, set.family)
 		}
 
-		m.logger.Debug("added ips to set",
+		m.logger.DebugContext(
+			ctx,
+			"added ips to set",
 			"ips_num", nn,
 			"set_name", set.name,
 			"set_type", set.typeName,
@@ -521,7 +521,7 @@ func (m *manager) addToSets(
 }
 
 // Add implements the [Manager] interface for *manager.
-func (m *manager) Add(host string, ip4s, ip6s []net.IP) (n int, err error) {
+func (m *manager) Add(ctx context.Context, host string, ip4s, ip6s []net.IP) (n int, err error) {
 	m.mu.Lock()
 	defer m.mu.Unlock()
 
@@ -530,9 +530,9 @@ func (m *manager) Add(host string, ip4s, ip6s []net.IP) (n int, err error) {
 		return 0, nil
 	}
 
-	m.logger.Debug("found sets", "set_num", len(sets))
+	m.logger.DebugContext(ctx, "found sets", "set_num", len(sets))
 
-	return m.addToSets(host, ip4s, ip6s, sets)
+	return m.addToSets(ctx, host, ip4s, ip6s, sets)
 }
 
 // Close implements the [Manager] interface for *manager.
diff --git a/internal/ipset/ipset_linux_internal_test.go b/internal/ipset/ipset_linux_internal_test.go
index 31e9e823..f3805e85 100644
--- a/internal/ipset/ipset_linux_internal_test.go
+++ b/internal/ipset/ipset_linux_internal_test.go
@@ -6,9 +6,11 @@ import (
 	"net"
 	"strings"
 	"testing"
+	"time"
 
 	"github.com/AdguardTeam/golibs/errors"
 	"github.com/AdguardTeam/golibs/logutil/slogutil"
+	"github.com/AdguardTeam/golibs/testutil"
 	"github.com/digineo/go-ipset/v2"
 	"github.com/mdlayher/netlink"
 	"github.com/stretchr/testify/assert"
@@ -16,6 +18,9 @@ import (
 	"github.com/ti-mo/netfilter"
 )
 
+// testTimeout is a common timeout for tests and contexts.
+const testTimeout = 1 * time.Second
+
 // fakeConn is a fake ipsetConn for tests.
 type fakeConn struct {
 	ipv4Header  *ipset.HeaderPolicy
@@ -59,7 +64,7 @@ func (c *fakeConn) listAll() (sets []props, err error) {
 }
 
 func TestManager_Add(t *testing.T) {
-	ipsetConf := []string{
+	ipsetList := []string{
 		"example.com,example.net/ipv4set",
 		"example.org,example.biz/ipv6set",
 	}
@@ -90,7 +95,11 @@ func TestManager_Add(t *testing.T) {
 		}, nil
 	}
 
-	m, err := newManagerWithDialer(slogutil.NewDiscardLogger(), ipsetConf, fakeDial)
+	conf := &Config{
+		Logger:    slogutil.NewDiscardLogger(),
+		IpsetList: ipsetList,
+	}
+	m, err := newManagerWithDialer(conf, fakeDial)
 	require.NoError(t, err)
 
 	ip4 := net.IP{1, 2, 3, 4}
@@ -101,7 +110,7 @@ func TestManager_Add(t *testing.T) {
 		0x00, 0x00, 0x56, 0x78,
 	}
 
-	n, err := m.Add("example.net", []net.IP{ip4}, nil)
+	n, err := m.Add(testutil.ContextWithTimeout(t, testTimeout), "example.net", []net.IP{ip4}, nil)
 	require.NoError(t, err)
 
 	assert.Equal(t, 1, n)
@@ -111,7 +120,7 @@ func TestManager_Add(t *testing.T) {
 	gotIP4 := ipv4Entries[0].IP.Value
 	assert.Equal(t, ip4, gotIP4)
 
-	n, err = m.Add("example.biz", nil, []net.IP{ip6})
+	n, err = m.Add(testutil.ContextWithTimeout(t, testTimeout), "example.biz", nil, []net.IP{ip6})
 	require.NoError(t, err)
 
 	assert.Equal(t, 1, n)
diff --git a/internal/ipset/ipset_others.go b/internal/ipset/ipset_others.go
index 4090b0b3..577fd319 100644
--- a/internal/ipset/ipset_others.go
+++ b/internal/ipset/ipset_others.go
@@ -3,11 +3,9 @@
 package ipset
 
 import (
-	"log/slog"
-
 	"github.com/AdguardTeam/AdGuardHome/internal/aghos"
 )
 
-func newManager(_ *slog.Logger, _ []string) (mgr Manager, err error) {
+func newManager(_ *Config) (mgr Manager, err error) {
 	return nil, aghos.Unsupported("ipset")
 }