all: imp tests

This commit is contained in:
Stanislav Chzhen 2025-02-25 17:16:31 +03:00
parent 00f0eb6047
commit bdb08ee0e6
17 changed files with 205 additions and 93 deletions

View file

@ -4,8 +4,6 @@ import "github.com/AdguardTeam/dnsproxy/upstream"
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration // UpstreamHTTPVersions returns the HTTP versions for upstream configuration
// depending on configuration. // depending on configuration.
//
// TODO(s.chzhen): !! Use in the dnsforward package.
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) { func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
if !http3 { if !http3 {
return upstream.DefaultHTTPVersions return upstream.DefaultHTTPVersions

View file

@ -0,0 +1,26 @@
package aghnet_test
import (
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/stretchr/testify/assert"
)
func TestIsCommentOrEmpty(t *testing.T) {
for _, tc := range []struct {
want assert.BoolAssertionFunc
str string
}{{
want: assert.True,
str: "",
}, {
want: assert.True,
str: "# comment",
}, {
want: assert.False,
str: "1.2.3.4",
}} {
tc.want(t, aghnet.IsCommentOrEmpty(tc.str))
}
}

View file

@ -1282,3 +1282,72 @@ func TestStorage_RangeByName(t *testing.T) {
}) })
} }
} }
func TestStorage_CustomUpstreamConfig(t *testing.T) {
const (
existingName = "existing_name"
existingClientID = "existing_client_id"
nonExistingClientID = "non_existing_client_id"
)
var (
existingClientUID = client.MustNewUID()
existingIP = netip.MustParseAddr("192.0.2.1")
nonExistingIP = netip.MustParseAddr("192.0.2.255")
)
existingClient := &client.Persistent{
Name: existingName,
IPs: []netip.Addr{existingIP},
ClientIDs: []string{existingClientID},
UID: existingClientUID,
Upstreams: []string{"192.0.2.0"},
}
s := newTestStorage(t)
s.UpdateCommonUpstreamConfig(&client.CommonUpstreamConfig{})
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return s.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
})
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := s.Add(ctx, existingClient)
require.NoError(t, err)
testCases := []struct {
cliAddr netip.Addr
wantNilConf assert.BoolAssertionFunc
name string
cliID string
}{{
name: "client_id",
cliID: existingClientID,
cliAddr: netip.Addr{},
wantNilConf: assert.False,
}, {
name: "client_addr",
cliID: "",
cliAddr: existingIP,
wantNilConf: assert.False,
}, {
name: "non_existing_client_id",
cliID: nonExistingClientID,
cliAddr: netip.Addr{},
wantNilConf: assert.True,
}, {
name: "non_existing_client_addr",
cliID: "",
cliAddr: nonExistingIP,
wantNilConf: assert.True,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
conf := s.CustomUpstreamConfig(tc.cliID, tc.cliAddr)
tc.wantNilConf(t, conf == nil)
})
}
}

View file

@ -1,6 +1,7 @@
package client package client
import ( import (
"slices"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
@ -23,8 +24,11 @@ type CommonUpstreamConfig struct {
// customUpstreamConfig contains custom client upstream configuration and the // customUpstreamConfig contains custom client upstream configuration and the
// timestamp of the latest configuration update. // timestamp of the latest configuration update.
type customUpstreamConfig struct { type customUpstreamConfig struct {
prxConf *proxy.CustomUpstreamConfig prxConf *proxy.CustomUpstreamConfig
confUpdate time.Time confUpdate time.Time
upstreams []string
upstreamsCacheSize uint32
upstreamsCacheEnabled bool
} }
// upstreamManager stores and updates custom client upstream configurations. // upstreamManager stores and updates custom client upstream configurations.
@ -60,19 +64,40 @@ func (m *upstreamManager) customUpstreamConfig(
c *Persistent, c *Persistent,
) (prxConf *proxy.CustomUpstreamConfig) { ) (prxConf *proxy.CustomUpstreamConfig) {
cliConf, ok := m.uidToCustomConf[c.UID] cliConf, ok := m.uidToCustomConf[c.UID]
if ok && m.confUpdate.Equal(cliConf.confUpdate) { if ok && !m.isConfigChanged(c, cliConf) {
return cliConf.prxConf return cliConf.prxConf
} }
prxConf = newCustomUpstreamConfig(c, m.commonConf) prxConf = newCustomUpstreamConfig(c, m.commonConf)
m.uidToCustomConf[c.UID] = &customUpstreamConfig{ m.uidToCustomConf[c.UID] = &customUpstreamConfig{
prxConf: prxConf, prxConf: prxConf,
confUpdate: m.confUpdate, confUpdate: m.confUpdate,
upstreams: slices.Clone(c.Upstreams),
upstreamsCacheEnabled: c.UpstreamsCacheEnabled,
upstreamsCacheSize: c.UpstreamsCacheSize,
} }
return prxConf return prxConf
} }
// isConfigChanged returns true if the update is necessary for the custom client
// upstream configuration.
func (m *upstreamManager) isConfigChanged(c *Persistent, cliConf *customUpstreamConfig) (ok bool) {
if !slices.Equal(c.Upstreams, cliConf.upstreams) {
return true
}
if c.UpstreamsCacheEnabled != cliConf.upstreamsCacheEnabled {
return true
}
if c.UpstreamsCacheSize != cliConf.upstreamsCacheSize {
return true
}
return !m.confUpdate.Equal(cliConf.confUpdate)
}
// clearUpstreamCache clears the upstream cache for each stored custom client // clearUpstreamCache clears the upstream cache for each stored custom client
// upstream configuration. // upstream configuration.
func (m *upstreamManager) clearUpstreamCache() { func (m *upstreamManager) clearUpstreamCache() {
@ -97,6 +122,10 @@ func (m *upstreamManager) remove(c *Persistent) (err error) {
func (m *upstreamManager) close() (err error) { func (m *upstreamManager) close() (err error) {
var errs []error var errs []error
for _, c := range m.uidToCustomConf { for _, c := range m.uidToCustomConf {
if c.prxConf == nil {
continue
}
errs = append(errs, c.prxConf.Close()) errs = append(errs, c.prxConf.Close())
} }

View file

@ -266,6 +266,7 @@ func TestServer_HandleBefore_udp(t *testing.T) {
UpstreamDNS: []string{localUpsAddr}, UpstreamDNS: []string{localUpsAddr},
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
}) })

View file

@ -44,6 +44,29 @@ type ClientsContainer interface {
ClearUpstreamCache() ClearUpstreamCache()
} }
// EmptyClientsContainer is an [ClientsContainer] implementation that does nothing.
type EmptyClientsContainer struct{}
// type check
var _ ClientsContainer = EmptyClientsContainer{}
// CustomUpstreamConfig implements the [ClientsContainer] interface for
// EmptyClientsContainer.
func (EmptyClientsContainer) CustomUpstreamConfig(
clientID string,
cliAddr netip.Addr,
) (conf *proxy.CustomUpstreamConfig) {
return nil
}
// UpdateCommonUpstreamConfig implements the [ClientsContainer] interface for
// EmptyClientsContainer.
func (EmptyClientsContainer) UpdateCommonUpstreamConfig(conf *client.CommonUpstreamConfig) {}
// ClearUpstreamCache implements the [ClientsContainer] interface for
// EmptyClientsContainer.
func (EmptyClientsContainer) ClearUpstreamCache() {}
// Config represents the DNS filtering configuration of AdGuard Home. The zero // Config represents the DNS filtering configuration of AdGuard Home. The zero
// Config is empty and ready for use. // Config is empty and ready for use.
type Config struct { type Config struct {
@ -469,7 +492,7 @@ func (s *Server) prepareIpsetListSettings() (ipsets []string, err error) {
} }
ipsets = stringutil.SplitTrimmed(string(data), "\n") ipsets = stringutil.SplitTrimmed(string(data), "\n")
ipsets = slices.DeleteFunc(ipsets, IsCommentOrEmpty) ipsets = slices.DeleteFunc(ipsets, aghnet.IsCommentOrEmpty)
log.Debug("dns: using %d ipset rules from file %q", len(ipsets), fn) log.Debug("dns: using %d ipset rules from file %q", len(ipsets), fn)
@ -480,7 +503,7 @@ func (s *Server) prepareIpsetListSettings() (ipsets []string, err error) {
// the configuration itself. // the configuration itself.
func (conf *ServerConfig) loadUpstreams() (upstreams []string, err error) { func (conf *ServerConfig) loadUpstreams() (upstreams []string, err error) {
if conf.UpstreamDNSFileName == "" { if conf.UpstreamDNSFileName == "" {
return stringutil.FilterOut(conf.UpstreamDNS, IsCommentOrEmpty), nil return stringutil.FilterOut(conf.UpstreamDNS, aghnet.IsCommentOrEmpty), nil
} }
var data []byte var data []byte
@ -493,7 +516,7 @@ func (conf *ServerConfig) loadUpstreams() (upstreams []string, err error) {
log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), conf.UpstreamDNSFileName) log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), conf.UpstreamDNSFileName)
return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil return stringutil.FilterOut(upstreams, aghnet.IsCommentOrEmpty), nil
} }
// collectListenAddr adds addrPort to addrs. It also adds its port to // collectListenAddr adds addrPort to addrs. It also adds its port to

View file

@ -299,6 +299,7 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
Config: Config{ Config: Config{
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
UpstreamDNS: []string{upsAddr}, UpstreamDNS: []string{upsAddr},
}, },
UsePrivateRDNS: true, UsePrivateRDNS: true,
@ -337,6 +338,7 @@ func TestServer_dns64WithDisabledRDNS(t *testing.T) {
Config: Config{ Config: Config{
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
UpstreamDNS: []string{upsAddr}, UpstreamDNS: []string{upsAddr},
}, },
UsePrivateRDNS: false, UsePrivateRDNS: false,

View file

@ -540,7 +540,7 @@ func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
uc, err := newUpstreamConfig(upstreams, defaultDNS, &upstream.Options{ uc, err := newUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
Bootstrap: boot, Bootstrap: boot,
Timeout: s.conf.UpstreamTimeout, Timeout: s.conf.UpstreamTimeout,
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams), HTTPVersions: aghnet.UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
PreferIPv6: s.conf.BootstrapPreferIPv6, PreferIPv6: s.conf.BootstrapPreferIPv6,
// Use a customized set of RootCAs, because Go's default mechanism of // Use a customized set of RootCAs, because Go's default mechanism of
// loading TLS roots does not always work properly on some routers so we're // loading TLS roots does not always work properly on some routers so we're
@ -557,17 +557,13 @@ func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
} }
s.conf.UpstreamConfig = uc s.conf.UpstreamConfig = uc
s.conf.ClientsContainer.UpdateCommonUpstreamConfig(&client.CommonUpstreamConfig{
// TODO(s.chzhen): !! Fix tests. Bootstrap: boot,
if s.conf.ClientsContainer != nil { UpstreamTimeout: s.conf.UpstreamTimeout,
s.conf.ClientsContainer.UpdateCommonUpstreamConfig(&client.CommonUpstreamConfig{ BootstrapPreferIPv6: s.conf.BootstrapPreferIPv6,
Bootstrap: boot, EDNSClientSubnetEnabled: s.conf.EDNSClientSubnet.Enabled,
UpstreamTimeout: s.conf.UpstreamTimeout, UseHTTP3Upstreams: s.conf.UseHTTP3Upstreams,
BootstrapPreferIPv6: s.conf.BootstrapPreferIPv6, })
EDNSClientSubnetEnabled: s.conf.EDNSClientSubnet.Enabled,
UseHTTP3Upstreams: s.conf.UseHTTP3Upstreams,
})
}
return nil return nil
} }
@ -641,7 +637,7 @@ func (s *Server) prepareInternalDNS() (err error) {
bootOpts := &upstream.Options{ bootOpts := &upstream.Options{
Timeout: DefaultTimeout, Timeout: DefaultTimeout,
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams), HTTPVersions: aghnet.UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
} }
s.bootstrap, s.bootResolvers, err = newBootstrap(s.conf.BootstrapDNS, s.etcHosts, bootOpts) s.bootstrap, s.bootResolvers, err = newBootstrap(s.conf.BootstrapDNS, s.etcHosts, bootOpts)
@ -672,7 +668,7 @@ func (s *Server) prepareInternalDNS() (err error) {
// setupFallbackDNS initializes the fallback DNS servers. // setupFallbackDNS initializes the fallback DNS servers.
func (s *Server) setupFallbackDNS() (uc *proxy.UpstreamConfig, err error) { func (s *Server) setupFallbackDNS() (uc *proxy.UpstreamConfig, err error) {
fallbacks := s.conf.FallbackDNS fallbacks := s.conf.FallbackDNS
fallbacks = stringutil.FilterOut(fallbacks, IsCommentOrEmpty) fallbacks = stringutil.FilterOut(fallbacks, aghnet.IsCommentOrEmpty)
if len(fallbacks) == 0 { if len(fallbacks) == 0 {
return nil, nil return nil, nil
} }

View file

@ -205,6 +205,7 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte)
Config: Config{ Config: Config{
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
}) })
@ -334,6 +335,7 @@ func TestServer(t *testing.T) {
Config: Config{ Config: Config{
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
}) })
@ -374,6 +376,7 @@ func TestServer_timeout(t *testing.T) {
Config: Config{ Config: Config{
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
} }
@ -401,6 +404,7 @@ func TestServer_timeout(t *testing.T) {
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{ s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{
Enabled: false, Enabled: false,
} }
s.conf.Config.ClientsContainer = EmptyClientsContainer{}
err = s.Prepare(&s.conf) err = s.Prepare(&s.conf)
require.NoError(t, err) require.NoError(t, err)
@ -417,6 +421,7 @@ func TestServer_Prepare_fallbacks(t *testing.T) {
}, },
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
} }
@ -442,6 +447,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
Config: Config{ Config: Config{
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
}) })
@ -573,6 +579,7 @@ func TestSafeSearch(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{ EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false, Enabled: false,
}, },
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
} }
@ -666,6 +673,7 @@ func TestInvalidRequest(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{ EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false, Enabled: false,
}, },
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
}) })
@ -696,6 +704,7 @@ func TestBlockedRequest(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{ EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false, Enabled: false,
}, },
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
} }
@ -733,6 +742,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{ EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false, Enabled: false,
}, },
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
} }
@ -811,6 +821,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{ EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false, Enabled: false,
}, },
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
}) })
@ -845,6 +856,7 @@ func TestBlockCNAME(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{ EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false, Enabled: false,
}, },
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
} }
@ -921,6 +933,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{ EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false, Enabled: false,
}, },
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
} }
@ -967,6 +980,7 @@ func TestNullBlockedRequest(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{ EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false, Enabled: false,
}, },
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
} }
@ -1035,6 +1049,7 @@ func TestBlockedCustomIP(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{ EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false, Enabled: false,
}, },
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
} }
@ -1088,6 +1103,7 @@ func TestBlockedByHosts(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{ EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false, Enabled: false,
}, },
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
} }
@ -1140,6 +1156,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{ EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false, Enabled: false,
}, },
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
} }
@ -1201,6 +1218,7 @@ func TestRewrite(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{ EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false, Enabled: false,
}, },
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
})) }))
@ -1327,6 +1345,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
s.conf.TCPListenAddrs = []*net.TCPAddr{{}} s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false} s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
s.conf.Config.ClientsContainer = EmptyClientsContainer{}
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
err = s.Prepare(&s.conf) err = s.Prepare(&s.conf)
@ -1412,6 +1431,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
s.conf.TCPListenAddrs = []*net.TCPAddr{{}} s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false} s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
s.conf.Config.ClientsContainer = EmptyClientsContainer{}
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
err = s.Prepare(&s.conf) err = s.Prepare(&s.conf)
@ -1680,6 +1700,7 @@ func TestServer_Exchange(t *testing.T) {
UpstreamDNS: []string{upsAddr}, UpstreamDNS: []string{upsAddr},
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
}, },
LocalPTRResolvers: []string{localUpsAddr}, LocalPTRResolvers: []string{localUpsAddr},
UsePrivateRDNS: true, UsePrivateRDNS: true,
@ -1702,6 +1723,7 @@ func TestServer_Exchange(t *testing.T) {
UpstreamDNS: []string{upsAddr}, UpstreamDNS: []string{upsAddr},
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
}, },
LocalPTRResolvers: []string{}, LocalPTRResolvers: []string{},
ServePlainDNS: true, ServePlainDNS: true,

View file

@ -40,6 +40,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
Config: Config{ Config: Config{
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
}) })

View file

@ -36,6 +36,7 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{ EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false, Enabled: false,
}, },
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
} }

View file

@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
@ -647,7 +648,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
return return
} }
req.BootstrapDNS = stringutil.FilterOut(req.BootstrapDNS, IsCommentOrEmpty) req.BootstrapDNS = stringutil.FilterOut(req.BootstrapDNS, aghnet.IsCommentOrEmpty)
opts := &upstream.Options{ opts := &upstream.Options{
Timeout: s.conf.UpstreamTimeout, Timeout: s.conf.UpstreamTimeout,

View file

@ -83,6 +83,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
RatelimitSubnetLenIPv6: 56, RatelimitSubnetLenIPv6: 56,
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
}, },
ConfigModified: func() {}, ConfigModified: func() {},
ServePlainDNS: true, ServePlainDNS: true,
@ -164,6 +165,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
RatelimitSubnetLenIPv6: 56, RatelimitSubnetLenIPv6: 56,
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
}, },
ConfigModified: func() {}, ConfigModified: func() {},
ServePlainDNS: true, ServePlainDNS: true,
@ -299,24 +301,6 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
} }
} }
func TestIsCommentOrEmpty(t *testing.T) {
for _, tc := range []struct {
want assert.BoolAssertionFunc
str string
}{{
want: assert.True,
str: "",
}, {
want: assert.True,
str: "# comment",
}, {
want: assert.False,
str: "1.2.3.4",
}} {
tc.want(t, IsCommentOrEmpty(tc.str))
}
}
func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) { func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) {
t.Helper() t.Helper()
@ -388,6 +372,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
Config: Config{ Config: Config{
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
}) })

View file

@ -81,6 +81,7 @@ func TestServer_ProcessInitial(t *testing.T) {
AAAADisabled: tc.aaaaDisabled, AAAADisabled: tc.aaaaDisabled,
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
} }
@ -180,6 +181,7 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
AAAADisabled: tc.aaaaDisabled, AAAADisabled: tc.aaaaDisabled,
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
} }
@ -324,6 +326,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
HandleDDR: tc.ddrEnabled, HandleDDR: tc.ddrEnabled,
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
}, },
TLSConfig: TLSConfig{ TLSConfig: TLSConfig{
ServerName: ddrTestDomainName, ServerName: ddrTestDomainName,
@ -660,6 +663,7 @@ func TestServer_HandleDNSRequest_restrictLocal(t *testing.T) {
UpstreamDNS: []string{localUpsAddr}, UpstreamDNS: []string{localUpsAddr},
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
}, },
UsePrivateRDNS: true, UsePrivateRDNS: true,
LocalPTRResolvers: []string{localUpsAddr}, LocalPTRResolvers: []string{localUpsAddr},
@ -788,6 +792,7 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) {
Config: Config{ Config: Config{
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
}, },
UsePrivateRDNS: true, UsePrivateRDNS: true,
LocalPTRResolvers: []string{localUpsAddr}, LocalPTRResolvers: []string{localUpsAddr},
@ -816,6 +821,7 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) {
Config: Config{ Config: Config{
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
}, },
UsePrivateRDNS: false, UsePrivateRDNS: false,
LocalPTRResolvers: []string{localUpsAddr}, LocalPTRResolvers: []string{localUpsAddr},

View file

@ -19,6 +19,7 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
Config: Config{ Config: Config{
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
}, },
ServePlainDNS: true, ServePlainDNS: true,
}) })

View file

@ -94,7 +94,7 @@ func newPrivateConfig(
) (uc *proxy.UpstreamConfig, err error) { ) (uc *proxy.UpstreamConfig, err error) {
confNeedsFiltering := len(addrs) > 0 confNeedsFiltering := len(addrs) > 0
if confNeedsFiltering { if confNeedsFiltering {
addrs = stringutil.FilterOut(addrs, IsCommentOrEmpty) addrs = stringutil.FilterOut(addrs, aghnet.IsCommentOrEmpty)
} else { } else {
sysResolvers := slices.DeleteFunc(slices.Clone(sysResolvers.Addrs()), unwanted.Has) sysResolvers := slices.DeleteFunc(slices.Clone(sysResolvers.Addrs()), unwanted.Has)
addrs = make([]string, 0, len(sysResolvers)) addrs = make([]string, 0, len(sysResolvers))
@ -127,20 +127,6 @@ func newPrivateConfig(
return uc, nil return uc, nil
} }
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration
// depending on configuration.
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
if !http3 {
return upstream.DefaultHTTPVersions
}
return []upstream.HTTPVersion{
upstream.HTTPVersion3,
upstream.HTTPVersion2,
upstream.HTTPVersion11,
}
}
// setProxyUpstreamMode sets the upstream mode and related settings in conf // setProxyUpstreamMode sets the upstream mode and related settings in conf
// based on provided parameters. // based on provided parameters.
func setProxyUpstreamMode( func setProxyUpstreamMode(
@ -162,10 +148,3 @@ func setProxyUpstreamMode(
return nil return nil
} }
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
// This function is useful for filtering out non-upstream lines from upstream
// configs.
func IsCommentOrEmpty(s string) (ok bool) {
return len(s) == 0 || s[0] == '#'
}

View file

@ -1,14 +1,12 @@
package home package home
import ( import (
"net/netip"
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -36,29 +34,3 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) {
return c return c
} }
// TODO(s.chzhen): !! Move to client package.
func TestClientsCustomUpstream(t *testing.T) {
clients := newClientsContainer(t)
ctx := testutil.ContextWithTimeout(t, testTimeout)
// Add client with upstreams.
err := clients.storage.Add(ctx, &client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")},
Upstreams: []string{
"1.1.1.1",
"[/example.org/]8.8.8.8",
},
})
require.NoError(t, err)
clients.storage.UpdateCommonUpstreamConfig(&client.CommonUpstreamConfig{})
upsConf := clients.storage.CustomUpstreamConfig("", netip.MustParseAddr("1.2.3.4"))
assert.Nil(t, upsConf)
upsConf = clients.storage.CustomUpstreamConfig("", netip.MustParseAddr("1.1.1.1"))
require.NotNil(t, upsConf)
}