Pull request 1878: AG-22597-imp-rdns

Squashed commit of the following:

commit ccad155c34989943d88a0a260c50845d1f4ece6b
Merge: 0cd889f6a 5a195b441
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Jul 6 17:00:58 2023 +0300

    Merge branch 'master' into AG-22597-imp-rdns

commit 0cd889f6a500f5616af0f8d8fdcde0403b87ad4f
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Jul 6 12:20:49 2023 +0300

    dnsforward: imp code

commit 1aaa1998b914b0d53142c21fa3bdcae502e4f3f6
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Jul 4 20:11:55 2023 +0300

    home: add todo

commit aed232fcf70ef546f373d5235b73abcb4fbb4b6c
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Jul 4 13:25:28 2023 +0300

    all: imp code, tests

commit 5c028c2766ffb8ebdc358a245a249c6a55d9ad81
Merge: 83d6ae7f6 97af062f7
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Jul 3 18:54:42 2023 +0300

    Merge branch 'master' into AG-22597-imp-rdns

commit 83d6ae7f61a7b81a8d73cd6d747035278c64fb70
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Jul 3 18:53:05 2023 +0300

    home: imp code

commit 8153988dece0406e51a90a43eaffae59dba30a36
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Fri Jun 30 18:06:09 2023 +0300

    all: imp code

commit 00d3cc11a9378318f176aae00ddf972f255d575c
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Fri Jun 30 13:05:04 2023 +0300

    all: add tests

commit ffdc95f237bfdb780922b4390d82cdc0154b0621
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Jun 29 15:20:00 2023 +0300

    all: imp code, docs

commit 0dc60e2b355750ca701558927d22fb9ad187ea7e
Merge: 69dd56bdb d4a4bda64
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Jun 29 15:13:19 2023 +0300

    Merge branch 'master' into AG-22597-imp-rdns

commit 69dd56bdb75056b0fa6bcf6538af7fff93383323
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Fri Jun 23 14:36:29 2023 +0300

    rdns: add tests

commit 16909b51adbe3a3f230291834cc9486dd8a0e8f8
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Jun 19 16:28:26 2023 +0300

    rdns: extract rdns
This commit is contained in:
Stanislav Chzhen 2023-07-06 17:10:06 +03:00
parent 5a195b441c
commit c21f958eaf
9 changed files with 389 additions and 458 deletions

View file

@ -17,6 +17,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
"github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
@ -277,17 +278,6 @@ func (s *Server) Resolve(host string) ([]net.IPAddr, error) {
return s.internalProxy.LookupIPAddr(host) return s.internalProxy.LookupIPAddr(host)
} }
// RDNSExchanger is a resolver for clients' addresses.
type RDNSExchanger interface {
// Exchange tries to resolve the ip in a suitable way, i.e. either as local
// or as external.
Exchange(ip net.IP) (host string, err error)
// ResolvesPrivatePTR returns true if the RDNSExchanger is able to
// resolve PTR requests for locally-served addresses.
ResolvesPrivatePTR() (ok bool)
}
const ( const (
// ErrRDNSNoData is returned by [RDNSExchanger.Exchange] when the answer // ErrRDNSNoData is returned by [RDNSExchanger.Exchange] when the answer
// section of response is either NODATA or has no PTR records. // section of response is either NODATA or has no PTR records.
@ -299,10 +289,10 @@ const (
) )
// type check // type check
var _ RDNSExchanger = (*Server)(nil) var _ rdns.Exchanger = (*Server)(nil)
// Exchange implements the RDNSExchanger interface for *Server. // Exchange implements the [rdns.Exchanger] interface for *Server.
func (s *Server) Exchange(ip net.IP) (host string, err error) { func (s *Server) Exchange(ip netip.Addr) (host string, err error) {
s.serverLock.RLock() s.serverLock.RLock()
defer s.serverLock.RUnlock() defer s.serverLock.RUnlock()
@ -310,7 +300,7 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) {
return "", nil return "", nil
} }
arpa, err := netutil.IPToReversedAddr(ip) arpa, err := netutil.IPToReversedAddr(ip.AsSlice())
if err != nil { if err != nil {
return "", fmt.Errorf("reversing ip: %w", err) return "", fmt.Errorf("reversing ip: %w", err)
} }
@ -335,7 +325,7 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) {
} }
var resolver *proxy.Proxy var resolver *proxy.Proxy
if s.privateNets.Contains(ip) { if s.isPrivateIP(ip) {
if !s.conf.UsePrivateRDNS { if !s.conf.UsePrivateRDNS {
return "", nil return "", nil
} }
@ -350,8 +340,12 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) {
return "", err return "", err
} }
return hostFromPTR(ctx.Res)
}
// hostFromPTR returns domain name from the PTR response or error.
func hostFromPTR(resp *dns.Msg) (host string, err error) {
// Distinguish between NODATA response and a failed request. // Distinguish between NODATA response and a failed request.
resp := ctx.Res
if resp.Rcode != dns.RcodeSuccess && resp.Rcode != dns.RcodeNameError { if resp.Rcode != dns.RcodeSuccess && resp.Rcode != dns.RcodeNameError {
return "", fmt.Errorf( return "", fmt.Errorf(
"received %s response: %w", "received %s response: %w",
@ -370,12 +364,25 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) {
return "", ErrRDNSNoData return "", ErrRDNSNoData
} }
// ResolvesPrivatePTR implements the RDNSExchanger interface for *Server. // isPrivateIP returns true if the ip is private.
func (s *Server) ResolvesPrivatePTR() (ok bool) { func (s *Server) isPrivateIP(ip netip.Addr) (ok bool) {
return s.privateNets.Contains(ip.AsSlice())
}
// ShouldResolveClient returns false if ip is a loopback address, or ip is
// private and resolving of private addresses is disabled.
func (s *Server) ShouldResolveClient(ip netip.Addr) (ok bool) {
if ip.IsLoopback() {
return false
}
isPrivate := s.isPrivateIP(ip)
s.serverLock.RLock() s.serverLock.RLock()
defer s.serverLock.RUnlock() defer s.serverLock.RUnlock()
return s.conf.UsePrivateRDNS return s.conf.ResolveClients &&
(s.conf.UsePrivateRDNS || !isPrivate)
} }
// Start starts the DNS server. // Start starts the DNS server.

View file

@ -1273,11 +1273,11 @@ func TestServer_Exchange(t *testing.T) {
) )
var ( var (
onesIP = net.IP{1, 1, 1, 1} onesIP = netip.MustParseAddr("1.1.1.1")
localIP = net.IP{192, 168, 1, 1} localIP = netip.MustParseAddr("192.168.1.1")
) )
revExtIPv4, err := netutil.IPToReversedAddr(onesIP) revExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice())
require.NoError(t, err) require.NoError(t, err)
extUpstream := &aghtest.UpstreamMock{ extUpstream := &aghtest.UpstreamMock{
@ -1290,7 +1290,7 @@ func TestServer_Exchange(t *testing.T) {
}, },
} }
revLocIPv4, err := netutil.IPToReversedAddr(localIP) revLocIPv4, err := netutil.IPToReversedAddr(localIP.AsSlice())
require.NoError(t, err) require.NoError(t, err)
locUpstream := &aghtest.UpstreamMock{ locUpstream := &aghtest.UpstreamMock{
@ -1330,7 +1330,7 @@ func TestServer_Exchange(t *testing.T) {
want string want string
wantErr error wantErr error
locUpstream upstream.Upstream locUpstream upstream.Upstream
req net.IP req netip.Addr
}{{ }{{
name: "external_good", name: "external_good",
want: onesHost, want: onesHost,
@ -1354,7 +1354,7 @@ func TestServer_Exchange(t *testing.T) {
want: "", want: "",
wantErr: ErrRDNSNoData, wantErr: ErrRDNSNoData,
locUpstream: locUpstream, locUpstream: locUpstream,
req: net.IP{192, 168, 1, 2}, req: netip.MustParseAddr("192.168.1.2"),
}, { }, {
name: "invalid_answer", name: "invalid_answer",
want: "", want: "",
@ -1396,3 +1396,57 @@ func TestServer_Exchange(t *testing.T) {
assert.Empty(t, host) assert.Empty(t, host)
}) })
} }
func TestServer_ShouldResolveClient(t *testing.T) {
srv := &Server{
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
}
testCases := []struct {
ip netip.Addr
want require.BoolAssertionFunc
name string
resolve bool
usePrivate bool
}{{
name: "default",
ip: netip.MustParseAddr("1.1.1.1"),
want: require.True,
resolve: true,
usePrivate: true,
}, {
name: "no_rdns",
ip: netip.MustParseAddr("1.1.1.1"),
want: require.False,
resolve: false,
usePrivate: true,
}, {
name: "loopback",
ip: netip.MustParseAddr("127.0.0.1"),
want: require.False,
resolve: true,
usePrivate: true,
}, {
name: "private_resolve",
ip: netip.MustParseAddr("192.168.0.1"),
want: require.True,
resolve: true,
usePrivate: true,
}, {
name: "private_no_resolve",
ip: netip.MustParseAddr("192.168.0.1"),
want: require.False,
resolve: true,
usePrivate: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
srv.conf.ResolveClients = tc.resolve
srv.conf.UsePrivateRDNS = tc.usePrivate
ok := srv.ShouldResolveClient(tc.ip)
tc.want(t, ok)
})
}
}

View file

@ -17,6 +17,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
"github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
@ -167,30 +168,77 @@ func initDNSServer(
return fmt.Errorf("dnsServer.Prepare: %w", err) return fmt.Errorf("dnsServer.Prepare: %w", err)
} }
if config.Clients.Sources.RDNS { initRDNS()
Context.rdns = NewRDNS(Context.dnsServer, &Context.clients, config.DNS.UsePrivateRDNS)
}
initWHOIS() initWHOIS()
return nil return nil
} }
const (
// defaultQueueSize is the size of queue of IPs for rDNS and WHOIS
// processing.
defaultQueueSize = 255
// defaultCacheSize is the maximum size of the cache for rDNS and WHOIS
// processing. It must be greater than zero.
defaultCacheSize = 10_000
// defaultIPTTL is the Time to Live duration for IP addresses cached by
// rDNS and WHOIS.
defaultIPTTL = 1 * time.Hour
)
// initRDNS initializes the rDNS.
func initRDNS() {
Context.rdnsCh = make(chan netip.Addr, defaultQueueSize)
// TODO(s.chzhen): Add ability to disable it on dns server configuration
// update in [dnsforward] package.
r := rdns.New(&rdns.Config{
Exchanger: Context.dnsServer,
CacheSize: defaultCacheSize,
CacheTTL: defaultIPTTL,
})
go processRDNS(r)
}
// processRDNS processes reverse DNS lookup queries. It is intended to be used
// as a goroutine.
func processRDNS(r rdns.Interface) {
defer log.OnPanic("rdns")
for ip := range Context.rdnsCh {
ok := Context.dnsServer.ShouldResolveClient(ip)
if !ok {
continue
}
host, changed := r.Process(ip)
if host == "" || !changed {
continue
}
ok = Context.clients.AddHost(ip, host, ClientSourceRDNS)
if ok {
continue
}
log.Debug(
"dns: can't set rdns info for client %q: already set with higher priority source",
ip,
)
}
}
// initWHOIS initializes the WHOIS. // initWHOIS initializes the WHOIS.
// //
// TODO(s.chzhen): Consider making configurable. // TODO(s.chzhen): Consider making configurable.
func initWHOIS() { func initWHOIS() {
const ( const (
// defaultQueueSize is the size of queue of IPs for WHOIS processing.
defaultQueueSize = 255
// defaultTimeout is the timeout for WHOIS requests. // defaultTimeout is the timeout for WHOIS requests.
defaultTimeout = 5 * time.Second defaultTimeout = 5 * time.Second
// defaultCacheSize is the maximum size of the cache. If it's zero,
// cache size is unlimited.
defaultCacheSize = 10_000
// defaultMaxConnReadSize is an upper limit in bytes for reading from // defaultMaxConnReadSize is an upper limit in bytes for reading from
// net.Conn. // net.Conn.
defaultMaxConnReadSize = 64 * 1024 defaultMaxConnReadSize = 64 * 1024
@ -200,9 +248,6 @@ func initWHOIS() {
// defaultMaxInfoLen is the maximum length of whois.Info fields. // defaultMaxInfoLen is the maximum length of whois.Info fields.
defaultMaxInfoLen = 250 defaultMaxInfoLen = 250
// defaultIPTTL is the Time to Live duration for cached IP addresses.
defaultIPTTL = 1 * time.Hour
) )
Context.whoisCh = make(chan netip.Addr, defaultQueueSize) Context.whoisCh = make(chan netip.Addr, defaultQueueSize)
@ -274,11 +319,7 @@ func onDNSRequest(pctx *proxy.DNSContext) {
return return
} }
srcs := config.Clients.Sources Context.rdnsCh <- ip
if srcs.RDNS && !ip.IsLoopback() {
Context.rdns.Begin(ip)
}
Context.whoisCh <- ip Context.whoisCh <- ip
} }
@ -517,11 +558,7 @@ func startDNSServer() error {
const topClientsNumber = 100 // the number of clients to get const topClientsNumber = 100 // the number of clients to get
for _, ip := range Context.stats.TopClientsIP(topClientsNumber) { for _, ip := range Context.stats.TopClientsIP(topClientsNumber) {
srcs := config.Clients.Sources Context.rdnsCh <- ip
if srcs.RDNS && !ip.IsLoopback() {
Context.rdns.Begin(ip)
}
Context.whoisCh <- ip Context.whoisCh <- ip
} }

View file

@ -56,7 +56,6 @@ type homeContext struct {
stats stats.Interface // statistics module stats stats.Interface // statistics module
queryLog querylog.QueryLog // query log module queryLog querylog.QueryLog // query log module
dnsServer *dnsforward.Server // DNS module dnsServer *dnsforward.Server // DNS module
rdns *RDNS // rDNS module
dhcpServer dhcpd.Interface // DHCP module dhcpServer dhcpd.Interface // DHCP module
auth *Auth // HTTP authentication module auth *Auth // HTTP authentication module
filters *filtering.DNSFilter // DNS filtering module filters *filtering.DNSFilter // DNS filtering module
@ -83,6 +82,9 @@ type homeContext struct {
client *http.Client client *http.Client
appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app
// rdnsCh is the channel for receiving IPs for rDNS processing.
rdnsCh chan netip.Addr
// whoisCh is the channel for receiving IPs for WHOIS processing. // whoisCh is the channel for receiving IPs for WHOIS processing.
whoisCh chan netip.Addr whoisCh chan netip.Addr

View file

@ -1,143 +0,0 @@
package home
import (
"encoding/binary"
"net/netip"
"sync/atomic"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
// RDNS resolves clients' addresses to enrich their metadata.
type RDNS struct {
exchanger dnsforward.RDNSExchanger
clients *clientsContainer
// ipCh used to pass client's IP to rDNS workerLoop.
ipCh chan netip.Addr
// ipCache caches the IP addresses to be resolved by rDNS. The resolved
// address stays here while it's inside clients. After leaving clients the
// address will be resolved once again. If the address couldn't be
// resolved, cache prevents further attempts to resolve it for some time.
ipCache cache.Cache
// usePrivate stores the state of current private reverse-DNS resolving
// settings.
usePrivate atomic.Bool
}
// Default AdGuard Home reverse DNS values.
const (
revDNSCacheSize = 10000
// TODO(e.burkov): Make these values configurable.
revDNSCacheTTL = 24 * 60 * 60
revDNSFailureCacheTTL = 1 * 60 * 60
revDNSQueueSize = 256
)
// NewRDNS creates and returns initialized RDNS.
func NewRDNS(
exchanger dnsforward.RDNSExchanger,
clients *clientsContainer,
usePrivate bool,
) (rDNS *RDNS) {
rDNS = &RDNS{
exchanger: exchanger,
clients: clients,
ipCache: cache.New(cache.Config{
EnableLRU: true,
MaxCount: revDNSCacheSize,
}),
ipCh: make(chan netip.Addr, revDNSQueueSize),
}
rDNS.usePrivate.Store(usePrivate)
go rDNS.workerLoop()
return rDNS
}
// ensurePrivateCache ensures that the state of the RDNS cache is consistent
// with the current private client RDNS resolving settings.
//
// TODO(e.burkov): Clearing cache each time this value changed is not a perfect
// approach since only unresolved locally-served addresses should be removed.
// Implement when improving the cache.
func (r *RDNS) ensurePrivateCache() {
usePrivate := r.exchanger.ResolvesPrivatePTR()
if r.usePrivate.CompareAndSwap(!usePrivate, usePrivate) {
r.ipCache.Clear()
}
}
// isCached returns true if ip is already cached and not expired yet. It also
// caches it otherwise.
func (r *RDNS) isCached(ip netip.Addr) (ok bool) {
ipBytes := ip.AsSlice()
now := uint64(time.Now().Unix())
if expire := r.ipCache.Get(ipBytes); len(expire) != 0 {
return binary.BigEndian.Uint64(expire) > now
}
return false
}
// cache caches the ip address for ttl seconds.
func (r *RDNS) cache(ip netip.Addr, ttl uint64) {
ipData := ip.AsSlice()
ttlData := [8]byte{}
binary.BigEndian.PutUint64(ttlData[:], uint64(time.Now().Unix())+ttl)
r.ipCache.Set(ipData, ttlData[:])
}
// Begin adds the ip to the resolving queue if it is not cached or already
// resolved.
func (r *RDNS) Begin(ip netip.Addr) {
r.ensurePrivateCache()
if r.isCached(ip) || r.clients.clientSource(ip) > ClientSourceRDNS {
return
}
select {
case r.ipCh <- ip:
log.Debug("rdns: %q added to queue", ip)
default:
log.Debug("rdns: queue is full")
}
}
// workerLoop handles incoming IP addresses from ipChan and adds it into
// clients.
func (r *RDNS) workerLoop() {
defer log.OnPanic("rdns")
for ip := range r.ipCh {
ttl := uint64(revDNSCacheTTL)
host, err := r.exchanger.Exchange(ip.AsSlice())
if err != nil {
log.Debug("rdns: resolving %q: %s", ip, err)
if errors.Is(err, dnsforward.ErrRDNSFailed) {
// Cache failure for a less time.
ttl = revDNSFailureCacheTTL
}
}
r.cache(ip, ttl)
if host != "" {
_ = r.clients.AddHost(ip, host, ClientSourceRDNS)
}
}
}

View file

@ -1,264 +0,0 @@
package home
import (
"bytes"
"encoding/binary"
"fmt"
"net"
"net/netip"
"sync"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRDNS_Begin(t *testing.T) {
aghtest.ReplaceLogLevel(t, log.DEBUG)
w := &bytes.Buffer{}
aghtest.ReplaceLogWriter(t, w)
ip1234, ip1235 := netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("1.2.3.5")
testCases := []struct {
cliIDIndex map[string]*Client
customChan chan netip.Addr
name string
wantLog string
ip netip.Addr
wantCacheHit int
wantCacheMiss int
}{{
cliIDIndex: map[string]*Client{},
customChan: nil,
name: "cached",
wantLog: "",
ip: ip1234,
wantCacheHit: 1,
wantCacheMiss: 0,
}, {
cliIDIndex: map[string]*Client{},
customChan: nil,
name: "not_cached",
wantLog: "rdns: queue is full",
ip: ip1235,
wantCacheHit: 0,
wantCacheMiss: 1,
}, {
cliIDIndex: map[string]*Client{"1.2.3.5": {}},
customChan: nil,
name: "already_in_clients",
wantLog: "",
ip: ip1235,
wantCacheHit: 0,
wantCacheMiss: 1,
}, {
cliIDIndex: map[string]*Client{},
customChan: make(chan netip.Addr, 1),
name: "add_to_queue",
wantLog: `rdns: "1.2.3.5" added to queue`,
ip: ip1235,
wantCacheHit: 0,
wantCacheMiss: 1,
}}
for _, tc := range testCases {
w.Reset()
ipCache := cache.New(cache.Config{
EnableLRU: true,
MaxCount: revDNSCacheSize,
})
ttl := make([]byte, binary.Size(uint64(0)))
binary.BigEndian.PutUint64(ttl, uint64(time.Now().Add(100*time.Hour).Unix()))
rdns := &RDNS{
ipCache: ipCache,
exchanger: &rDNSExchanger{
ex: aghtest.NewErrorUpstream(),
},
clients: &clientsContainer{
list: map[string]*Client{},
idIndex: tc.cliIDIndex,
ipToRC: map[netip.Addr]*RuntimeClient{},
allTags: stringutil.NewSet(),
},
}
ipCache.Clear()
ipCache.Set(net.IP{1, 2, 3, 4}, ttl)
if tc.customChan != nil {
rdns.ipCh = tc.customChan
defer close(tc.customChan)
}
t.Run(tc.name, func(t *testing.T) {
rdns.Begin(tc.ip)
assert.Equal(t, tc.wantCacheHit, ipCache.Stats().Hit)
assert.Equal(t, tc.wantCacheMiss, ipCache.Stats().Miss)
assert.Contains(t, w.String(), tc.wantLog)
})
}
}
// rDNSExchanger is a mock dnsforward.RDNSExchanger implementation for tests.
type rDNSExchanger struct {
ex upstream.Upstream
usePrivate bool
}
// Exchange implements dnsforward.RDNSExchanger interface for *RDNSExchanger.
func (e *rDNSExchanger) Exchange(ip net.IP) (host string, err error) {
rev, err := netutil.IPToReversedAddr(ip)
if err != nil {
return "", fmt.Errorf("reversing ip: %w", err)
}
req := &dns.Msg{
Question: []dns.Question{{
Name: dns.Fqdn(rev),
Qclass: dns.ClassINET,
Qtype: dns.TypePTR,
}},
}
resp, err := e.ex.Exchange(req)
if err != nil {
return "", err
}
if len(resp.Answer) == 0 {
return "", nil
}
return resp.Answer[0].Header().Name, nil
}
// Exchange implements dnsforward.RDNSExchanger interface for *RDNSExchanger.
func (e *rDNSExchanger) ResolvesPrivatePTR() (ok bool) {
return e.usePrivate
}
func TestRDNS_ensurePrivateCache(t *testing.T) {
data := []byte{1, 2, 3, 4}
ipCache := cache.New(cache.Config{
EnableLRU: true,
MaxCount: revDNSCacheSize,
})
ex := &rDNSExchanger{
ex: aghtest.NewErrorUpstream(),
}
rdns := &RDNS{
ipCache: ipCache,
exchanger: ex,
}
rdns.ipCache.Set(data, data)
require.NotZero(t, rdns.ipCache.Stats().Count)
ex.usePrivate = !ex.usePrivate
rdns.ensurePrivateCache()
require.Zero(t, rdns.ipCache.Stats().Count)
}
func TestRDNS_WorkerLoop(t *testing.T) {
aghtest.ReplaceLogLevel(t, log.DEBUG)
w := &bytes.Buffer{}
aghtest.ReplaceLogWriter(t, w)
localIP := netip.MustParseAddr("192.168.1.1")
revIPv4, err := netutil.IPToReversedAddr(localIP.AsSlice())
require.NoError(t, err)
revIPv6, err := netutil.IPToReversedAddr(net.ParseIP("2a00:1450:400c:c06::93"))
require.NoError(t, err)
locUpstream := &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "local.upstream.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
return aghalg.Coalesce(
aghtest.MatchedResponse(req, dns.TypePTR, revIPv4, "local.domain"),
aghtest.MatchedResponse(req, dns.TypePTR, revIPv6, "ipv6.domain"),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil
},
}
errUpstream := aghtest.NewErrorUpstream()
testCases := []struct {
ups upstream.Upstream
cliIP netip.Addr
wantLog string
name string
wantClientSource clientSource
}{{
ups: locUpstream,
cliIP: localIP,
wantLog: "",
name: "all_good",
wantClientSource: ClientSourceRDNS,
}, {
ups: errUpstream,
cliIP: netip.MustParseAddr("192.168.1.2"),
wantLog: `rdns: resolving "192.168.1.2": test upstream error`,
name: "resolve_error",
wantClientSource: ClientSourceNone,
}, {
ups: locUpstream,
cliIP: netip.MustParseAddr("2a00:1450:400c:c06::93"),
wantLog: "",
name: "ipv6_good",
wantClientSource: ClientSourceRDNS,
}}
for _, tc := range testCases {
w.Reset()
cc := newClientsContainer(t)
ch := make(chan netip.Addr)
rdns := &RDNS{
exchanger: &rDNSExchanger{
ex: tc.ups,
},
clients: cc,
ipCh: ch,
ipCache: cache.New(cache.Config{
EnableLRU: true,
MaxCount: revDNSCacheSize,
}),
}
t.Run(tc.name, func(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
rdns.workerLoop()
wg.Done()
}()
ch <- tc.cliIP
close(ch)
wg.Wait()
if tc.wantLog != "" {
assert.Contains(t, w.String(), tc.wantLog)
}
assert.Equal(t, tc.wantClientSource, cc.clientSource(tc.cliIP))
})
}
}

132
internal/rdns/rdns.go Normal file
View file

@ -0,0 +1,132 @@
// Package rdns processes reverse DNS lookup queries.
package rdns
import (
"net/netip"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/bluele/gcache"
)
// Interface processes rDNS queries.
type Interface interface {
// Process makes rDNS request and returns domain name. changed indicates
// that domain name was updated since last request.
Process(ip netip.Addr) (host string, changed bool)
}
// Empty is an empty [Inteface] implementation which does nothing.
type Empty struct{}
// type check
var _ Interface = (*Empty)(nil)
// Process implements the [Interface] interface for Empty.
func (Empty) Process(_ netip.Addr) (host string, changed bool) {
return "", false
}
// Exchanger is a resolver for clients' addresses.
type Exchanger interface {
// Exchange tries to resolve the ip in a suitable way, i.e. either as local
// or as external.
Exchange(ip netip.Addr) (host string, err error)
}
// Config is the configuration structure for Default.
type Config struct {
// Exchanger resolves IP addresses to domain names.
Exchanger Exchanger
// CacheSize is the maximum size of the cache. It must be greater than
// zero.
CacheSize int
// CacheTTL is the Time to Live duration for cached IP addresses.
CacheTTL time.Duration
}
// Default is the default rDNS query processor.
type Default struct {
// cache is the cache containing IP addresses of clients. An active IP
// address is resolved once again after it expires. If IP address couldn't
// be resolved, it stays here for some time to prevent further attempts to
// resolve the same IP.
cache gcache.Cache
// exchanger resolves IP addresses to domain names.
exchanger Exchanger
// cacheTTL is the Time to Live duration for cached IP addresses.
cacheTTL time.Duration
}
// New returns a new default rDNS query processor. conf must not be nil.
func New(conf *Config) (r *Default) {
return &Default{
cache: gcache.New(conf.CacheSize).LRU().Build(),
exchanger: conf.Exchanger,
cacheTTL: conf.CacheTTL,
}
}
// type check
var _ Interface = (*Default)(nil)
// Process implements the [Interface] interface for Default.
func (r *Default) Process(ip netip.Addr) (host string, changed bool) {
fromCache, expired := r.findInCache(ip)
if !expired {
return fromCache, false
}
host, err := r.exchanger.Exchange(ip)
if err != nil {
log.Debug("rdns: resolving %q: %s", ip, err)
}
item := &cacheItem{
expiry: time.Now().Add(r.cacheTTL),
host: host,
}
err = r.cache.Set(ip, item)
if err != nil {
log.Debug("rdns: cache: adding item %q: %s", ip, err)
}
return host, fromCache == "" || host != fromCache
}
// findInCache finds domain name in the cache. expired is true if host is not
// valid anymore.
func (r *Default) findInCache(ip netip.Addr) (host string, expired bool) {
val, err := r.cache.Get(ip)
if err != nil {
if !errors.Is(err, gcache.KeyNotFoundError) {
log.Debug("rdns: cache: retrieving %q: %s", ip, err)
}
return "", true
}
item, ok := val.(*cacheItem)
if !ok {
log.Debug("rdns: cache: %q bad type %T", ip, val)
return "", true
}
return item.host, time.Now().After(item.expiry)
}
// cacheItem represents an item that we will store in the cache.
type cacheItem struct {
// expiry is the time when cacheItem will expire.
expiry time.Time
// host is the domain name of a runtime client.
host string
}

105
internal/rdns/rdns_test.go Normal file
View file

@ -0,0 +1,105 @@
package rdns_test
import (
"net/netip"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
"github.com/AdguardTeam/golibs/netutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// fakeRDNSExchanger is a mock [rdns.Exchanger] implementation for tests.
type fakeRDNSExchanger struct {
OnExchange func(ip netip.Addr) (host string, err error)
}
// type check
var _ rdns.Exchanger = (*fakeRDNSExchanger)(nil)
// Exchange implements [rdns.Exchanger] interface for *fakeRDNSExchanger.
func (e *fakeRDNSExchanger) Exchange(ip netip.Addr) (host string, err error) {
return e.OnExchange(ip)
}
func TestDefault_Process(t *testing.T) {
ip1 := netip.MustParseAddr("1.2.3.4")
revAddr1, err := netutil.IPToReversedAddr(ip1.AsSlice())
require.NoError(t, err)
ip2 := netip.MustParseAddr("4.3.2.1")
revAddr2, err := netutil.IPToReversedAddr(ip2.AsSlice())
require.NoError(t, err)
localIP := netip.MustParseAddr("192.168.0.1")
localRevAddr1, err := netutil.IPToReversedAddr(localIP.AsSlice())
require.NoError(t, err)
config := &rdns.Config{
CacheSize: 100,
CacheTTL: time.Hour,
}
testCases := []struct {
name string
addr netip.Addr
want string
}{{
name: "first",
addr: ip1,
want: revAddr1,
}, {
name: "second",
addr: ip2,
want: revAddr2,
}, {
name: "empty",
addr: netip.MustParseAddr("0.0.0.0"),
want: "",
}, {
name: "private",
addr: localIP,
want: localRevAddr1,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
hit := 0
onExchange := func(ip netip.Addr) (host string, err error) {
hit++
switch ip {
case ip1:
return revAddr1, nil
case ip2:
return revAddr2, nil
case localIP:
return localRevAddr1, nil
default:
return "", nil
}
}
exchanger := &fakeRDNSExchanger{
OnExchange: onExchange,
}
config.Exchanger = exchanger
r := rdns.New(config)
got, changed := r.Process(tc.addr)
require.True(t, changed)
assert.Equal(t, tc.want, got)
assert.Equal(t, 1, hit)
// From cache.
got, changed = r.Process(tc.addr)
require.False(t, changed)
assert.Equal(t, tc.want, got)
assert.Equal(t, 1, hit)
})
}
}

View file

@ -177,6 +177,7 @@ run_linter gocognit --over 10\
./internal/aghhttp/\ ./internal/aghhttp/\
./internal/aghio/\ ./internal/aghio/\
./internal/next/\ ./internal/next/\
./internal/rdns/\
./internal/tools/\ ./internal/tools/\
./internal/version/\ ./internal/version/\
./internal/whois/\ ./internal/whois/\