mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2025-02-16 09:59:49 +03:00
Pull request 1927: 6006-use-address-processor
Updates #6006. Squashed commit of the following: commit ac27db95c12858b6ef182a0bd4acebab67a23993 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Jul 18 15:47:17 2023 +0300 all: imp code commit 3936288512bfc2d44902ead6ab1bb5711f92b73c Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Mon Jul 17 19:23:46 2023 +0300 all: imp client resolving
This commit is contained in:
parent
dead10e033
commit
7bfad08dde
16 changed files with 443 additions and 397 deletions
|
@ -25,6 +25,8 @@ NOTE: Add new changes BELOW THIS COMMENT.
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
||||||
|
- Occasional client information lookup failures that could lead to the DNS
|
||||||
|
server getting stuck ([#6006]).
|
||||||
- `bufio.Scanner: token too long` and other errors when trying to add
|
- `bufio.Scanner: token too long` and other errors when trying to add
|
||||||
filtering-rule lists with lines over 1024 bytes long or containing cosmetic
|
filtering-rule lists with lines over 1024 bytes long or containing cosmetic
|
||||||
rules ([#6003]).
|
rules ([#6003]).
|
||||||
|
@ -35,6 +37,7 @@ NOTE: Add new changes BELOW THIS COMMENT.
|
||||||
the `Dockerfile`.
|
the `Dockerfile`.
|
||||||
|
|
||||||
[#6003]: https://github.com/AdguardTeam/AdGuardHome/issues/6003
|
[#6003]: https://github.com/AdguardTeam/AdGuardHome/issues/6003
|
||||||
|
[#6006]: https://github.com/AdguardTeam/AdGuardHome/issues/6006
|
||||||
|
|
||||||
<!--
|
<!--
|
||||||
NOTE: Add new changes ABOVE THIS COMMENT.
|
NOTE: Add new changes ABOVE THIS COMMENT.
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
@ -270,7 +271,13 @@ type ServerConfig struct {
|
||||||
UDPListenAddrs []*net.UDPAddr // UDP listen address
|
UDPListenAddrs []*net.UDPAddr // UDP listen address
|
||||||
TCPListenAddrs []*net.TCPAddr // TCP listen address
|
TCPListenAddrs []*net.TCPAddr // TCP listen address
|
||||||
UpstreamConfig *proxy.UpstreamConfig // Upstream DNS servers config
|
UpstreamConfig *proxy.UpstreamConfig // Upstream DNS servers config
|
||||||
OnDNSRequest func(d *proxy.DNSContext)
|
|
||||||
|
// AddrProcConf defines the configuration for the client IP processor.
|
||||||
|
// If nil, [client.EmptyAddrProc] is used.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): The use of [client.EmptyAddrProc] is a crutch for tests.
|
||||||
|
// Remove that.
|
||||||
|
AddrProcConf *client.DefaultAddrProcConfig
|
||||||
|
|
||||||
FilteringConfig
|
FilteringConfig
|
||||||
TLSConfig
|
TLSConfig
|
||||||
|
@ -298,9 +305,6 @@ type ServerConfig struct {
|
||||||
// DNS64Prefixes is a slice of NAT64 prefixes to be used for DNS64.
|
// DNS64Prefixes is a slice of NAT64 prefixes to be used for DNS64.
|
||||||
DNS64Prefixes []netip.Prefix
|
DNS64Prefixes []netip.Prefix
|
||||||
|
|
||||||
// ResolveClients signals if the RDNS should resolve clients' addresses.
|
|
||||||
ResolveClients bool
|
|
||||||
|
|
||||||
// UsePrivateRDNS defines if the PTR requests for unknown addresses from
|
// UsePrivateRDNS defines if the PTR requests for unknown addresses from
|
||||||
// locally-served networks should be resolved via private PTR resolvers.
|
// locally-served networks should be resolved via private PTR resolvers.
|
||||||
UsePrivateRDNS bool
|
UsePrivateRDNS bool
|
||||||
|
|
57
internal/dnsforward/dialcontext.go
Normal file
57
internal/dnsforward/dialcontext.go
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
package dnsforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DialContext is a [whois.DialContextFunc] that uses s to resolve hostnames.
|
||||||
|
func (s *Server) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
||||||
|
log.Debug("dnsforward: dialing %q for network %q", addr, network)
|
||||||
|
|
||||||
|
host, port, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := &net.Dialer{
|
||||||
|
// TODO(a.garipov): Consider making configurable.
|
||||||
|
Timeout: time.Minute * 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
if net.ParseIP(host) != nil {
|
||||||
|
return dialer.DialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
addrs, err := s.Resolve(host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("resolving %q: %w", host, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("dnsforward: resolving %q: %v", host, addrs)
|
||||||
|
|
||||||
|
if len(addrs) == 0 {
|
||||||
|
return nil, fmt.Errorf("no addresses for host %q", host)
|
||||||
|
}
|
||||||
|
|
||||||
|
var dialErrs []error
|
||||||
|
for _, a := range addrs {
|
||||||
|
addr = net.JoinHostPort(a.String(), port)
|
||||||
|
conn, err = dialer.DialContext(ctx, network, addr)
|
||||||
|
if err != nil {
|
||||||
|
dialErrs = append(dialErrs, err)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(a.garipov): Use errors.Join in Go 1.20.
|
||||||
|
return nil, errors.List(fmt.Sprintf("dialing %q", addr), dialErrs...)
|
||||||
|
}
|
|
@ -14,6 +14,7 @@ import (
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||||
|
@ -99,6 +100,10 @@ type Server struct {
|
||||||
// must be a valid domain name plus dots on each side.
|
// must be a valid domain name plus dots on each side.
|
||||||
localDomainSuffix string
|
localDomainSuffix string
|
||||||
|
|
||||||
|
// addrProc, if not nil, is used to process clients' IP addresses with rDNS,
|
||||||
|
// WHOIS, etc.
|
||||||
|
addrProc client.AddressProcessor
|
||||||
|
|
||||||
ipset ipsetCtx
|
ipset ipsetCtx
|
||||||
privateNets netutil.SubnetSet
|
privateNets netutil.SubnetSet
|
||||||
localResolvers *proxy.Proxy
|
localResolvers *proxy.Proxy
|
||||||
|
@ -170,6 +175,9 @@ const (
|
||||||
|
|
||||||
// NewServer creates a new instance of the dnsforward.Server
|
// NewServer creates a new instance of the dnsforward.Server
|
||||||
// Note: this function must be called only once
|
// Note: this function must be called only once
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): How many constructors and initializers does this thing have?
|
||||||
|
// Refactor!
|
||||||
func NewServer(p DNSCreateParams) (s *Server, err error) {
|
func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||||
var localDomainSuffix string
|
var localDomainSuffix string
|
||||||
if p.LocalDomain == "" {
|
if p.LocalDomain == "" {
|
||||||
|
@ -257,14 +265,25 @@ func (s *Server) WriteDiskConfig(c *FilteringConfig) {
|
||||||
c.UpstreamDNS = stringutil.CloneSlice(sc.UpstreamDNS)
|
c.UpstreamDNS = stringutil.CloneSlice(sc.UpstreamDNS)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RDNSSettings returns the copy of actual RDNS configuration.
|
// LocalPTRResolvers returns the current local PTR resolver configuration.
|
||||||
func (s *Server) RDNSSettings() (localPTRResolvers []string, resolveClients, resolvePTR bool) {
|
func (s *Server) LocalPTRResolvers() (localPTRResolvers []string) {
|
||||||
s.serverLock.RLock()
|
s.serverLock.RLock()
|
||||||
defer s.serverLock.RUnlock()
|
defer s.serverLock.RUnlock()
|
||||||
|
|
||||||
return stringutil.CloneSlice(s.conf.LocalPTRResolvers),
|
return stringutil.CloneSlice(s.conf.LocalPTRResolvers)
|
||||||
s.conf.ResolveClients,
|
}
|
||||||
s.conf.UsePrivateRDNS
|
|
||||||
|
// AddrProcConfig returns the current address processing configuration. Only
|
||||||
|
// fields c.UsePrivateRDNS, c.UseRDNS, and c.UseWHOIS are filled.
|
||||||
|
func (s *Server) AddrProcConfig() (c *client.DefaultAddrProcConfig) {
|
||||||
|
s.serverLock.RLock()
|
||||||
|
defer s.serverLock.RUnlock()
|
||||||
|
|
||||||
|
return &client.DefaultAddrProcConfig{
|
||||||
|
UsePrivateRDNS: s.conf.UsePrivateRDNS,
|
||||||
|
UseRDNS: s.conf.AddrProcConf.UseRDNS,
|
||||||
|
UseWHOIS: s.conf.AddrProcConf.UseWHOIS,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve - get IP addresses by host name from an upstream server.
|
// Resolve - get IP addresses by host name from an upstream server.
|
||||||
|
@ -296,10 +315,6 @@ func (s *Server) Exchange(ip netip.Addr) (host string, err error) {
|
||||||
s.serverLock.RLock()
|
s.serverLock.RLock()
|
||||||
defer s.serverLock.RUnlock()
|
defer s.serverLock.RUnlock()
|
||||||
|
|
||||||
if !s.conf.ResolveClients {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
arpa, err := netutil.IPToReversedAddr(ip.AsSlice())
|
arpa, err := netutil.IPToReversedAddr(ip.AsSlice())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("reversing ip: %w", err)
|
return "", fmt.Errorf("reversing ip: %w", err)
|
||||||
|
@ -318,14 +333,15 @@ func (s *Server) Exchange(ip netip.Addr) (host string, err error) {
|
||||||
Qclass: dns.ClassINET,
|
Qclass: dns.ClassINET,
|
||||||
}},
|
}},
|
||||||
}
|
}
|
||||||
ctx := &proxy.DNSContext{
|
|
||||||
|
dctx := &proxy.DNSContext{
|
||||||
Proto: "udp",
|
Proto: "udp",
|
||||||
Req: req,
|
Req: req,
|
||||||
StartTime: time.Now(),
|
StartTime: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
var resolver *proxy.Proxy
|
var resolver *proxy.Proxy
|
||||||
if s.isPrivateIP(ip) {
|
if s.privateNets.Contains(ip.AsSlice()) {
|
||||||
if !s.conf.UsePrivateRDNS {
|
if !s.conf.UsePrivateRDNS {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
@ -336,11 +352,11 @@ func (s *Server) Exchange(ip netip.Addr) (host string, err error) {
|
||||||
resolver = s.internalProxy
|
resolver = s.internalProxy
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = resolver.Resolve(ctx); err != nil {
|
if err = resolver.Resolve(dctx); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return hostFromPTR(ctx.Res)
|
return hostFromPTR(dctx.Res)
|
||||||
}
|
}
|
||||||
|
|
||||||
// hostFromPTR returns domain name from the PTR response or error.
|
// hostFromPTR returns domain name from the PTR response or error.
|
||||||
|
@ -364,27 +380,6 @@ func hostFromPTR(resp *dns.Msg) (host string, err error) {
|
||||||
return "", ErrRDNSNoData
|
return "", ErrRDNSNoData
|
||||||
}
|
}
|
||||||
|
|
||||||
// isPrivateIP returns true if the ip is private.
|
|
||||||
func (s *Server) isPrivateIP(ip netip.Addr) (ok bool) {
|
|
||||||
return s.privateNets.Contains(ip.AsSlice())
|
|
||||||
}
|
|
||||||
|
|
||||||
// ShouldResolveClient returns false if ip is a loopback address, or ip is
|
|
||||||
// private and resolving of private addresses is disabled.
|
|
||||||
func (s *Server) ShouldResolveClient(ip netip.Addr) (ok bool) {
|
|
||||||
if ip.IsLoopback() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
isPrivate := s.isPrivateIP(ip)
|
|
||||||
|
|
||||||
s.serverLock.RLock()
|
|
||||||
defer s.serverLock.RUnlock()
|
|
||||||
|
|
||||||
return s.conf.ResolveClients &&
|
|
||||||
(s.conf.UsePrivateRDNS || !isPrivate)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start starts the DNS server.
|
// Start starts the DNS server.
|
||||||
func (s *Server) Start() error {
|
func (s *Server) Start() error {
|
||||||
s.serverLock.Lock()
|
s.serverLock.Lock()
|
||||||
|
@ -555,6 +550,24 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||||
|
|
||||||
s.recDetector.clear()
|
s.recDetector.clear()
|
||||||
|
|
||||||
|
if s.conf.AddrProcConf == nil {
|
||||||
|
// TODO(a.garipov): This is a crutch for tests; remove.
|
||||||
|
s.conf.AddrProcConf = &client.DefaultAddrProcConfig{}
|
||||||
|
s.addrProc = client.EmptyAddrProc{}
|
||||||
|
} else {
|
||||||
|
c := s.conf.AddrProcConf
|
||||||
|
c.DialContext = s.DialContext
|
||||||
|
c.PrivateSubnets = s.privateNets
|
||||||
|
c.UsePrivateRDNS = s.conf.UsePrivateRDNS
|
||||||
|
s.addrProc = client.NewDefaultAddrProc(s.conf.AddrProcConf)
|
||||||
|
|
||||||
|
// Clear the initial addresses to not resolve them again.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Consider ways of removing this once more client
|
||||||
|
// logic is moved to package client.
|
||||||
|
c.InitialAddresses = nil
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -696,6 +709,11 @@ func (s *Server) Reconfigure(conf *ServerConfig) error {
|
||||||
// TODO(a.garipov): This whole piece of API is weird and needs to be remade.
|
// TODO(a.garipov): This whole piece of API is weird and needs to be remade.
|
||||||
if conf == nil {
|
if conf == nil {
|
||||||
conf = &s.conf
|
conf = &s.conf
|
||||||
|
} else {
|
||||||
|
closeErr := s.addrProc.Close()
|
||||||
|
if closeErr != nil {
|
||||||
|
log.Error("dnsforward: closing address processor: %s", closeErr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.Prepare(conf)
|
err = s.Prepare(conf)
|
||||||
|
|
|
@ -39,11 +39,29 @@ func TestMain(m *testing.M) {
|
||||||
testutil.DiscardLogOutput(m)
|
testutil.DiscardLogOutput(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testTimeout is the common timeout for tests.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Use more.
|
||||||
|
const testTimeout = 1 * time.Second
|
||||||
|
|
||||||
|
// testQuestionTarget is the common question target for tests.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Use more.
|
||||||
|
const testQuestionTarget = "target.example"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
tlsServerName = "testdns.adguard.com"
|
tlsServerName = "testdns.adguard.com"
|
||||||
testMessagesCount = 10
|
testMessagesCount = 10
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// testClientAddr is the common net.Addr for tests.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Use more.
|
||||||
|
var testClientAddr net.Addr = &net.TCPAddr{
|
||||||
|
IP: net.IP{1, 2, 3, 4},
|
||||||
|
Port: 12345,
|
||||||
|
}
|
||||||
|
|
||||||
func startDeferStop(t *testing.T, s *Server) {
|
func startDeferStop(t *testing.T, s *Server) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
@ -53,6 +71,13 @@ func startDeferStop(t *testing.T, s *Server) {
|
||||||
testutil.CleanupAndRequireSuccess(t, s.Stop)
|
testutil.CleanupAndRequireSuccess(t, s.Stop)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// packageUpstreamVariableMu is used to serialize access to the package-level
|
||||||
|
// variables of package upstream.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Move these parameters to upstream options and remove this
|
||||||
|
// crutch.
|
||||||
|
var packageUpstreamVariableMu = &sync.Mutex{}
|
||||||
|
|
||||||
func createTestServer(
|
func createTestServer(
|
||||||
t *testing.T,
|
t *testing.T,
|
||||||
filterConf *filtering.Config,
|
filterConf *filtering.Config,
|
||||||
|
@ -61,6 +86,9 @@ func createTestServer(
|
||||||
) (s *Server) {
|
) (s *Server) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
packageUpstreamVariableMu.Lock()
|
||||||
|
defer packageUpstreamVariableMu.Unlock()
|
||||||
|
|
||||||
rules := `||nxdomain.example.org
|
rules := `||nxdomain.example.org
|
||||||
||NULL.example.org^
|
||NULL.example.org^
|
||||||
127.0.0.1 host.example.org
|
127.0.0.1 host.example.org
|
||||||
|
@ -307,11 +335,9 @@ func TestServer(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_timeout(t *testing.T) {
|
func TestServer_timeout(t *testing.T) {
|
||||||
const timeout time.Duration = time.Second
|
|
||||||
|
|
||||||
t.Run("custom", func(t *testing.T) {
|
t.Run("custom", func(t *testing.T) {
|
||||||
srvConf := &ServerConfig{
|
srvConf := &ServerConfig{
|
||||||
UpstreamTimeout: timeout,
|
UpstreamTimeout: testTimeout,
|
||||||
FilteringConfig: FilteringConfig{
|
FilteringConfig: FilteringConfig{
|
||||||
BlockingMode: BlockingModeDefault,
|
BlockingMode: BlockingModeDefault,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
@ -324,7 +350,7 @@ func TestServer_timeout(t *testing.T) {
|
||||||
err = s.Prepare(srvConf)
|
err = s.Prepare(srvConf)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, timeout, s.conf.UpstreamTimeout)
|
assert.Equal(t, testTimeout, s.conf.UpstreamTimeout)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("default", func(t *testing.T) {
|
t.Run("default", func(t *testing.T) {
|
||||||
|
@ -545,7 +571,7 @@ func TestInvalidRequest(t *testing.T) {
|
||||||
|
|
||||||
// Send a DNS request without question.
|
// Send a DNS request without question.
|
||||||
_, _, err := (&dns.Client{
|
_, _, err := (&dns.Client{
|
||||||
Timeout: 500 * time.Millisecond,
|
Timeout: testTimeout,
|
||||||
}).Exchange(&req, addr)
|
}).Exchange(&req, addr)
|
||||||
|
|
||||||
assert.NoErrorf(t, err, "got a response to an invalid query")
|
assert.NoErrorf(t, err, "got a response to an invalid query")
|
||||||
|
@ -1320,9 +1346,7 @@ func TestServer_Exchange(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
srv.conf.ResolveClients = true
|
|
||||||
srv.conf.UsePrivateRDNS = true
|
srv.conf.UsePrivateRDNS = true
|
||||||
|
|
||||||
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
|
@ -1396,57 +1420,3 @@ func TestServer_Exchange(t *testing.T) {
|
||||||
assert.Empty(t, host)
|
assert.Empty(t, host)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_ShouldResolveClient(t *testing.T) {
|
|
||||||
srv := &Server{
|
|
||||||
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
ip netip.Addr
|
|
||||||
want require.BoolAssertionFunc
|
|
||||||
name string
|
|
||||||
resolve bool
|
|
||||||
usePrivate bool
|
|
||||||
}{{
|
|
||||||
name: "default",
|
|
||||||
ip: netip.MustParseAddr("1.1.1.1"),
|
|
||||||
want: require.True,
|
|
||||||
resolve: true,
|
|
||||||
usePrivate: true,
|
|
||||||
}, {
|
|
||||||
name: "no_rdns",
|
|
||||||
ip: netip.MustParseAddr("1.1.1.1"),
|
|
||||||
want: require.False,
|
|
||||||
resolve: false,
|
|
||||||
usePrivate: true,
|
|
||||||
}, {
|
|
||||||
name: "loopback",
|
|
||||||
ip: netip.MustParseAddr("127.0.0.1"),
|
|
||||||
want: require.False,
|
|
||||||
resolve: true,
|
|
||||||
usePrivate: true,
|
|
||||||
}, {
|
|
||||||
name: "private_resolve",
|
|
||||||
ip: netip.MustParseAddr("192.168.0.1"),
|
|
||||||
want: require.True,
|
|
||||||
resolve: true,
|
|
||||||
usePrivate: true,
|
|
||||||
}, {
|
|
||||||
name: "private_no_resolve",
|
|
||||||
ip: netip.MustParseAddr("192.168.0.1"),
|
|
||||||
want: require.False,
|
|
||||||
resolve: true,
|
|
||||||
usePrivate: false,
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
srv.conf.ResolveClients = tc.resolve
|
|
||||||
srv.conf.UsePrivateRDNS = tc.usePrivate
|
|
||||||
|
|
||||||
ok := srv.ShouldResolveClient(tc.ip)
|
|
||||||
tc.want(t, ok)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -50,10 +50,10 @@ func (s *Server) beforeRequestHandler(
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getClientRequestFilteringSettings looks up client filtering settings using
|
// clientRequestFilteringSettings looks up client filtering settings using the
|
||||||
// the client's IP address and ID, if any, from dctx.
|
// client's IP address and ID, if any, from dctx.
|
||||||
func (s *Server) getClientRequestFilteringSettings(dctx *dnsContext) *filtering.Settings {
|
func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) {
|
||||||
setts := s.dnsFilter.Settings()
|
setts = s.dnsFilter.Settings()
|
||||||
setts.ProtectionEnabled = dctx.protectionEnabled
|
setts.ProtectionEnabled = dctx.protectionEnabled
|
||||||
if s.conf.FilterHandler != nil {
|
if s.conf.FilterHandler != nil {
|
||||||
ip, _ := netutil.IPAndPortFromAddr(dctx.proxyCtx.Addr)
|
ip, _ := netutil.IPAndPortFromAddr(dctx.proxyCtx.Addr)
|
||||||
|
|
|
@ -124,7 +124,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||||
cacheMinTTL := s.conf.CacheMinTTL
|
cacheMinTTL := s.conf.CacheMinTTL
|
||||||
cacheMaxTTL := s.conf.CacheMaxTTL
|
cacheMaxTTL := s.conf.CacheMaxTTL
|
||||||
cacheOptimistic := s.conf.CacheOptimistic
|
cacheOptimistic := s.conf.CacheOptimistic
|
||||||
resolveClients := s.conf.ResolveClients
|
resolveClients := s.conf.AddrProcConf.UseRDNS
|
||||||
usePrivateRDNS := s.conf.UsePrivateRDNS
|
usePrivateRDNS := s.conf.UsePrivateRDNS
|
||||||
localPTRUpstreams := stringutil.CloneSliceOrEmpty(s.conf.LocalPTRResolvers)
|
localPTRUpstreams := stringutil.CloneSliceOrEmpty(s.conf.LocalPTRResolvers)
|
||||||
|
|
||||||
|
@ -314,8 +314,6 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||||
setIfNotNil(&s.conf.ProtectionEnabled, dc.ProtectionEnabled)
|
setIfNotNil(&s.conf.ProtectionEnabled, dc.ProtectionEnabled)
|
||||||
setIfNotNil(&s.conf.EnableDNSSEC, dc.DNSSECEnabled)
|
setIfNotNil(&s.conf.EnableDNSSEC, dc.DNSSECEnabled)
|
||||||
setIfNotNil(&s.conf.AAAADisabled, dc.DisableIPv6)
|
setIfNotNil(&s.conf.AAAADisabled, dc.DisableIPv6)
|
||||||
setIfNotNil(&s.conf.ResolveClients, dc.ResolveClients)
|
|
||||||
setIfNotNil(&s.conf.UsePrivateRDNS, dc.UsePrivateRDNS)
|
|
||||||
|
|
||||||
return s.setConfigRestartable(dc)
|
return s.setConfigRestartable(dc)
|
||||||
}
|
}
|
||||||
|
@ -335,6 +333,9 @@ func setIfNotNil[T any](currentPtr, newPtr *T) (hasSet bool) {
|
||||||
// setConfigRestartable sets the parameters which trigger a restart.
|
// setConfigRestartable sets the parameters which trigger a restart.
|
||||||
// shouldRestart is true if the server should be restarted to apply changes.
|
// shouldRestart is true if the server should be restarted to apply changes.
|
||||||
// s.serverLock is expected to be locked.
|
// s.serverLock is expected to be locked.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Some of these could probably be updated without a restart.
|
||||||
|
// Inspect and consider refactoring.
|
||||||
func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
|
func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||||
for _, hasSet := range []bool{
|
for _, hasSet := range []bool{
|
||||||
setIfNotNil(&s.conf.UpstreamDNS, dc.Upstreams),
|
setIfNotNil(&s.conf.UpstreamDNS, dc.Upstreams),
|
||||||
|
@ -347,6 +348,8 @@ func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||||
setIfNotNil(&s.conf.CacheMinTTL, dc.CacheMinTTL),
|
setIfNotNil(&s.conf.CacheMinTTL, dc.CacheMinTTL),
|
||||||
setIfNotNil(&s.conf.CacheMaxTTL, dc.CacheMaxTTL),
|
setIfNotNil(&s.conf.CacheMaxTTL, dc.CacheMaxTTL),
|
||||||
setIfNotNil(&s.conf.CacheOptimistic, dc.CacheOptimistic),
|
setIfNotNil(&s.conf.CacheOptimistic, dc.CacheOptimistic),
|
||||||
|
setIfNotNil(&s.conf.AddrProcConf.UseRDNS, dc.ResolveClients),
|
||||||
|
setIfNotNil(&s.conf.UsePrivateRDNS, dc.UsePrivateRDNS),
|
||||||
} {
|
} {
|
||||||
shouldRestart = shouldRestart || hasSet
|
shouldRestart = shouldRestart || hasSet
|
||||||
if shouldRestart {
|
if shouldRestart {
|
||||||
|
|
|
@ -30,6 +30,7 @@ type dnsContext struct {
|
||||||
setts *filtering.Settings
|
setts *filtering.Settings
|
||||||
|
|
||||||
result *filtering.Result
|
result *filtering.Result
|
||||||
|
|
||||||
// origResp is the response received from upstream. It is set when the
|
// origResp is the response received from upstream. It is set when the
|
||||||
// response is modified by filters.
|
// response is modified by filters.
|
||||||
origResp *dns.Msg
|
origResp *dns.Msg
|
||||||
|
@ -48,13 +49,13 @@ type dnsContext struct {
|
||||||
// clientID is the ClientID from DoH, DoQ, or DoT, if provided.
|
// clientID is the ClientID from DoH, DoQ, or DoT, if provided.
|
||||||
clientID string
|
clientID string
|
||||||
|
|
||||||
|
// startTime is the time at which the processing of the request has started.
|
||||||
|
startTime time.Time
|
||||||
|
|
||||||
// origQuestion is the question received from the client. It is set
|
// origQuestion is the question received from the client. It is set
|
||||||
// when the request is modified by rewrites.
|
// when the request is modified by rewrites.
|
||||||
origQuestion dns.Question
|
origQuestion dns.Question
|
||||||
|
|
||||||
// startTime is the time at which the processing of the request has started.
|
|
||||||
startTime time.Time
|
|
||||||
|
|
||||||
// protectionEnabled shows if the filtering is enabled, and if the
|
// protectionEnabled shows if the filtering is enabled, and if the
|
||||||
// server's DNS filter is ready.
|
// server's DNS filter is ready.
|
||||||
protectionEnabled bool
|
protectionEnabled bool
|
||||||
|
@ -160,6 +161,22 @@ func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
|
||||||
return resultCodeSuccess
|
return resultCodeSuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mozillaFQDN is the domain used to signal the Firefox browser to not use its
|
||||||
|
// own DoH server.
|
||||||
|
//
|
||||||
|
// See https://support.mozilla.org/en-US/kb/canary-domain-use-application-dnsnet.
|
||||||
|
const mozillaFQDN = "use-application-dns.net."
|
||||||
|
|
||||||
|
// healthcheckFQDN is a reserved domain-name used for healthchecking.
|
||||||
|
//
|
||||||
|
// [Section 6.2 of RFC 6761] states that DNS Registries/Registrars must not
|
||||||
|
// grant requests to register test names in the normal way to any person or
|
||||||
|
// entity, making domain names under the .test TLD free to use in internal
|
||||||
|
// purposes.
|
||||||
|
//
|
||||||
|
// [Section 6.2 of RFC 6761]: https://www.rfc-editor.org/rfc/rfc6761.html#section-6.2
|
||||||
|
const healthcheckFQDN = "healthcheck.adguardhome.test."
|
||||||
|
|
||||||
// processInitial terminates the following processing for some requests if
|
// processInitial terminates the following processing for some requests if
|
||||||
// needed and enriches dctx with some client-specific information.
|
// needed and enriches dctx with some client-specific information.
|
||||||
//
|
//
|
||||||
|
@ -169,6 +186,8 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
|
||||||
defer log.Debug("dnsforward: finished processing initial")
|
defer log.Debug("dnsforward: finished processing initial")
|
||||||
|
|
||||||
pctx := dctx.proxyCtx
|
pctx := dctx.proxyCtx
|
||||||
|
s.processClientIP(pctx.Addr)
|
||||||
|
|
||||||
q := pctx.Req.Question[0]
|
q := pctx.Req.Question[0]
|
||||||
qt := q.Qtype
|
qt := q.Qtype
|
||||||
if s.conf.AAAADisabled && qt == dns.TypeAAAA {
|
if s.conf.AAAADisabled && qt == dns.TypeAAAA {
|
||||||
|
@ -177,28 +196,13 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
|
||||||
return resultCodeFinish
|
return resultCodeFinish
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.conf.OnDNSRequest != nil {
|
if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == mozillaFQDN {
|
||||||
s.conf.OnDNSRequest(pctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Disable Mozilla DoH.
|
|
||||||
//
|
|
||||||
// See https://support.mozilla.org/en-US/kb/canary-domain-use-application-dnsnet.
|
|
||||||
if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == "use-application-dns.net." {
|
|
||||||
pctx.Res = s.genNXDomain(pctx.Req)
|
pctx.Res = s.genNXDomain(pctx.Req)
|
||||||
|
|
||||||
return resultCodeFinish
|
return resultCodeFinish
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle a reserved domain healthcheck.adguardhome.test.
|
if q.Name == healthcheckFQDN {
|
||||||
//
|
|
||||||
// [Section 6.2 of RFC 6761] states that DNS Registries/Registrars must not
|
|
||||||
// grant requests to register test names in the normal way to any person or
|
|
||||||
// entity, making domain names under test. TLD free to use in internal
|
|
||||||
// purposes.
|
|
||||||
//
|
|
||||||
// [Section 6.2 of RFC 6761]: https://www.rfc-editor.org/rfc/rfc6761.html#section-6.2
|
|
||||||
if q.Name == "healthcheck.adguardhome.test." {
|
|
||||||
// Generate a NODATA negative response to make nslookup exit with 0.
|
// Generate a NODATA negative response to make nslookup exit with 0.
|
||||||
pctx.Res = s.makeResponse(pctx.Req)
|
pctx.Res = s.makeResponse(pctx.Req)
|
||||||
|
|
||||||
|
@ -213,11 +217,28 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
|
||||||
|
|
||||||
// Get the client-specific filtering settings.
|
// Get the client-specific filtering settings.
|
||||||
dctx.protectionEnabled, _ = s.UpdatedProtectionStatus()
|
dctx.protectionEnabled, _ = s.UpdatedProtectionStatus()
|
||||||
dctx.setts = s.getClientRequestFilteringSettings(dctx)
|
dctx.setts = s.clientRequestFilteringSettings(dctx)
|
||||||
|
|
||||||
return resultCodeSuccess
|
return resultCodeSuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// processClientIP sends the client IP address to s.addrProc, if needed.
|
||||||
|
func (s *Server) processClientIP(addr net.Addr) {
|
||||||
|
clientIP := netutil.NetAddrToAddrPort(addr).Addr()
|
||||||
|
if clientIP == (netip.Addr{}) {
|
||||||
|
log.Info("dnsforward: warning: bad client addr %q", addr)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do not assign s.addrProc to a local variable to then use, since this lock
|
||||||
|
// also serializes the closure of s.addrProc.
|
||||||
|
s.serverLock.RLock()
|
||||||
|
defer s.serverLock.RUnlock()
|
||||||
|
|
||||||
|
s.addrProc.Process(clientIP)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) setTableHostToIP(t hostToIPTable) {
|
func (s *Server) setTableHostToIP(t hostToIPTable) {
|
||||||
s.tableHostToIPLock.Lock()
|
s.tableHostToIPLock.Lock()
|
||||||
defer s.tableHostToIPLock.Unlock()
|
defer s.tableHostToIPLock.Unlock()
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
|
"github.com/AdguardTeam/urlfilter/rules"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -22,6 +23,96 @@ const (
|
||||||
ddrTestFQDN = ddrTestDomainName + "."
|
ddrTestFQDN = ddrTestDomainName + "."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestServer_ProcessInitial(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
target string
|
||||||
|
wantRCode rules.RCode
|
||||||
|
qType rules.RRType
|
||||||
|
aaaaDisabled bool
|
||||||
|
wantRC resultCode
|
||||||
|
}{{
|
||||||
|
name: "success",
|
||||||
|
target: testQuestionTarget,
|
||||||
|
wantRCode: -1,
|
||||||
|
qType: dns.TypeA,
|
||||||
|
aaaaDisabled: false,
|
||||||
|
wantRC: resultCodeSuccess,
|
||||||
|
}, {
|
||||||
|
name: "aaaa_disabled",
|
||||||
|
target: testQuestionTarget,
|
||||||
|
wantRCode: dns.RcodeSuccess,
|
||||||
|
qType: dns.TypeAAAA,
|
||||||
|
aaaaDisabled: true,
|
||||||
|
wantRC: resultCodeFinish,
|
||||||
|
}, {
|
||||||
|
name: "aaaa_disabled_a",
|
||||||
|
target: testQuestionTarget,
|
||||||
|
wantRCode: -1,
|
||||||
|
qType: dns.TypeA,
|
||||||
|
aaaaDisabled: true,
|
||||||
|
wantRC: resultCodeSuccess,
|
||||||
|
}, {
|
||||||
|
name: "mozilla_canary",
|
||||||
|
target: mozillaFQDN,
|
||||||
|
wantRCode: dns.RcodeNameError,
|
||||||
|
qType: dns.TypeA,
|
||||||
|
aaaaDisabled: false,
|
||||||
|
wantRC: resultCodeFinish,
|
||||||
|
}, {
|
||||||
|
name: "adguardhome_healthcheck",
|
||||||
|
target: healthcheckFQDN,
|
||||||
|
wantRCode: dns.RcodeSuccess,
|
||||||
|
qType: dns.TypeA,
|
||||||
|
aaaaDisabled: false,
|
||||||
|
wantRC: resultCodeFinish,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc
|
||||||
|
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
c := ServerConfig{
|
||||||
|
FilteringConfig: FilteringConfig{
|
||||||
|
AAAADisabled: tc.aaaaDisabled,
|
||||||
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := createTestServer(t, &filtering.Config{}, c, nil)
|
||||||
|
|
||||||
|
var gotAddr netip.Addr
|
||||||
|
s.addrProc = &aghtest.AddressProcessor{
|
||||||
|
OnProcess: func(ip netip.Addr) { gotAddr = ip },
|
||||||
|
OnClose: func() (err error) { panic("not implemented") },
|
||||||
|
}
|
||||||
|
|
||||||
|
dctx := &dnsContext{
|
||||||
|
proxyCtx: &proxy.DNSContext{
|
||||||
|
Req: createTestMessageWithType(tc.target, tc.qType),
|
||||||
|
Addr: testClientAddr,
|
||||||
|
RequestID: 1234,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
gotRC := s.processInitial(dctx)
|
||||||
|
assert.Equal(t, tc.wantRC, gotRC)
|
||||||
|
assert.Equal(t, netutil.NetAddrToAddrPort(testClientAddr).Addr(), gotAddr)
|
||||||
|
|
||||||
|
if tc.wantRCode > 0 {
|
||||||
|
gotResp := dctx.proxyCtx.Res
|
||||||
|
require.NotNil(t, gotResp)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.wantRCode, gotResp.Rcode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestServer_ProcessDDRQuery(t *testing.T) {
|
func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||||
dohSVCB := &dns.SVCB{
|
dohSVCB := &dns.SVCB{
|
||||||
Priority: 1,
|
Priority: 1,
|
||||||
|
@ -64,7 +155,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||||
}{{
|
}{{
|
||||||
name: "pass_host",
|
name: "pass_host",
|
||||||
wantRes: resultCodeSuccess,
|
wantRes: resultCodeSuccess,
|
||||||
host: "example.net.",
|
host: testQuestionTarget,
|
||||||
qtype: dns.TypeSVCB,
|
qtype: dns.TypeSVCB,
|
||||||
ddrEnabled: true,
|
ddrEnabled: true,
|
||||||
portDoH: 8043,
|
portDoH: 8043,
|
||||||
|
@ -234,33 +325,33 @@ func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||||
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||||
knownIP := netip.MustParseAddr("1.2.3.4")
|
knownIP := netip.MustParseAddr("1.2.3.4")
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
|
wantIP netip.Addr
|
||||||
name string
|
name string
|
||||||
host string
|
host string
|
||||||
wantIP netip.Addr
|
|
||||||
wantRes resultCode
|
wantRes resultCode
|
||||||
isLocalCli bool
|
isLocalCli bool
|
||||||
}{{
|
}{{
|
||||||
|
wantIP: knownIP,
|
||||||
name: "local_client_success",
|
name: "local_client_success",
|
||||||
host: "example.lan",
|
host: "example.lan",
|
||||||
wantIP: knownIP,
|
|
||||||
wantRes: resultCodeSuccess,
|
wantRes: resultCodeSuccess,
|
||||||
isLocalCli: true,
|
isLocalCli: true,
|
||||||
}, {
|
}, {
|
||||||
|
wantIP: netip.Addr{},
|
||||||
name: "local_client_unknown_host",
|
name: "local_client_unknown_host",
|
||||||
host: "wronghost.lan",
|
host: "wronghost.lan",
|
||||||
wantIP: netip.Addr{},
|
|
||||||
wantRes: resultCodeSuccess,
|
wantRes: resultCodeSuccess,
|
||||||
isLocalCli: true,
|
isLocalCli: true,
|
||||||
}, {
|
}, {
|
||||||
|
wantIP: netip.Addr{},
|
||||||
name: "external_client_known_host",
|
name: "external_client_known_host",
|
||||||
host: "example.lan",
|
host: "example.lan",
|
||||||
wantIP: netip.Addr{},
|
|
||||||
wantRes: resultCodeFinish,
|
wantRes: resultCodeFinish,
|
||||||
isLocalCli: false,
|
isLocalCli: false,
|
||||||
}, {
|
}, {
|
||||||
|
wantIP: netip.Addr{},
|
||||||
name: "external_client_unknown_host",
|
name: "external_client_unknown_host",
|
||||||
host: "wronghost.lan",
|
host: "wronghost.lan",
|
||||||
wantIP: netip.Addr{},
|
|
||||||
wantRes: resultCodeFinish,
|
wantRes: resultCodeFinish,
|
||||||
isLocalCli: false,
|
isLocalCli: false,
|
||||||
}}
|
}}
|
||||||
|
@ -332,52 +423,52 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||||
|
|
||||||
knownIP := netip.MustParseAddr("1.2.3.4")
|
knownIP := netip.MustParseAddr("1.2.3.4")
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
|
wantIP netip.Addr
|
||||||
name string
|
name string
|
||||||
host string
|
host string
|
||||||
suffix string
|
suffix string
|
||||||
wantIP netip.Addr
|
|
||||||
wantRes resultCode
|
wantRes resultCode
|
||||||
qtyp uint16
|
qtyp uint16
|
||||||
}{{
|
}{{
|
||||||
|
wantIP: netip.Addr{},
|
||||||
name: "success_external",
|
name: "success_external",
|
||||||
host: examplecom,
|
host: examplecom,
|
||||||
suffix: defaultLocalDomainSuffix,
|
suffix: defaultLocalDomainSuffix,
|
||||||
wantIP: netip.Addr{},
|
|
||||||
wantRes: resultCodeSuccess,
|
wantRes: resultCodeSuccess,
|
||||||
qtyp: dns.TypeA,
|
qtyp: dns.TypeA,
|
||||||
}, {
|
}, {
|
||||||
|
wantIP: netip.Addr{},
|
||||||
name: "success_external_non_a",
|
name: "success_external_non_a",
|
||||||
host: examplecom,
|
host: examplecom,
|
||||||
suffix: defaultLocalDomainSuffix,
|
suffix: defaultLocalDomainSuffix,
|
||||||
wantIP: netip.Addr{},
|
|
||||||
wantRes: resultCodeSuccess,
|
wantRes: resultCodeSuccess,
|
||||||
qtyp: dns.TypeCNAME,
|
qtyp: dns.TypeCNAME,
|
||||||
}, {
|
}, {
|
||||||
|
wantIP: knownIP,
|
||||||
name: "success_internal",
|
name: "success_internal",
|
||||||
host: examplelan,
|
host: examplelan,
|
||||||
suffix: defaultLocalDomainSuffix,
|
suffix: defaultLocalDomainSuffix,
|
||||||
wantIP: knownIP,
|
|
||||||
wantRes: resultCodeSuccess,
|
wantRes: resultCodeSuccess,
|
||||||
qtyp: dns.TypeA,
|
qtyp: dns.TypeA,
|
||||||
}, {
|
}, {
|
||||||
|
wantIP: netip.Addr{},
|
||||||
name: "success_internal_unknown",
|
name: "success_internal_unknown",
|
||||||
host: "example-new.lan",
|
host: "example-new.lan",
|
||||||
suffix: defaultLocalDomainSuffix,
|
suffix: defaultLocalDomainSuffix,
|
||||||
wantIP: netip.Addr{},
|
|
||||||
wantRes: resultCodeSuccess,
|
wantRes: resultCodeSuccess,
|
||||||
qtyp: dns.TypeA,
|
qtyp: dns.TypeA,
|
||||||
}, {
|
}, {
|
||||||
|
wantIP: netip.Addr{},
|
||||||
name: "success_internal_aaaa",
|
name: "success_internal_aaaa",
|
||||||
host: examplelan,
|
host: examplelan,
|
||||||
suffix: defaultLocalDomainSuffix,
|
suffix: defaultLocalDomainSuffix,
|
||||||
wantIP: netip.Addr{},
|
|
||||||
wantRes: resultCodeSuccess,
|
wantRes: resultCodeSuccess,
|
||||||
qtyp: dns.TypeAAAA,
|
qtyp: dns.TypeAAAA,
|
||||||
}, {
|
}, {
|
||||||
|
wantIP: knownIP,
|
||||||
name: "success_custom_suffix",
|
name: "success_custom_suffix",
|
||||||
host: "example.custom",
|
host: "example.custom",
|
||||||
suffix: "custom",
|
suffix: "custom",
|
||||||
wantIP: knownIP,
|
|
||||||
wantRes: resultCodeSuccess,
|
wantRes: resultCodeSuccess,
|
||||||
qtyp: dns.TypeA,
|
qtyp: dns.TypeA,
|
||||||
}}
|
}}
|
||||||
|
@ -560,10 +651,8 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
|
||||||
var dnsCtx *dnsContext
|
var dnsCtx *dnsContext
|
||||||
setup := func(use bool) {
|
setup := func(use bool) {
|
||||||
proxyCtx = &proxy.DNSContext{
|
proxyCtx = &proxy.DNSContext{
|
||||||
Addr: &net.TCPAddr{
|
Addr: testClientAddr,
|
||||||
IP: net.IP{127, 0, 0, 1},
|
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
|
||||||
},
|
|
||||||
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
|
|
||||||
}
|
}
|
||||||
dnsCtx = &dnsContext{
|
dnsCtx = &dnsContext{
|
||||||
proxyCtx: proxyCtx,
|
proxyCtx: proxyCtx,
|
|
@ -42,11 +42,13 @@ func (s *Server) loadUpstreams() (upstreams []string, err error) {
|
||||||
|
|
||||||
// prepareUpstreamSettings sets upstream DNS server settings.
|
// prepareUpstreamSettings sets upstream DNS server settings.
|
||||||
func (s *Server) prepareUpstreamSettings() (err error) {
|
func (s *Server) prepareUpstreamSettings() (err error) {
|
||||||
// We're setting a customized set of RootCAs. The reason is that Go default
|
// Use a customized set of RootCAs, because Go's default mechanism of
|
||||||
// mechanism of loading TLS roots does not always work properly on some
|
// loading TLS roots does not always work properly on some routers so we're
|
||||||
// routers so we're loading roots manually and pass it here.
|
// loading roots manually and pass it here.
|
||||||
//
|
//
|
||||||
// See [aghtls.SystemRootCAs].
|
// See [aghtls.SystemRootCAs].
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Investigate if that's true.
|
||||||
upstream.RootCAs = s.conf.TLSv12Roots
|
upstream.RootCAs = s.conf.TLSv12Roots
|
||||||
upstream.CipherSuites = s.conf.TLSCiphers
|
upstream.CipherSuites = s.conf.TLSCiphers
|
||||||
|
|
||||||
|
@ -190,7 +192,7 @@ func (s *Server) resolveUpstreamsWithHosts(
|
||||||
|
|
||||||
// extractUpstreamHost returns the hostname of addr without port with an
|
// extractUpstreamHost returns the hostname of addr without port with an
|
||||||
// assumption that any address passed here has already been successfully parsed
|
// assumption that any address passed here has already been successfully parsed
|
||||||
// by [upstream.AddressToUpstream]. This function eesentially mirrors the logic
|
// by [upstream.AddressToUpstream]. This function essentially mirrors the logic
|
||||||
// of [upstream.AddressToUpstream], see TODO on [replaceUpstreamsWithHosts].
|
// of [upstream.AddressToUpstream], see TODO on [replaceUpstreamsWithHosts].
|
||||||
func extractUpstreamHost(addr string) (host string) {
|
func extractUpstreamHost(addr string) (host string) {
|
||||||
var err error
|
var err error
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||||
|
@ -141,7 +142,7 @@ func (clients *clientsContainer) handleHostsUpdates() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// webHandlersRegistered prevents a [clientsContainer] from regisering its web
|
// webHandlersRegistered prevents a [clientsContainer] from registering its web
|
||||||
// handlers more than once.
|
// handlers more than once.
|
||||||
//
|
//
|
||||||
// TODO(a.garipov): Refactor HTTP handler registration logic.
|
// TODO(a.garipov): Refactor HTTP handler registration logic.
|
||||||
|
@ -743,11 +744,9 @@ func (clients *clientsContainer) Update(prev, c *Client) (err error) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// setWHOISInfo sets the WHOIS information for a client.
|
// setWHOISInfo sets the WHOIS information for a client. clients.lock is
|
||||||
|
// expected to be locked.
|
||||||
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
|
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
|
||||||
clients.lock.Lock()
|
|
||||||
defer clients.lock.Unlock()
|
|
||||||
|
|
||||||
_, ok := clients.findLocked(ip.String())
|
_, ok := clients.findLocked(ip.String())
|
||||||
if ok {
|
if ok {
|
||||||
log.Debug("clients: client for %s is already created, ignore whois info", ip)
|
log.Debug("clients: client for %s is already created, ignore whois info", ip)
|
||||||
|
@ -774,9 +773,11 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
|
||||||
rc.WHOIS = wi
|
rc.WHOIS = wi
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddHost adds a new IP-hostname pairing. The priorities of the sources are
|
// addHost adds a new IP-hostname pairing. The priorities of the sources are
|
||||||
// taken into account. ok is true if the pairing was added.
|
// taken into account. ok is true if the pairing was added.
|
||||||
func (clients *clientsContainer) AddHost(
|
//
|
||||||
|
// TODO(a.garipov): Only used in internal tests. Consider removing.
|
||||||
|
func (clients *clientsContainer) addHost(
|
||||||
ip netip.Addr,
|
ip netip.Addr,
|
||||||
host string,
|
host string,
|
||||||
src clientSource,
|
src clientSource,
|
||||||
|
@ -787,6 +788,32 @@ func (clients *clientsContainer) AddHost(
|
||||||
return clients.addHostLocked(ip, host, src)
|
return clients.addHostLocked(ip, host, src)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ client.AddressUpdater = (*clientsContainer)(nil)
|
||||||
|
|
||||||
|
// UpdateAddress implements the [client.AddressUpdater] interface for
|
||||||
|
// *clientsContainer
|
||||||
|
func (clients *clientsContainer) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {
|
||||||
|
// Common fast path optimization.
|
||||||
|
if host == "" && info == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
clients.lock.Lock()
|
||||||
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
|
if host != "" {
|
||||||
|
ok := clients.addHostLocked(ip, host, ClientSourceRDNS)
|
||||||
|
if !ok {
|
||||||
|
log.Debug("clients: host for client %q already set with higher priority source", ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if info != nil {
|
||||||
|
clients.setWHOISInfo(ip, info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// addHostLocked adds a new IP-hostname pairing. clients.lock is expected to be
|
// addHostLocked adds a new IP-hostname pairing. clients.lock is expected to be
|
||||||
// locked.
|
// locked.
|
||||||
func (clients *clientsContainer) addHostLocked(
|
func (clients *clientsContainer) addHostLocked(
|
||||||
|
|
|
@ -168,13 +168,13 @@ func TestClients(t *testing.T) {
|
||||||
|
|
||||||
t.Run("addhost_success", func(t *testing.T) {
|
t.Run("addhost_success", func(t *testing.T) {
|
||||||
ip := netip.MustParseAddr("1.1.1.1")
|
ip := netip.MustParseAddr("1.1.1.1")
|
||||||
ok := clients.AddHost(ip, "host", ClientSourceARP)
|
ok := clients.addHost(ip, "host", ClientSourceARP)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
ok = clients.AddHost(ip, "host2", ClientSourceARP)
|
ok = clients.addHost(ip, "host2", ClientSourceARP)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
ok = clients.AddHost(ip, "host3", ClientSourceHostsFile)
|
ok = clients.addHost(ip, "host3", ClientSourceHostsFile)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
assert.Equal(t, clients.clientSource(ip), ClientSourceHostsFile)
|
assert.Equal(t, clients.clientSource(ip), ClientSourceHostsFile)
|
||||||
|
@ -182,18 +182,18 @@ func TestClients(t *testing.T) {
|
||||||
|
|
||||||
t.Run("dhcp_replaces_arp", func(t *testing.T) {
|
t.Run("dhcp_replaces_arp", func(t *testing.T) {
|
||||||
ip := netip.MustParseAddr("1.2.3.4")
|
ip := netip.MustParseAddr("1.2.3.4")
|
||||||
ok := clients.AddHost(ip, "from_arp", ClientSourceARP)
|
ok := clients.addHost(ip, "from_arp", ClientSourceARP)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.Equal(t, clients.clientSource(ip), ClientSourceARP)
|
assert.Equal(t, clients.clientSource(ip), ClientSourceARP)
|
||||||
|
|
||||||
ok = clients.AddHost(ip, "from_dhcp", ClientSourceDHCP)
|
ok = clients.addHost(ip, "from_dhcp", ClientSourceDHCP)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.Equal(t, clients.clientSource(ip), ClientSourceDHCP)
|
assert.Equal(t, clients.clientSource(ip), ClientSourceDHCP)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("addhost_fail", func(t *testing.T) {
|
t.Run("addhost_fail", func(t *testing.T) {
|
||||||
ip := netip.MustParseAddr("1.1.1.1")
|
ip := netip.MustParseAddr("1.1.1.1")
|
||||||
ok := clients.AddHost(ip, "host1", ClientSourceRDNS)
|
ok := clients.addHost(ip, "host1", ClientSourceRDNS)
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -216,7 +216,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||||
|
|
||||||
t.Run("existing_auto-client", func(t *testing.T) {
|
t.Run("existing_auto-client", func(t *testing.T) {
|
||||||
ip := netip.MustParseAddr("1.1.1.1")
|
ip := netip.MustParseAddr("1.1.1.1")
|
||||||
ok := clients.AddHost(ip, "host", ClientSourceRDNS)
|
ok := clients.addHost(ip, "host", ClientSourceRDNS)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
clients.setWHOISInfo(ip, whois)
|
clients.setWHOISInfo(ip, whois)
|
||||||
|
@ -259,7 +259,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
// Now add an auto-client with the same IP.
|
// Now add an auto-client with the same IP.
|
||||||
ok = clients.AddHost(ip, "test", ClientSourceRDNS)
|
ok = clients.addHost(ip, "test", ClientSourceRDNS)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
})
|
})
|
||||||
|
|
|
@ -590,7 +590,13 @@ func (c *configuration) write() (err error) {
|
||||||
s.WriteDiskConfig(&c)
|
s.WriteDiskConfig(&c)
|
||||||
dns := &config.DNS
|
dns := &config.DNS
|
||||||
dns.FilteringConfig = c
|
dns.FilteringConfig = c
|
||||||
dns.LocalPTRResolvers, config.Clients.Sources.RDNS, dns.UsePrivateRDNS = s.RDNSSettings()
|
|
||||||
|
dns.LocalPTRResolvers = s.LocalPTRResolvers()
|
||||||
|
|
||||||
|
addrProcConf := s.AddrProcConfig()
|
||||||
|
config.Clients.Sources.RDNS = addrProcConf.UseRDNS
|
||||||
|
config.Clients.Sources.WHOIS = addrProcConf.UseWHOIS
|
||||||
|
dns.UsePrivateRDNS = addrProcConf.UsePrivateRDNS
|
||||||
}
|
}
|
||||||
|
|
||||||
if Context.dhcpServer != nil {
|
if Context.dhcpServer != nil {
|
||||||
|
|
|
@ -13,14 +13,12 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
|
@ -135,7 +133,7 @@ func initDNSServer(
|
||||||
return fmt.Errorf("preparing set of private subnets: %w", err)
|
return fmt.Errorf("preparing set of private subnets: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
p := dnsforward.DNSCreateParams{
|
Context.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{
|
||||||
DNSFilter: filters,
|
DNSFilter: filters,
|
||||||
Stats: sts,
|
Stats: sts,
|
||||||
QueryLog: qlog,
|
QueryLog: qlog,
|
||||||
|
@ -143,9 +141,7 @@ func initDNSServer(
|
||||||
Anonymizer: anonymizer,
|
Anonymizer: anonymizer,
|
||||||
LocalDomain: config.DHCP.LocalDomainName,
|
LocalDomain: config.DHCP.LocalDomainName,
|
||||||
DHCPServer: dhcpSrv,
|
DHCPServer: dhcpSrv,
|
||||||
}
|
})
|
||||||
|
|
||||||
Context.dnsServer, err = dnsforward.NewServer(p)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
closeDNSServer()
|
closeDNSServer()
|
||||||
|
|
||||||
|
@ -154,134 +150,23 @@ func initDNSServer(
|
||||||
|
|
||||||
Context.clients.dnsServer = Context.dnsServer
|
Context.clients.dnsServer = Context.dnsServer
|
||||||
|
|
||||||
dnsConf, err := generateServerConfig(tlsConf, httpReg)
|
dnsConf, err := newServerConfig(tlsConf, httpReg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
closeDNSServer()
|
closeDNSServer()
|
||||||
|
|
||||||
return fmt.Errorf("generateServerConfig: %w", err)
|
return fmt.Errorf("newServerConfig: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = Context.dnsServer.Prepare(&dnsConf)
|
err = Context.dnsServer.Prepare(dnsConf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
closeDNSServer()
|
closeDNSServer()
|
||||||
|
|
||||||
return fmt.Errorf("dnsServer.Prepare: %w", err)
|
return fmt.Errorf("dnsServer.Prepare: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
initRDNS()
|
|
||||||
initWHOIS()
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
|
||||||
// defaultQueueSize is the size of queue of IPs for rDNS and WHOIS
|
|
||||||
// processing.
|
|
||||||
defaultQueueSize = 255
|
|
||||||
|
|
||||||
// defaultCacheSize is the maximum size of the cache for rDNS and WHOIS
|
|
||||||
// processing. It must be greater than zero.
|
|
||||||
defaultCacheSize = 10_000
|
|
||||||
|
|
||||||
// defaultIPTTL is the Time to Live duration for IP addresses cached by
|
|
||||||
// rDNS and WHOIS.
|
|
||||||
defaultIPTTL = 1 * time.Hour
|
|
||||||
)
|
|
||||||
|
|
||||||
// initRDNS initializes the rDNS.
|
|
||||||
func initRDNS() {
|
|
||||||
Context.rdnsCh = make(chan netip.Addr, defaultQueueSize)
|
|
||||||
|
|
||||||
// TODO(s.chzhen): Add ability to disable it on dns server configuration
|
|
||||||
// update in [dnsforward] package.
|
|
||||||
r := rdns.New(&rdns.Config{
|
|
||||||
Exchanger: Context.dnsServer,
|
|
||||||
CacheSize: defaultCacheSize,
|
|
||||||
CacheTTL: defaultIPTTL,
|
|
||||||
})
|
|
||||||
|
|
||||||
go processRDNS(r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// processRDNS processes reverse DNS lookup queries. It is intended to be used
|
|
||||||
// as a goroutine.
|
|
||||||
func processRDNS(r rdns.Interface) {
|
|
||||||
defer log.OnPanic("rdns")
|
|
||||||
|
|
||||||
for ip := range Context.rdnsCh {
|
|
||||||
ok := Context.dnsServer.ShouldResolveClient(ip)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
host, changed := r.Process(ip)
|
|
||||||
if host == "" || !changed {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
ok = Context.clients.AddHost(ip, host, ClientSourceRDNS)
|
|
||||||
if ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug(
|
|
||||||
"dns: can't set rdns info for client %q: already set with higher priority source",
|
|
||||||
ip,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// initWHOIS initializes the WHOIS.
|
|
||||||
//
|
|
||||||
// TODO(s.chzhen): Consider making configurable.
|
|
||||||
func initWHOIS() {
|
|
||||||
const (
|
|
||||||
// defaultTimeout is the timeout for WHOIS requests.
|
|
||||||
defaultTimeout = 5 * time.Second
|
|
||||||
|
|
||||||
// defaultMaxConnReadSize is an upper limit in bytes for reading from
|
|
||||||
// net.Conn.
|
|
||||||
defaultMaxConnReadSize = 64 * 1024
|
|
||||||
|
|
||||||
// defaultMaxRedirects is the maximum redirects count.
|
|
||||||
defaultMaxRedirects = 5
|
|
||||||
|
|
||||||
// defaultMaxInfoLen is the maximum length of whois.Info fields.
|
|
||||||
defaultMaxInfoLen = 250
|
|
||||||
)
|
|
||||||
|
|
||||||
Context.whoisCh = make(chan netip.Addr, defaultQueueSize)
|
|
||||||
|
|
||||||
var w whois.Interface
|
|
||||||
|
|
||||||
if config.Clients.Sources.WHOIS {
|
|
||||||
w = whois.New(&whois.Config{
|
|
||||||
DialContext: customDialContext,
|
|
||||||
ServerAddr: whois.DefaultServer,
|
|
||||||
Port: whois.DefaultPort,
|
|
||||||
Timeout: defaultTimeout,
|
|
||||||
CacheSize: defaultCacheSize,
|
|
||||||
MaxConnReadSize: defaultMaxConnReadSize,
|
|
||||||
MaxRedirects: defaultMaxRedirects,
|
|
||||||
MaxInfoLen: defaultMaxInfoLen,
|
|
||||||
CacheTTL: defaultIPTTL,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
w = whois.Empty{}
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer log.OnPanic("whois")
|
|
||||||
|
|
||||||
for ip := range Context.whoisCh {
|
|
||||||
info, changed := w.Process(context.Background(), ip)
|
|
||||||
if info != nil && changed {
|
|
||||||
Context.clients.setWHOISInfo(ip, info)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseSubnetSet parses a slice of subnets. If the slice is empty, it returns
|
// parseSubnetSet parses a slice of subnets. If the slice is empty, it returns
|
||||||
// a subnet set that matches all locally served networks, see
|
// a subnet set that matches all locally served networks, see
|
||||||
// [netutil.IsLocallyServed].
|
// [netutil.IsLocallyServed].
|
||||||
|
@ -312,17 +197,6 @@ func isRunning() bool {
|
||||||
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
|
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
|
||||||
}
|
}
|
||||||
|
|
||||||
func onDNSRequest(pctx *proxy.DNSContext) {
|
|
||||||
ip := netutil.NetAddrToAddrPort(pctx.Addr).Addr()
|
|
||||||
if ip == (netip.Addr{}) {
|
|
||||||
// This would be quite weird if we get here.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
Context.rdnsCh <- ip
|
|
||||||
Context.whoisCh <- ip
|
|
||||||
}
|
|
||||||
|
|
||||||
func ipsToTCPAddrs(ips []netip.Addr, port int) (tcpAddrs []*net.TCPAddr) {
|
func ipsToTCPAddrs(ips []netip.Addr, port int) (tcpAddrs []*net.TCPAddr) {
|
||||||
if ips == nil {
|
if ips == nil {
|
||||||
return nil
|
return nil
|
||||||
|
@ -349,23 +223,35 @@ func ipsToUDPAddrs(ips []netip.Addr, port int) (udpAddrs []*net.UDPAddr) {
|
||||||
return udpAddrs
|
return udpAddrs
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateServerConfig(
|
func newServerConfig(
|
||||||
tlsConf *tlsConfigSettings,
|
tlsConf *tlsConfigSettings,
|
||||||
httpReg aghhttp.RegisterFunc,
|
httpReg aghhttp.RegisterFunc,
|
||||||
) (newConf dnsforward.ServerConfig, err error) {
|
) (newConf *dnsforward.ServerConfig, err error) {
|
||||||
dnsConf := config.DNS
|
dnsConf := config.DNS
|
||||||
hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()})
|
hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()})
|
||||||
newConf = dnsforward.ServerConfig{
|
|
||||||
|
newConf = &dnsforward.ServerConfig{
|
||||||
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
|
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
|
||||||
TCPListenAddrs: ipsToTCPAddrs(hosts, dnsConf.Port),
|
TCPListenAddrs: ipsToTCPAddrs(hosts, dnsConf.Port),
|
||||||
FilteringConfig: dnsConf.FilteringConfig,
|
FilteringConfig: dnsConf.FilteringConfig,
|
||||||
ConfigModified: onConfigModified,
|
ConfigModified: onConfigModified,
|
||||||
HTTPRegister: httpReg,
|
HTTPRegister: httpReg,
|
||||||
OnDNSRequest: onDNSRequest,
|
|
||||||
UseDNS64: config.DNS.UseDNS64,
|
UseDNS64: config.DNS.UseDNS64,
|
||||||
DNS64Prefixes: config.DNS.DNS64Prefixes,
|
DNS64Prefixes: config.DNS.DNS64Prefixes,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const initialClientsNum = 100
|
||||||
|
|
||||||
|
// Do not set DialContext, PrivateSubnets, and UsePrivateRDNS, because they
|
||||||
|
// are set by [dnsforward.Server.Prepare].
|
||||||
|
newConf.AddrProcConf = &client.DefaultAddrProcConfig{
|
||||||
|
Exchanger: Context.dnsServer,
|
||||||
|
AddressUpdater: &Context.clients,
|
||||||
|
InitialAddresses: Context.stats.TopClientsIP(initialClientsNum),
|
||||||
|
UseRDNS: config.Clients.Sources.RDNS,
|
||||||
|
UseWHOIS: config.Clients.Sources.WHOIS,
|
||||||
|
}
|
||||||
|
|
||||||
if tlsConf.Enabled {
|
if tlsConf.Enabled {
|
||||||
newConf.TLSConfig = tlsConf.TLSConfig
|
newConf.TLSConfig = tlsConf.TLSConfig
|
||||||
newConf.TLSConfig.ServerName = tlsConf.ServerName
|
newConf.TLSConfig.ServerName = tlsConf.ServerName
|
||||||
|
@ -385,9 +271,9 @@ func generateServerConfig(
|
||||||
if tlsConf.PortDNSCrypt != 0 {
|
if tlsConf.PortDNSCrypt != 0 {
|
||||||
newConf.DNSCryptConfig, err = newDNSCrypt(hosts, *tlsConf)
|
newConf.DNSCryptConfig, err = newDNSCrypt(hosts, *tlsConf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Don't wrap the error, because it's already
|
// Don't wrap the error, because it's already wrapped by
|
||||||
// wrapped by newDNSCrypt.
|
// newDNSCrypt.
|
||||||
return dnsforward.ServerConfig{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -401,7 +287,6 @@ func generateServerConfig(
|
||||||
newConf.LocalPTRResolvers = dnsConf.LocalPTRResolvers
|
newConf.LocalPTRResolvers = dnsConf.LocalPTRResolvers
|
||||||
newConf.UpstreamTimeout = dnsConf.UpstreamTimeout.Duration
|
newConf.UpstreamTimeout = dnsConf.UpstreamTimeout.Duration
|
||||||
|
|
||||||
newConf.ResolveClients = config.Clients.Sources.RDNS
|
|
||||||
newConf.UsePrivateRDNS = dnsConf.UsePrivateRDNS
|
newConf.UsePrivateRDNS = dnsConf.UsePrivateRDNS
|
||||||
newConf.ServeHTTP3 = dnsConf.ServeHTTP3
|
newConf.ServeHTTP3 = dnsConf.ServeHTTP3
|
||||||
newConf.UseHTTP3Upstreams = dnsConf.UseHTTP3Upstreams
|
newConf.UseHTTP3Upstreams = dnsConf.UseHTTP3Upstreams
|
||||||
|
@ -556,27 +441,19 @@ func startDNSServer() error {
|
||||||
Context.stats.Start()
|
Context.stats.Start()
|
||||||
Context.queryLog.Start()
|
Context.queryLog.Start()
|
||||||
|
|
||||||
const topClientsNumber = 100 // the number of clients to get
|
|
||||||
for _, ip := range Context.stats.TopClientsIP(topClientsNumber) {
|
|
||||||
Context.rdnsCh <- ip
|
|
||||||
Context.whoisCh <- ip
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func reconfigureDNSServer() (err error) {
|
func reconfigureDNSServer() (err error) {
|
||||||
var newConf dnsforward.ServerConfig
|
|
||||||
|
|
||||||
tlsConf := &tlsConfigSettings{}
|
tlsConf := &tlsConfigSettings{}
|
||||||
Context.tls.WriteDiskConfig(tlsConf)
|
Context.tls.WriteDiskConfig(tlsConf)
|
||||||
|
|
||||||
newConf, err = generateServerConfig(tlsConf, httpRegister)
|
newConf, err := newServerConfig(tlsConf, httpRegister)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("generating forwarding dns server config: %w", err)
|
return fmt.Errorf("generating forwarding dns server config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = Context.dnsServer.Reconfigure(&newConf)
|
err = Context.dnsServer.Reconfigure(newConf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("starting forwarding dns server: %w", err)
|
return fmt.Errorf("starting forwarding dns server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,14 +3,12 @@ package home
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -79,15 +77,8 @@ type homeContext struct {
|
||||||
pidFileName string // PID file name. Empty if no PID file was created.
|
pidFileName string // PID file name. Empty if no PID file was created.
|
||||||
controlLock sync.Mutex
|
controlLock sync.Mutex
|
||||||
tlsRoots *x509.CertPool // list of root CAs for TLSv1.2
|
tlsRoots *x509.CertPool // list of root CAs for TLSv1.2
|
||||||
client *http.Client
|
|
||||||
appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app
|
appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app
|
||||||
|
|
||||||
// rdnsCh is the channel for receiving IPs for rDNS processing.
|
|
||||||
rdnsCh chan netip.Addr
|
|
||||||
|
|
||||||
// whoisCh is the channel for receiving IPs for WHOIS processing.
|
|
||||||
whoisCh chan netip.Addr
|
|
||||||
|
|
||||||
// tlsCipherIDs are the ID of the cipher suites that AdGuard Home must use.
|
// tlsCipherIDs are the ID of the cipher suites that AdGuard Home must use.
|
||||||
tlsCipherIDs []uint16
|
tlsCipherIDs []uint16
|
||||||
|
|
||||||
|
@ -156,19 +147,6 @@ func setupContext(opts options) (err error) {
|
||||||
setupContextFlags(opts)
|
setupContextFlags(opts)
|
||||||
|
|
||||||
Context.tlsRoots = aghtls.SystemRootCAs()
|
Context.tlsRoots = aghtls.SystemRootCAs()
|
||||||
Context.client = &http.Client{
|
|
||||||
Timeout: time.Minute * 5,
|
|
||||||
Transport: &http.Transport{
|
|
||||||
DialContext: customDialContext,
|
|
||||||
Proxy: getHTTPProxy,
|
|
||||||
TLSClientConfig: &tls.Config{
|
|
||||||
RootCAs: Context.tlsRoots,
|
|
||||||
CipherSuites: Context.tlsCipherIDs,
|
|
||||||
MinVersion: tls.VersionTLS12,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
Context.mux = http.NewServeMux()
|
Context.mux = http.NewServeMux()
|
||||||
|
|
||||||
if !Context.firstRun {
|
if !Context.firstRun {
|
||||||
|
@ -341,7 +319,7 @@ func initContextClients() (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.updater = updater.NewUpdater(&updater.Config{
|
Context.updater = updater.NewUpdater(&updater.Config{
|
||||||
Client: Context.client,
|
Client: config.DNS.DnsfilterConf.HTTPClient,
|
||||||
Version: version.Version(),
|
Version: version.Version(),
|
||||||
Channel: version.Channel(),
|
Channel: version.Channel(),
|
||||||
GOARCH: runtime.GOARCH,
|
GOARCH: runtime.GOARCH,
|
||||||
|
@ -433,7 +411,7 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
|
||||||
conf.Filters = slices.Clone(config.Filters)
|
conf.Filters = slices.Clone(config.Filters)
|
||||||
conf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
|
conf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
|
||||||
conf.UserRules = slices.Clone(config.UserRules)
|
conf.UserRules = slices.Clone(config.UserRules)
|
||||||
conf.HTTPClient = Context.client
|
conf.HTTPClient = httpClient()
|
||||||
|
|
||||||
cacheTime := time.Duration(conf.CacheTime) * time.Minute
|
cacheTime := time.Duration(conf.CacheTime) * time.Minute
|
||||||
|
|
||||||
|
@ -634,10 +612,10 @@ func run(opts options, clientBuildFS fs.FS) {
|
||||||
Context.tls.start()
|
Context.tls.start()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
sErr := startDNSServer()
|
startErr := startDNSServer()
|
||||||
if sErr != nil {
|
if startErr != nil {
|
||||||
closeDNSServer()
|
closeDNSServer()
|
||||||
fatalOnError(sErr)
|
fatalOnError(startErr)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -996,62 +974,6 @@ func detectFirstRun() bool {
|
||||||
return errors.Is(err, os.ErrNotExist)
|
return errors.Is(err, os.ErrNotExist)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect to a remote server resolving hostname using our own DNS server.
|
|
||||||
//
|
|
||||||
// TODO(e.burkov): This messy logic should be decomposed and clarified.
|
|
||||||
//
|
|
||||||
// TODO(a.garipov): Support network.
|
|
||||||
func customDialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
|
||||||
log.Debug("home: customdial: dialing addr %q for network %s", addr, network)
|
|
||||||
|
|
||||||
host, port, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
dialer := &net.Dialer{
|
|
||||||
Timeout: time.Minute * 5,
|
|
||||||
}
|
|
||||||
|
|
||||||
if net.ParseIP(host) != nil || config.DNS.Port == 0 {
|
|
||||||
return dialer.DialContext(ctx, network, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
addrs, err := Context.dnsServer.Resolve(host)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("resolving %q: %w", host, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("dnsServer.Resolve: %q: %v", host, addrs)
|
|
||||||
|
|
||||||
if len(addrs) == 0 {
|
|
||||||
return nil, fmt.Errorf("couldn't lookup host: %q", host)
|
|
||||||
}
|
|
||||||
|
|
||||||
var dialErrs []error
|
|
||||||
for _, a := range addrs {
|
|
||||||
addr = net.JoinHostPort(a.String(), port)
|
|
||||||
conn, err = dialer.DialContext(ctx, network, addr)
|
|
||||||
if err != nil {
|
|
||||||
dialErrs = append(dialErrs, err)
|
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
return conn, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errors.List(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getHTTPProxy(_ *http.Request) (*url.URL, error) {
|
|
||||||
if config.ProxyURL == "" {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return url.Parse(config.ProxyURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// jsonError is a generic JSON error response.
|
// jsonError is a generic JSON error response.
|
||||||
//
|
//
|
||||||
// TODO(a.garipov): Merge together with the implementations in [dhcpd] and other
|
// TODO(a.garipov): Merge together with the implementations in [dhcpd] and other
|
||||||
|
|
47
internal/home/httpclient.go
Normal file
47
internal/home/httpclient.go
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
package home
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// httpClient returns a new HTTP client that uses the AdGuard Home's own DNS
|
||||||
|
// server for resolving hostnames. The resulting client should not be used
|
||||||
|
// until [Context.dnsServer] is initialized.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov, e.burkov): This is rather messy. Refactor.
|
||||||
|
func httpClient() (c *http.Client) {
|
||||||
|
// Do not use Context.dnsServer.DialContext directly in the struct literal
|
||||||
|
// below, since Context.dnsServer may be nil when this function is called.
|
||||||
|
dialContext := func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
||||||
|
return Context.dnsServer.DialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Client{
|
||||||
|
// TODO(a.garipov): Make configurable.
|
||||||
|
Timeout: time.Minute * 5,
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialContext: dialContext,
|
||||||
|
Proxy: httpProxy,
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
RootCAs: Context.tlsRoots,
|
||||||
|
CipherSuites: Context.tlsCipherIDs,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// httpProxy returns parses and returns an HTTP proxy URL from the config, if
|
||||||
|
// any.
|
||||||
|
func httpProxy(_ *http.Request) (u *url.URL, err error) {
|
||||||
|
if config.ProxyURL == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return url.Parse(config.ProxyURL)
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue