mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2025-02-21 12:29:03 +03:00
Pull request: 5035-more-clients-netip-addr
Updates #5035. Squashed commit of the following: commit 1934ea14299921760e9fcf6dd9053bd3155cb40e Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Wed Nov 9 14:19:54 2022 +0300 all: move more client code to netip.Addr
This commit is contained in:
parent
98af0e000e
commit
167b112511
17 changed files with 164 additions and 210 deletions
|
@ -31,12 +31,6 @@ var (
|
||||||
// the IP being static is available.
|
// the IP being static is available.
|
||||||
const ErrNoStaticIPInfo errors.Error = "no information about static ip"
|
const ErrNoStaticIPInfo errors.Error = "no information about static ip"
|
||||||
|
|
||||||
// IPv4Localhost returns 127.0.0.1, which returns true for [netip.Addr.Is4].
|
|
||||||
func IPv4Localhost() (ip netip.Addr) { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
|
|
||||||
|
|
||||||
// IPv6Localhost returns ::1, which returns true for [netip.Addr.Is6].
|
|
||||||
func IPv6Localhost() (ip netip.Addr) { return netip.AddrFrom16([16]byte{15: 1}) }
|
|
||||||
|
|
||||||
// IfaceHasStaticIP checks if interface is configured to have static IP address.
|
// IfaceHasStaticIP checks if interface is configured to have static IP address.
|
||||||
// If it can't give a definitive answer, it returns false and an error for which
|
// If it can't give a definitive answer, it returns false and an error for which
|
||||||
// errors.Is(err, ErrNoStaticIPInfo) is true.
|
// errors.Is(err, ErrNoStaticIPInfo) is true.
|
||||||
|
|
|
@ -188,7 +188,7 @@ func TestBroadcastFromIPNet(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCheckPort(t *testing.T) {
|
func TestCheckPort(t *testing.T) {
|
||||||
laddr := netip.AddrPortFrom(IPv4Localhost(), 0)
|
laddr := netip.AddrPortFrom(netutil.IPv4Localhost(), 0)
|
||||||
|
|
||||||
t.Run("tcp_bound", func(t *testing.T) {
|
t.Run("tcp_bound", func(t *testing.T) {
|
||||||
l, err := net.Listen("tcp", laddr.String())
|
l, err := net.Listen("tcp", laddr.String())
|
||||||
|
|
|
@ -23,16 +23,6 @@ func ValidateClientID(id string) (err error) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// hasLabelSuffix returns true if s ends with suffix preceded by a dot. It's
|
|
||||||
// a helper function to prevent unnecessary allocations in code like:
|
|
||||||
//
|
|
||||||
// if strings.HasSuffix(s, "." + suffix) { /* … */ }
|
|
||||||
//
|
|
||||||
// s must be longer than suffix.
|
|
||||||
func hasLabelSuffix(s, suffix string) (ok bool) {
|
|
||||||
return strings.HasSuffix(s, suffix) && s[len(s)-len(suffix)-1] == '.'
|
|
||||||
}
|
|
||||||
|
|
||||||
// clientIDFromClientServerName extracts and validates a ClientID. hostSrvName
|
// clientIDFromClientServerName extracts and validates a ClientID. hostSrvName
|
||||||
// is the server name of the host. cliSrvName is the server name as sent by the
|
// is the server name of the host. cliSrvName is the server name as sent by the
|
||||||
// client. When strict is true, and client and host server name don't match,
|
// client. When strict is true, and client and host server name don't match,
|
||||||
|
@ -46,7 +36,7 @@ func clientIDFromClientServerName(
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hasLabelSuffix(cliSrvName, hostSrvName) {
|
if !netutil.IsImmediateSubdomain(cliSrvName, hostSrvName) {
|
||||||
if !strict {
|
if !strict {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -246,6 +246,7 @@ type RDNSExchanger interface {
|
||||||
// Exchange tries to resolve the ip in a suitable way, e.g. either as
|
// Exchange tries to resolve the ip in a suitable way, e.g. either as
|
||||||
// local or as external.
|
// local or as external.
|
||||||
Exchange(ip net.IP) (host string, err error)
|
Exchange(ip net.IP) (host string, err error)
|
||||||
|
|
||||||
// ResolvesPrivatePTR returns true if the RDNSExchanger is able to
|
// ResolvesPrivatePTR returns true if the RDNSExchanger is able to
|
||||||
// resolve PTR requests for locally-served addresses.
|
// resolve PTR requests for locally-served addresses.
|
||||||
ResolvesPrivatePTR() (ok bool)
|
ResolvesPrivatePTR() (ok bool)
|
||||||
|
@ -261,6 +262,9 @@ const (
|
||||||
rDNSNotPTRErr errors.Error = "the response is not a ptr"
|
rDNSNotPTRErr errors.Error = "the response is not a ptr"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ RDNSExchanger = (*Server)(nil)
|
||||||
|
|
||||||
// Exchange implements the RDNSExchanger interface for *Server.
|
// Exchange implements the RDNSExchanger interface for *Server.
|
||||||
func (s *Server) Exchange(ip net.IP) (host string, err error) {
|
func (s *Server) Exchange(ip net.IP) (host string, err error) {
|
||||||
s.serverLock.RLock()
|
s.serverLock.RLock()
|
||||||
|
@ -675,21 +679,13 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// IsBlockedClient returns true if the client is blocked by the current access
|
// IsBlockedClient returns true if the client is blocked by the current access
|
||||||
// settings.
|
// settings.
|
||||||
func (s *Server) IsBlockedClient(ip net.IP, clientID string) (blocked bool, rule string) {
|
func (s *Server) IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string) {
|
||||||
s.serverLock.RLock()
|
s.serverLock.RLock()
|
||||||
defer s.serverLock.RUnlock()
|
defer s.serverLock.RUnlock()
|
||||||
|
|
||||||
blockedByIP := false
|
blockedByIP := false
|
||||||
if ip != nil {
|
if ip != (netip.Addr{}) {
|
||||||
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
|
blockedByIP, rule = s.access.isBlockedIP(ip)
|
||||||
ipAddr, err := netutil.IPToAddrNoMapped(ip)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("dnsforward: bad client ip %v: %s", ip, err)
|
|
||||||
|
|
||||||
return false, ""
|
|
||||||
}
|
|
||||||
|
|
||||||
blockedByIP, rule = s.access.isBlockedIP(ipAddr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
allowlistMode := s.access.allowlistMode()
|
allowlistMode := s.access.allowlistMode()
|
||||||
|
|
|
@ -19,13 +19,13 @@ func (s *Server) beforeRequestHandler(
|
||||||
_ *proxy.Proxy,
|
_ *proxy.Proxy,
|
||||||
pctx *proxy.DNSContext,
|
pctx *proxy.DNSContext,
|
||||||
) (reply bool, err error) {
|
) (reply bool, err error) {
|
||||||
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
|
|
||||||
clientID, err := s.clientIDFromDNSContext(pctx)
|
clientID, err := s.clientIDFromDNSContext(pctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("getting clientid: %w", err)
|
return false, fmt.Errorf("getting clientid: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
blocked, _ := s.IsBlockedClient(ip, clientID)
|
addrPort := netutil.NetAddrToAddrPort(pctx.Addr)
|
||||||
|
blocked, _ := s.IsBlockedClient(addrPort.Addr(), clientID)
|
||||||
if blocked {
|
if blocked {
|
||||||
return s.preBlockedResponse(pctx)
|
return s.preBlockedResponse(pctx)
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -40,7 +40,7 @@ func serveFiltersLocally(t *testing.T, fltContent []byte) (ipp netip.AddrPort) {
|
||||||
addr := l.Addr()
|
addr := l.Addr()
|
||||||
require.IsType(t, new(net.TCPAddr), addr)
|
require.IsType(t, new(net.TCPAddr), addr)
|
||||||
|
|
||||||
return netip.AddrPortFrom(aghnet.IPv4Localhost(), uint16(addr.(*net.TCPAddr).Port))
|
return netip.AddrPortFrom(netutil.IPv4Localhost(), uint16(addr.(*net.TCPAddr).Port))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFilters(t *testing.T) {
|
func TestFilters(t *testing.T) {
|
||||||
|
|
|
@ -129,7 +129,7 @@ type RuntimeClientWHOISInfo struct {
|
||||||
|
|
||||||
type clientsContainer struct {
|
type clientsContainer struct {
|
||||||
// TODO(a.garipov): Perhaps use a number of separate indices for
|
// TODO(a.garipov): Perhaps use a number of separate indices for
|
||||||
// different types (string, net.IP, and so on).
|
// different types (string, netip.Addr, and so on).
|
||||||
list map[string]*Client // name -> client
|
list map[string]*Client // name -> client
|
||||||
idIndex map[string]*Client // ID -> client
|
idIndex map[string]*Client // ID -> client
|
||||||
|
|
||||||
|
@ -333,7 +333,7 @@ func (clients *clientsContainer) onDHCPLeaseChanged(flags int) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// exists checks if client with this IP address already exists.
|
// exists checks if client with this IP address already exists.
|
||||||
func (clients *clientsContainer) exists(ip net.IP, source clientSource) (ok bool) {
|
func (clients *clientsContainer) exists(ip netip.Addr, source clientSource) (ok bool) {
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
|
@ -342,7 +342,7 @@ func (clients *clientsContainer) exists(ip net.IP, source clientSource) (ok bool
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
rc, ok := clients.findRuntimeClientLocked(ip)
|
rc, ok := clients.ipToRC[ip]
|
||||||
if !ok {
|
if !ok {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -371,7 +371,8 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client,
|
||||||
var artClient *querylog.Client
|
var artClient *querylog.Client
|
||||||
var art bool
|
var art bool
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
c, art = clients.clientOrArtificial(net.ParseIP(id), id)
|
ip, _ := netip.ParseAddr(id)
|
||||||
|
c, art = clients.clientOrArtificial(ip, id)
|
||||||
if art {
|
if art {
|
||||||
artClient = c
|
artClient = c
|
||||||
|
|
||||||
|
@ -389,7 +390,7 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client,
|
||||||
// records about this client besides maybe whether or not it is blocked. c is
|
// records about this client besides maybe whether or not it is blocked. c is
|
||||||
// never nil.
|
// never nil.
|
||||||
func (clients *clientsContainer) clientOrArtificial(
|
func (clients *clientsContainer) clientOrArtificial(
|
||||||
ip net.IP,
|
ip netip.Addr,
|
||||||
id string,
|
id string,
|
||||||
) (c *querylog.Client, art bool) {
|
) (c *querylog.Client, art bool) {
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -406,13 +407,6 @@ func (clients *clientsContainer) clientOrArtificial(
|
||||||
}, false
|
}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if ip == nil {
|
|
||||||
// Technically should never happen, but still.
|
|
||||||
return &querylog.Client{
|
|
||||||
Name: "",
|
|
||||||
}, true
|
|
||||||
}
|
|
||||||
|
|
||||||
var rc *RuntimeClient
|
var rc *RuntimeClient
|
||||||
rc, ok = clients.findRuntimeClient(ip)
|
rc, ok = clients.findRuntimeClient(ip)
|
||||||
if ok {
|
if ok {
|
||||||
|
@ -492,19 +486,20 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
|
||||||
return c, true
|
return c, true
|
||||||
}
|
}
|
||||||
|
|
||||||
ip := net.ParseIP(id)
|
ip, err := netip.ParseAddr(id)
|
||||||
if ip == nil {
|
if err != nil {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, c = range clients.list {
|
for _, c = range clients.list {
|
||||||
for _, id := range c.IDs {
|
for _, id := range c.IDs {
|
||||||
_, ipnet, err := net.ParseCIDR(id)
|
var n netip.Prefix
|
||||||
|
n, err = netip.ParsePrefix(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if ipnet.Contains(ip) {
|
if n.Contains(ip) {
|
||||||
return c, true
|
return c, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -514,19 +509,20 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
macFound := clients.dhcpServer.FindMACbyIP(ip)
|
macFound := clients.dhcpServer.FindMACbyIP(ip.AsSlice())
|
||||||
if macFound == nil {
|
if macFound == nil {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, c = range clients.list {
|
for _, c = range clients.list {
|
||||||
for _, id := range c.IDs {
|
for _, id := range c.IDs {
|
||||||
hwAddr, err := net.ParseMAC(id)
|
var mac net.HardwareAddr
|
||||||
|
mac, err = net.ParseMAC(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if bytes.Equal(hwAddr, macFound) {
|
if bytes.Equal(mac, macFound) {
|
||||||
return c, true
|
return c, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -535,32 +531,18 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// findRuntimeClientLocked finds a runtime client by their IP address. For
|
|
||||||
// internal use only.
|
|
||||||
func (clients *clientsContainer) findRuntimeClientLocked(ip net.IP) (rc *RuntimeClient, ok bool) {
|
|
||||||
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
|
|
||||||
ipAddr, err := netutil.IPToAddrNoMapped(ip)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("clients: bad client ip %v: %s", ip, err)
|
|
||||||
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
rc, ok = clients.ipToRC[ipAddr]
|
|
||||||
|
|
||||||
return rc, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// findRuntimeClient finds a runtime client by their IP.
|
// findRuntimeClient finds a runtime client by their IP.
|
||||||
func (clients *clientsContainer) findRuntimeClient(ip net.IP) (rc *RuntimeClient, ok bool) {
|
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *RuntimeClient, ok bool) {
|
||||||
if ip == nil {
|
if ip == (netip.Addr{}) {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
return clients.findRuntimeClientLocked(ip)
|
rc, ok = clients.ipToRC[ip]
|
||||||
|
|
||||||
|
return rc, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
// check validates the client.
|
// check validates the client.
|
||||||
|
@ -578,14 +560,16 @@ func (clients *clientsContainer) check(c *Client) (err error) {
|
||||||
|
|
||||||
for i, id := range c.IDs {
|
for i, id := range c.IDs {
|
||||||
// Normalize structured data.
|
// Normalize structured data.
|
||||||
var ip net.IP
|
var (
|
||||||
var ipnet *net.IPNet
|
ip netip.Addr
|
||||||
var mac net.HardwareAddr
|
n netip.Prefix
|
||||||
if ip = net.ParseIP(id); ip != nil {
|
mac net.HardwareAddr
|
||||||
|
)
|
||||||
|
|
||||||
|
if ip, err = netip.ParseAddr(id); err == nil {
|
||||||
c.IDs[i] = ip.String()
|
c.IDs[i] = ip.String()
|
||||||
} else if ip, ipnet, err = net.ParseCIDR(id); err == nil {
|
} else if n, err = netip.ParsePrefix(id); err == nil {
|
||||||
ipnet.IP = ip
|
c.IDs[i] = n.String()
|
||||||
c.IDs[i] = ipnet.String()
|
|
||||||
} else if mac, err = net.ParseMAC(id); err == nil {
|
} else if mac, err = net.ParseMAC(id); err == nil {
|
||||||
c.IDs[i] = mac.String()
|
c.IDs[i] = mac.String()
|
||||||
} else if err = dnsforward.ValidateClientID(id); err == nil {
|
} else if err = dnsforward.ValidateClientID(id); err == nil {
|
||||||
|
@ -750,7 +734,7 @@ func (clients *clientsContainer) Update(name string, c *Client) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// setWHOISInfo sets the WHOIS information for a client.
|
// setWHOISInfo sets the WHOIS information for a client.
|
||||||
func (clients *clientsContainer) setWHOISInfo(ip net.IP, wi *RuntimeClientWHOISInfo) {
|
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *RuntimeClientWHOISInfo) {
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
|
@ -760,7 +744,7 @@ func (clients *clientsContainer) setWHOISInfo(ip net.IP, wi *RuntimeClientWHOISI
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rc, ok := clients.findRuntimeClientLocked(ip)
|
rc, ok := clients.ipToRC[ip]
|
||||||
if ok {
|
if ok {
|
||||||
rc.WHOISInfo = wi
|
rc.WHOISInfo = wi
|
||||||
log.Debug("clients: set whois info for runtime client %s: %+v", rc.Host, wi)
|
log.Debug("clients: set whois info for runtime client %s: %+v", rc.Host, wi)
|
||||||
|
@ -776,32 +760,22 @@ func (clients *clientsContainer) setWHOISInfo(ip net.IP, wi *RuntimeClientWHOISI
|
||||||
|
|
||||||
rc.WHOISInfo = wi
|
rc.WHOISInfo = wi
|
||||||
|
|
||||||
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
|
clients.ipToRC[ip] = rc
|
||||||
ipAddr, err := netutil.IPToAddrNoMapped(ip)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("clients: bad client ip %v: %s", ip, err)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
clients.ipToRC[ipAddr] = rc
|
|
||||||
|
|
||||||
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddHost adds a new IP-hostname pairing. The priorities of the sources are
|
// AddHost adds a new IP-hostname pairing. The priorities of the sources are
|
||||||
// taken into account. ok is true if the pairing was added.
|
// taken into account. ok is true if the pairing was added.
|
||||||
func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSource) (ok bool, err error) {
|
func (clients *clientsContainer) AddHost(
|
||||||
|
ip netip.Addr,
|
||||||
|
host string,
|
||||||
|
src clientSource,
|
||||||
|
) (ok bool) {
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
|
return clients.addHostLocked(ip, host, src)
|
||||||
ipAddr, err := netutil.IPToAddrNoMapped(ip)
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("adding host: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return clients.addHostLocked(ipAddr, host, src), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// addHostLocked adds a new IP-hostname pairing. clients.lock is expected to be
|
// addHostLocked adds a new IP-hostname pairing. clients.lock is expected to be
|
||||||
|
|
|
@ -22,8 +22,18 @@ func TestClients(t *testing.T) {
|
||||||
clients.Init(nil, nil, nil, nil)
|
clients.Init(nil, nil, nil, nil)
|
||||||
|
|
||||||
t.Run("add_success", func(t *testing.T) {
|
t.Run("add_success", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
cliNone = "1.2.3.4"
|
||||||
|
cli1 = "1.1.1.1"
|
||||||
|
cli2 = "2.2.2.2"
|
||||||
|
|
||||||
|
cliNoneIP = netip.MustParseAddr(cliNone)
|
||||||
|
cli1IP = netip.MustParseAddr(cli1)
|
||||||
|
cli2IP = netip.MustParseAddr(cli2)
|
||||||
|
)
|
||||||
|
|
||||||
c := &Client{
|
c := &Client{
|
||||||
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
|
IDs: []string{cli1, "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,7 +43,7 @@ func TestClients(t *testing.T) {
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
c = &Client{
|
c = &Client{
|
||||||
IDs: []string{"2.2.2.2"},
|
IDs: []string{cli2},
|
||||||
Name: "client2",
|
Name: "client2",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,7 +52,7 @@ func TestClients(t *testing.T) {
|
||||||
|
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
c, ok = clients.Find("1.1.1.1")
|
c, ok = clients.Find(cli1)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
|
|
||||||
assert.Equal(t, "client1", c.Name)
|
assert.Equal(t, "client1", c.Name)
|
||||||
|
@ -52,14 +62,14 @@ func TestClients(t *testing.T) {
|
||||||
|
|
||||||
assert.Equal(t, "client1", c.Name)
|
assert.Equal(t, "client1", c.Name)
|
||||||
|
|
||||||
c, ok = clients.Find("2.2.2.2")
|
c, ok = clients.Find(cli2)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
|
|
||||||
assert.Equal(t, "client2", c.Name)
|
assert.Equal(t, "client2", c.Name)
|
||||||
|
|
||||||
assert.False(t, clients.exists(net.IP{1, 2, 3, 4}, ClientSourceHostsFile))
|
assert.False(t, clients.exists(cliNoneIP, ClientSourceHostsFile))
|
||||||
assert.True(t, clients.exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile))
|
assert.True(t, clients.exists(cli1IP, ClientSourceHostsFile))
|
||||||
assert.True(t, clients.exists(net.IP{2, 2, 2, 2}, ClientSourceHostsFile))
|
assert.True(t, clients.exists(cli2IP, ClientSourceHostsFile))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("add_fail_name", func(t *testing.T) {
|
t.Run("add_fail_name", func(t *testing.T) {
|
||||||
|
@ -103,23 +113,31 @@ func TestClients(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("update_success", func(t *testing.T) {
|
t.Run("update_success", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
cliOld = "1.1.1.1"
|
||||||
|
cliNew = "1.1.1.2"
|
||||||
|
|
||||||
|
cliOldIP = netip.MustParseAddr(cliOld)
|
||||||
|
cliNewIP = netip.MustParseAddr(cliNew)
|
||||||
|
)
|
||||||
|
|
||||||
err := clients.Update("client1", &Client{
|
err := clients.Update("client1", &Client{
|
||||||
IDs: []string{"1.1.1.2"},
|
IDs: []string{cliNew},
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.False(t, clients.exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile))
|
assert.False(t, clients.exists(cliOldIP, ClientSourceHostsFile))
|
||||||
assert.True(t, clients.exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile))
|
assert.True(t, clients.exists(cliNewIP, ClientSourceHostsFile))
|
||||||
|
|
||||||
err = clients.Update("client1", &Client{
|
err = clients.Update("client1", &Client{
|
||||||
IDs: []string{"1.1.1.2"},
|
IDs: []string{cliNew},
|
||||||
Name: "client1-renamed",
|
Name: "client1-renamed",
|
||||||
UseOwnSettings: true,
|
UseOwnSettings: true,
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
c, ok := clients.Find("1.1.1.2")
|
c, ok := clients.Find(cliNew)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
|
|
||||||
assert.Equal(t, "client1-renamed", c.Name)
|
assert.Equal(t, "client1-renamed", c.Name)
|
||||||
|
@ -132,14 +150,14 @@ func TestClients(t *testing.T) {
|
||||||
|
|
||||||
require.Len(t, c.IDs, 1)
|
require.Len(t, c.IDs, 1)
|
||||||
|
|
||||||
assert.Equal(t, "1.1.1.2", c.IDs[0])
|
assert.Equal(t, cliNew, c.IDs[0])
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("del_success", func(t *testing.T) {
|
t.Run("del_success", func(t *testing.T) {
|
||||||
ok := clients.Del("client1-renamed")
|
ok := clients.Del("client1-renamed")
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
|
|
||||||
assert.False(t, clients.exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile))
|
assert.False(t, clients.exists(netip.MustParseAddr("1.1.1.2"), ClientSourceHostsFile))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("del_fail", func(t *testing.T) {
|
t.Run("del_fail", func(t *testing.T) {
|
||||||
|
@ -148,45 +166,33 @@ func TestClients(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("addhost_success", func(t *testing.T) {
|
t.Run("addhost_success", func(t *testing.T) {
|
||||||
ip := net.IP{1, 1, 1, 1}
|
ip := netip.MustParseAddr("1.1.1.1")
|
||||||
|
ok := clients.AddHost(ip, "host", ClientSourceARP)
|
||||||
ok, err := clients.AddHost(ip, "host", ClientSourceARP)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
ok, err = clients.AddHost(ip, "host2", ClientSourceARP)
|
ok = clients.AddHost(ip, "host2", ClientSourceARP)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
ok, err = clients.AddHost(ip, "host3", ClientSourceHostsFile)
|
ok = clients.AddHost(ip, "host3", ClientSourceHostsFile)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
assert.True(t, clients.exists(ip, ClientSourceHostsFile))
|
assert.True(t, clients.exists(ip, ClientSourceHostsFile))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("dhcp_replaces_arp", func(t *testing.T) {
|
t.Run("dhcp_replaces_arp", func(t *testing.T) {
|
||||||
ip := net.IP{1, 2, 3, 4}
|
ip := netip.MustParseAddr("1.2.3.4")
|
||||||
|
ok := clients.AddHost(ip, "from_arp", ClientSourceARP)
|
||||||
ok, err := clients.AddHost(ip, "from_arp", ClientSourceARP)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.True(t, clients.exists(ip, ClientSourceARP))
|
assert.True(t, clients.exists(ip, ClientSourceARP))
|
||||||
|
|
||||||
ok, err = clients.AddHost(ip, "from_dhcp", ClientSourceDHCP)
|
ok = clients.AddHost(ip, "from_dhcp", ClientSourceDHCP)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.True(t, clients.exists(ip, ClientSourceDHCP))
|
assert.True(t, clients.exists(ip, ClientSourceDHCP))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("addhost_fail", func(t *testing.T) {
|
t.Run("addhost_fail", func(t *testing.T) {
|
||||||
ok, err := clients.AddHost(net.IP{1, 1, 1, 1}, "host1", ClientSourceRDNS)
|
ip := netip.MustParseAddr("1.1.1.1")
|
||||||
require.NoError(t, err)
|
ok := clients.AddHost(ip, "host1", ClientSourceRDNS)
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -203,7 +209,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||||
|
|
||||||
t.Run("new_client", func(t *testing.T) {
|
t.Run("new_client", func(t *testing.T) {
|
||||||
ip := netip.MustParseAddr("1.1.1.255")
|
ip := netip.MustParseAddr("1.1.1.255")
|
||||||
clients.setWHOISInfo(ip.AsSlice(), whois)
|
clients.setWHOISInfo(ip, whois)
|
||||||
rc := clients.ipToRC[ip]
|
rc := clients.ipToRC[ip]
|
||||||
require.NotNil(t, rc)
|
require.NotNil(t, rc)
|
||||||
|
|
||||||
|
@ -212,12 +218,10 @@ func TestClientsWHOIS(t *testing.T) {
|
||||||
|
|
||||||
t.Run("existing_auto-client", func(t *testing.T) {
|
t.Run("existing_auto-client", func(t *testing.T) {
|
||||||
ip := netip.MustParseAddr("1.1.1.1")
|
ip := netip.MustParseAddr("1.1.1.1")
|
||||||
ok, err := clients.AddHost(ip.AsSlice(), "host", ClientSourceRDNS)
|
ok := clients.AddHost(ip, "host", ClientSourceRDNS)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
clients.setWHOISInfo(ip.AsSlice(), whois)
|
clients.setWHOISInfo(ip, whois)
|
||||||
rc := clients.ipToRC[ip]
|
rc := clients.ipToRC[ip]
|
||||||
require.NotNil(t, rc)
|
require.NotNil(t, rc)
|
||||||
|
|
||||||
|
@ -234,7 +238,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
clients.setWHOISInfo(ip.AsSlice(), whois)
|
clients.setWHOISInfo(ip, whois)
|
||||||
rc := clients.ipToRC[ip]
|
rc := clients.ipToRC[ip]
|
||||||
require.Nil(t, rc)
|
require.Nil(t, rc)
|
||||||
|
|
||||||
|
@ -249,7 +253,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||||
clients.Init(nil, nil, nil, nil)
|
clients.Init(nil, nil, nil, nil)
|
||||||
|
|
||||||
t.Run("simple", func(t *testing.T) {
|
t.Run("simple", func(t *testing.T) {
|
||||||
ip := net.IP{1, 1, 1, 1}
|
ip := netip.MustParseAddr("1.1.1.1")
|
||||||
|
|
||||||
// Add a client.
|
// Add a client.
|
||||||
ok, err := clients.Add(&Client{
|
ok, err := clients.Add(&Client{
|
||||||
|
@ -260,8 +264,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
// Now add an auto-client with the same IP.
|
// Now add an auto-client with the same IP.
|
||||||
ok, err = clients.AddHost(ip, "test", ClientSourceRDNS)
|
ok = clients.AddHost(ip, "test", ClientSourceRDNS)
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,8 @@ package home
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
)
|
)
|
||||||
|
@ -47,8 +47,8 @@ type runtimeClientJSON struct {
|
||||||
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
|
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
|
||||||
|
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
IP netip.Addr `json:"ip"`
|
||||||
Source clientSource `json:"source"`
|
Source clientSource `json:"source"`
|
||||||
IP net.IP `json:"ip"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type clientListJSON struct {
|
type clientListJSON struct {
|
||||||
|
@ -75,7 +75,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||||
|
|
||||||
Name: rc.Host,
|
Name: rc.Host,
|
||||||
Source: rc.Source,
|
Source: rc.Source,
|
||||||
IP: ip.AsSlice(),
|
IP: ip,
|
||||||
}
|
}
|
||||||
|
|
||||||
data.RuntimeClients = append(data.RuntimeClients, cj)
|
data.RuntimeClients = append(data.RuntimeClients, cj)
|
||||||
|
@ -218,7 +218,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
ip := net.ParseIP(idStr)
|
ip, _ := netip.ParseAddr(idStr)
|
||||||
c, ok := clients.Find(idStr)
|
c, ok := clients.Find(idStr)
|
||||||
var cj *clientJSON
|
var cj *clientJSON
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -240,7 +240,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||||
// findRuntime looks up the IP in runtime and temporary storages, like
|
// findRuntime looks up the IP in runtime and temporary storages, like
|
||||||
// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
|
// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
|
||||||
// non-nil.
|
// non-nil.
|
||||||
func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj *clientJSON) {
|
func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) {
|
||||||
rc, ok := clients.findRuntimeClient(ip)
|
rc, ok := clients.findRuntimeClient(ip)
|
||||||
if !ok {
|
if !ok {
|
||||||
// It is still possible that the IP used to be in the runtime clients
|
// It is still possible that the IP used to be in the runtime clients
|
||||||
|
|
|
@ -71,9 +71,7 @@ func appendDNSAddrsWithIfaces(dst []string, src []netip.Addr) (res []string, err
|
||||||
// on, including the addresses on all interfaces in cases of unspecified IPs.
|
// on, including the addresses on all interfaces in cases of unspecified IPs.
|
||||||
func collectDNSAddresses() (addrs []string, err error) {
|
func collectDNSAddresses() (addrs []string, err error) {
|
||||||
if hosts := config.DNS.BindHosts; len(hosts) == 0 {
|
if hosts := config.DNS.BindHosts; len(hosts) == 0 {
|
||||||
addr := aghnet.IPv4Localhost()
|
addrs = appendDNSAddrs(addrs, netutil.IPv4Localhost())
|
||||||
|
|
||||||
addrs = appendDNSAddrs(addrs, addr)
|
|
||||||
} else {
|
} else {
|
||||||
addrs, err = appendDNSAddrsWithIfaces(addrs, hosts)
|
addrs, err = appendDNSAddrsWithIfaces(addrs, hosts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
|
@ -150,8 +151,8 @@ func isRunning() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func onDNSRequest(pctx *proxy.DNSContext) {
|
func onDNSRequest(pctx *proxy.DNSContext) {
|
||||||
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
|
ip := netutil.NetAddrToAddrPort(pctx.Addr).Addr()
|
||||||
if ip == nil {
|
if ip == (netip.Addr{}) {
|
||||||
// This would be quite weird if we get here.
|
// This would be quite weird if we get here.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -160,7 +161,8 @@ func onDNSRequest(pctx *proxy.DNSContext) {
|
||||||
if srcs.RDNS && !ip.IsLoopback() {
|
if srcs.RDNS && !ip.IsLoopback() {
|
||||||
Context.rdns.Begin(ip)
|
Context.rdns.Begin(ip)
|
||||||
}
|
}
|
||||||
if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) {
|
|
||||||
|
if srcs.WHOIS && !netutil.IsSpecialPurposeAddr(ip) {
|
||||||
Context.whois.Begin(ip)
|
Context.whois.Begin(ip)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -193,11 +195,7 @@ func ipsToUDPAddrs(ips []netip.Addr, port int) (udpAddrs []*net.UDPAddr) {
|
||||||
|
|
||||||
func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
||||||
dnsConf := config.DNS
|
dnsConf := config.DNS
|
||||||
hosts := dnsConf.BindHosts
|
hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()})
|
||||||
if len(hosts) == 0 {
|
|
||||||
hosts = []netip.Addr{aghnet.IPv4Localhost()}
|
|
||||||
}
|
|
||||||
|
|
||||||
newConf = dnsforward.ServerConfig{
|
newConf = dnsforward.ServerConfig{
|
||||||
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
|
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
|
||||||
TCPListenAddrs: ipsToTCPAddrs(hosts, dnsConf.Port),
|
TCPListenAddrs: ipsToTCPAddrs(hosts, dnsConf.Port),
|
||||||
|
@ -400,15 +398,12 @@ func startDNSServer() error {
|
||||||
|
|
||||||
const topClientsNumber = 100 // the number of clients to get
|
const topClientsNumber = 100 // the number of clients to get
|
||||||
for _, ip := range Context.stats.TopClientsIP(topClientsNumber) {
|
for _, ip := range Context.stats.TopClientsIP(topClientsNumber) {
|
||||||
if ip == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
srcs := config.Clients.Sources
|
srcs := config.Clients.Sources
|
||||||
if srcs.RDNS && !ip.IsLoopback() {
|
if srcs.RDNS && !ip.IsLoopback() {
|
||||||
Context.rdns.Begin(ip)
|
Context.rdns.Begin(ip)
|
||||||
}
|
}
|
||||||
if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) {
|
|
||||||
|
if srcs.WHOIS && !netutil.IsSpecialPurposeAddr(ip) {
|
||||||
Context.whois.Begin(ip)
|
Context.whois.Begin(ip)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -576,7 +576,7 @@ func checkPermissions() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// We should check if AdGuard Home is able to bind to port 53
|
// We should check if AdGuard Home is able to bind to port 53
|
||||||
err := aghnet.CheckPort("tcp", netip.AddrPortFrom(aghnet.IPv4Localhost(), defaultPortDNS))
|
err := aghnet.CheckPort("tcp", netip.AddrPortFrom(netutil.IPv4Localhost(), defaultPortDNS))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, os.ErrPermission) {
|
if errors.Is(err, os.ErrPermission) {
|
||||||
log.Fatal(`Permission check failed.
|
log.Fatal(`Permission check failed.
|
||||||
|
|
|
@ -2,7 +2,7 @@ package home
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"net"
|
"net/netip"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ type RDNS struct {
|
||||||
usePrivate uint32
|
usePrivate uint32
|
||||||
|
|
||||||
// ipCh used to pass client's IP to rDNS workerLoop.
|
// ipCh used to pass client's IP to rDNS workerLoop.
|
||||||
ipCh chan net.IP
|
ipCh chan netip.Addr
|
||||||
|
|
||||||
// ipCache caches the IP addresses to be resolved by rDNS. The resolved
|
// ipCache caches the IP addresses to be resolved by rDNS. The resolved
|
||||||
// address stays here while it's inside clients. After leaving clients the
|
// address stays here while it's inside clients. After leaving clients the
|
||||||
|
@ -50,7 +50,7 @@ func NewRDNS(
|
||||||
EnableLRU: true,
|
EnableLRU: true,
|
||||||
MaxCount: defaultRDNSCacheSize,
|
MaxCount: defaultRDNSCacheSize,
|
||||||
}),
|
}),
|
||||||
ipCh: make(chan net.IP, defaultRDNSIPChSize),
|
ipCh: make(chan netip.Addr, defaultRDNSIPChSize),
|
||||||
}
|
}
|
||||||
if usePrivate {
|
if usePrivate {
|
||||||
rDNS.usePrivate = 1
|
rDNS.usePrivate = 1
|
||||||
|
@ -80,9 +80,10 @@ func (r *RDNS) ensurePrivateCache() {
|
||||||
|
|
||||||
// isCached returns true if ip is already cached and not expired yet. It also
|
// isCached returns true if ip is already cached and not expired yet. It also
|
||||||
// caches it otherwise.
|
// caches it otherwise.
|
||||||
func (r *RDNS) isCached(ip net.IP) (ok bool) {
|
func (r *RDNS) isCached(ip netip.Addr) (ok bool) {
|
||||||
|
ipBytes := ip.AsSlice()
|
||||||
now := uint64(time.Now().Unix())
|
now := uint64(time.Now().Unix())
|
||||||
if expire := r.ipCache.Get(ip); len(expire) != 0 {
|
if expire := r.ipCache.Get(ipBytes); len(expire) != 0 {
|
||||||
if binary.BigEndian.Uint64(expire) > now {
|
if binary.BigEndian.Uint64(expire) > now {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -91,14 +92,14 @@ func (r *RDNS) isCached(ip net.IP) (ok bool) {
|
||||||
// The cache entry either expired or doesn't exist.
|
// The cache entry either expired or doesn't exist.
|
||||||
ttl := make([]byte, 8)
|
ttl := make([]byte, 8)
|
||||||
binary.BigEndian.PutUint64(ttl, now+defaultRDNSCacheTTL)
|
binary.BigEndian.PutUint64(ttl, now+defaultRDNSCacheTTL)
|
||||||
r.ipCache.Set(ip, ttl)
|
r.ipCache.Set(ipBytes, ttl)
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Begin adds the ip to the resolving queue if it is not cached or already
|
// Begin adds the ip to the resolving queue if it is not cached or already
|
||||||
// resolved.
|
// resolved.
|
||||||
func (r *RDNS) Begin(ip net.IP) {
|
func (r *RDNS) Begin(ip netip.Addr) {
|
||||||
r.ensurePrivateCache()
|
r.ensurePrivateCache()
|
||||||
|
|
||||||
if r.isCached(ip) || r.clients.exists(ip, ClientSourceRDNS) {
|
if r.isCached(ip) || r.clients.exists(ip, ClientSourceRDNS) {
|
||||||
|
@ -107,9 +108,9 @@ func (r *RDNS) Begin(ip net.IP) {
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case r.ipCh <- ip:
|
case r.ipCh <- ip:
|
||||||
log.Tracef("rdns: %q added to queue", ip)
|
log.Debug("rdns: %q added to queue", ip)
|
||||||
default:
|
default:
|
||||||
log.Tracef("rdns: queue is full")
|
log.Debug("rdns: queue is full")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,7 +120,7 @@ func (r *RDNS) workerLoop() {
|
||||||
defer log.OnPanic("rdns")
|
defer log.OnPanic("rdns")
|
||||||
|
|
||||||
for ip := range r.ipCh {
|
for ip := range r.ipCh {
|
||||||
host, err := r.exchanger.Exchange(ip)
|
host, err := r.exchanger.Exchange(ip.AsSlice())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug("rdns: resolving %q: %s", ip, err)
|
log.Debug("rdns: resolving %q: %s", ip, err)
|
||||||
|
|
||||||
|
@ -128,8 +129,6 @@ func (r *RDNS) workerLoop() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't handle any errors since AddHost doesn't return non-nil errors
|
_ = r.clients.AddHost(ip, host, ClientSourceRDNS)
|
||||||
// for now.
|
|
||||||
_, _ = r.clients.AddHost(ip, host, ClientSourceRDNS)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,14 +27,14 @@ func TestRDNS_Begin(t *testing.T) {
|
||||||
w := &bytes.Buffer{}
|
w := &bytes.Buffer{}
|
||||||
aghtest.ReplaceLogWriter(t, w)
|
aghtest.ReplaceLogWriter(t, w)
|
||||||
|
|
||||||
ip1234, ip1235 := net.IP{1, 2, 3, 4}, net.IP{1, 2, 3, 5}
|
ip1234, ip1235 := netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("1.2.3.5")
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
cliIDIndex map[string]*Client
|
cliIDIndex map[string]*Client
|
||||||
customChan chan net.IP
|
customChan chan netip.Addr
|
||||||
name string
|
name string
|
||||||
wantLog string
|
wantLog string
|
||||||
req net.IP
|
ip netip.Addr
|
||||||
wantCacheHit int
|
wantCacheHit int
|
||||||
wantCacheMiss int
|
wantCacheMiss int
|
||||||
}{{
|
}{{
|
||||||
|
@ -42,7 +42,7 @@ func TestRDNS_Begin(t *testing.T) {
|
||||||
customChan: nil,
|
customChan: nil,
|
||||||
name: "cached",
|
name: "cached",
|
||||||
wantLog: "",
|
wantLog: "",
|
||||||
req: ip1234,
|
ip: ip1234,
|
||||||
wantCacheHit: 1,
|
wantCacheHit: 1,
|
||||||
wantCacheMiss: 0,
|
wantCacheMiss: 0,
|
||||||
}, {
|
}, {
|
||||||
|
@ -50,7 +50,7 @@ func TestRDNS_Begin(t *testing.T) {
|
||||||
customChan: nil,
|
customChan: nil,
|
||||||
name: "not_cached",
|
name: "not_cached",
|
||||||
wantLog: "rdns: queue is full",
|
wantLog: "rdns: queue is full",
|
||||||
req: ip1235,
|
ip: ip1235,
|
||||||
wantCacheHit: 0,
|
wantCacheHit: 0,
|
||||||
wantCacheMiss: 1,
|
wantCacheMiss: 1,
|
||||||
}, {
|
}, {
|
||||||
|
@ -58,15 +58,15 @@ func TestRDNS_Begin(t *testing.T) {
|
||||||
customChan: nil,
|
customChan: nil,
|
||||||
name: "already_in_clients",
|
name: "already_in_clients",
|
||||||
wantLog: "",
|
wantLog: "",
|
||||||
req: ip1235,
|
ip: ip1235,
|
||||||
wantCacheHit: 0,
|
wantCacheHit: 0,
|
||||||
wantCacheMiss: 1,
|
wantCacheMiss: 1,
|
||||||
}, {
|
}, {
|
||||||
cliIDIndex: map[string]*Client{},
|
cliIDIndex: map[string]*Client{},
|
||||||
customChan: make(chan net.IP, 1),
|
customChan: make(chan netip.Addr, 1),
|
||||||
name: "add_to_queue",
|
name: "add_to_queue",
|
||||||
wantLog: `rdns: "1.2.3.5" added to queue`,
|
wantLog: `rdns: "1.2.3.5" added to queue`,
|
||||||
req: ip1235,
|
ip: ip1235,
|
||||||
wantCacheHit: 0,
|
wantCacheHit: 0,
|
||||||
wantCacheMiss: 1,
|
wantCacheMiss: 1,
|
||||||
}}
|
}}
|
||||||
|
@ -102,7 +102,7 @@ func TestRDNS_Begin(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
rdns.Begin(tc.req)
|
rdns.Begin(tc.ip)
|
||||||
assert.Equal(t, tc.wantCacheHit, ipCache.Stats().Hit)
|
assert.Equal(t, tc.wantCacheHit, ipCache.Stats().Hit)
|
||||||
assert.Equal(t, tc.wantCacheMiss, ipCache.Stats().Miss)
|
assert.Equal(t, tc.wantCacheMiss, ipCache.Stats().Miss)
|
||||||
assert.Contains(t, w.String(), tc.wantLog)
|
assert.Contains(t, w.String(), tc.wantLog)
|
||||||
|
@ -179,8 +179,8 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||||
w := &bytes.Buffer{}
|
w := &bytes.Buffer{}
|
||||||
aghtest.ReplaceLogWriter(t, w)
|
aghtest.ReplaceLogWriter(t, w)
|
||||||
|
|
||||||
localIP := net.IP{192, 168, 1, 1}
|
localIP := netip.MustParseAddr("192.168.1.1")
|
||||||
revIPv4, err := netutil.IPToReversedAddr(localIP)
|
revIPv4, err := netutil.IPToReversedAddr(localIP.AsSlice())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
revIPv6, err := netutil.IPToReversedAddr(net.ParseIP("2a00:1450:400c:c06::93"))
|
revIPv6, err := netutil.IPToReversedAddr(net.ParseIP("2a00:1450:400c:c06::93"))
|
||||||
|
@ -201,24 +201,24 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
ups upstream.Upstream
|
ups upstream.Upstream
|
||||||
|
cliIP netip.Addr
|
||||||
wantLog string
|
wantLog string
|
||||||
name string
|
name string
|
||||||
cliIP net.IP
|
|
||||||
}{{
|
}{{
|
||||||
ups: locUpstream,
|
ups: locUpstream,
|
||||||
|
cliIP: localIP,
|
||||||
wantLog: "",
|
wantLog: "",
|
||||||
name: "all_good",
|
name: "all_good",
|
||||||
cliIP: localIP,
|
|
||||||
}, {
|
}, {
|
||||||
ups: errUpstream,
|
ups: errUpstream,
|
||||||
|
cliIP: netip.MustParseAddr("192.168.1.2"),
|
||||||
wantLog: `rdns: resolving "192.168.1.2": test upstream error`,
|
wantLog: `rdns: resolving "192.168.1.2": test upstream error`,
|
||||||
name: "resolve_error",
|
name: "resolve_error",
|
||||||
cliIP: net.IP{192, 168, 1, 2},
|
|
||||||
}, {
|
}, {
|
||||||
ups: locUpstream,
|
ups: locUpstream,
|
||||||
|
cliIP: netip.MustParseAddr("2a00:1450:400c:c06::93"),
|
||||||
wantLog: "",
|
wantLog: "",
|
||||||
name: "ipv6_good",
|
name: "ipv6_good",
|
||||||
cliIP: net.ParseIP("2a00:1450:400c:c06::93"),
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
|
@ -230,7 +230,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||||
ipToRC: map[netip.Addr]*RuntimeClient{},
|
ipToRC: map[netip.Addr]*RuntimeClient{},
|
||||||
allTags: stringutil.NewSet(),
|
allTags: stringutil.NewSet(),
|
||||||
}
|
}
|
||||||
ch := make(chan net.IP)
|
ch := make(chan netip.Addr)
|
||||||
rdns := &RDNS{
|
rdns := &RDNS{
|
||||||
exchanger: &rDNSExchanger{
|
exchanger: &rDNSExchanger{
|
||||||
ex: tc.ups,
|
ex: tc.ups,
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -26,7 +27,7 @@ const (
|
||||||
// WHOIS - module context
|
// WHOIS - module context
|
||||||
type WHOIS struct {
|
type WHOIS struct {
|
||||||
clients *clientsContainer
|
clients *clientsContainer
|
||||||
ipChan chan net.IP
|
ipChan chan netip.Addr
|
||||||
|
|
||||||
// dialContext specifies the dial function for creating unencrypted TCP
|
// dialContext specifies the dial function for creating unencrypted TCP
|
||||||
// connections.
|
// connections.
|
||||||
|
@ -51,7 +52,7 @@ func initWHOIS(clients *clientsContainer) *WHOIS {
|
||||||
MaxCount: 10000,
|
MaxCount: 10000,
|
||||||
}),
|
}),
|
||||||
dialContext: customDialContext,
|
dialContext: customDialContext,
|
||||||
ipChan: make(chan net.IP, 255),
|
ipChan: make(chan netip.Addr, 255),
|
||||||
}
|
}
|
||||||
|
|
||||||
go w.workerLoop()
|
go w.workerLoop()
|
||||||
|
@ -192,7 +193,7 @@ func (w *WHOIS) queryAll(ctx context.Context, target string) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Request WHOIS information
|
// Request WHOIS information
|
||||||
func (w *WHOIS) process(ctx context.Context, ip net.IP) (wi *RuntimeClientWHOISInfo) {
|
func (w *WHOIS) process(ctx context.Context, ip netip.Addr) (wi *RuntimeClientWHOISInfo) {
|
||||||
resp, err := w.queryAll(ctx, ip.String())
|
resp, err := w.queryAll(ctx, ip.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug("whois: error: %s IP:%s", err, ip)
|
log.Debug("whois: error: %s IP:%s", err, ip)
|
||||||
|
@ -220,24 +221,25 @@ func (w *WHOIS) process(ctx context.Context, ip net.IP) (wi *RuntimeClientWHOISI
|
||||||
}
|
}
|
||||||
|
|
||||||
// Begin - begin requesting WHOIS info
|
// Begin - begin requesting WHOIS info
|
||||||
func (w *WHOIS) Begin(ip net.IP) {
|
func (w *WHOIS) Begin(ip netip.Addr) {
|
||||||
|
ipBytes := ip.AsSlice()
|
||||||
now := uint64(time.Now().Unix())
|
now := uint64(time.Now().Unix())
|
||||||
expire := w.ipAddrs.Get([]byte(ip))
|
expire := w.ipAddrs.Get(ipBytes)
|
||||||
if len(expire) != 0 {
|
if len(expire) != 0 {
|
||||||
exp := binary.BigEndian.Uint64(expire)
|
exp := binary.BigEndian.Uint64(expire)
|
||||||
if exp > now {
|
if exp > now {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// TTL expired
|
|
||||||
}
|
}
|
||||||
|
|
||||||
expire = make([]byte, 8)
|
expire = make([]byte, 8)
|
||||||
binary.BigEndian.PutUint64(expire, now+whoisTTL)
|
binary.BigEndian.PutUint64(expire, now+whoisTTL)
|
||||||
_ = w.ipAddrs.Set([]byte(ip), expire)
|
_ = w.ipAddrs.Set(ipBytes, expire)
|
||||||
|
|
||||||
log.Debug("whois: adding %s", ip)
|
log.Debug("whois: adding %s", ip)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case w.ipChan <- ip:
|
case w.ipChan <- ip:
|
||||||
//
|
|
||||||
default:
|
default:
|
||||||
log.Debug("whois: queue is full")
|
log.Debug("whois: queue is full")
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -64,7 +65,7 @@ type Interface interface {
|
||||||
|
|
||||||
// GetTopClientIP returns at most limit IP addresses corresponding to the
|
// GetTopClientIP returns at most limit IP addresses corresponding to the
|
||||||
// clients with the most number of requests.
|
// clients with the most number of requests.
|
||||||
TopClientsIP(limit uint) []net.IP
|
TopClientsIP(limit uint) []netip.Addr
|
||||||
|
|
||||||
// WriteDiskConfig puts the Interface's configuration to the dc.
|
// WriteDiskConfig puts the Interface's configuration to the dc.
|
||||||
WriteDiskConfig(dc *DiskConfig)
|
WriteDiskConfig(dc *DiskConfig)
|
||||||
|
@ -107,8 +108,6 @@ type StatsCtx struct {
|
||||||
filename string
|
filename string
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Interface = &StatsCtx{}
|
|
||||||
|
|
||||||
// New creates s from conf and properly initializes it. Don't use s before
|
// New creates s from conf and properly initializes it. Don't use s before
|
||||||
// calling it's Start method.
|
// calling it's Start method.
|
||||||
func New(conf Config) (s *StatsCtx, err error) {
|
func New(conf Config) (s *StatsCtx, err error) {
|
||||||
|
@ -178,6 +177,9 @@ func withRecovered(orig *error) {
|
||||||
*orig = errors.WithDeferred(*orig, err)
|
*orig = errors.WithDeferred(*orig, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ Interface = (*StatsCtx)(nil)
|
||||||
|
|
||||||
// Start implements the Interface interface for *StatsCtx.
|
// Start implements the Interface interface for *StatsCtx.
|
||||||
func (s *StatsCtx) Start() {
|
func (s *StatsCtx) Start() {
|
||||||
s.initWeb()
|
s.initWeb()
|
||||||
|
@ -250,8 +252,8 @@ func (s *StatsCtx) WriteDiskConfig(dc *DiskConfig) {
|
||||||
dc.Interval = atomic.LoadUint32(&s.limitHours) / 24
|
dc.Interval = atomic.LoadUint32(&s.limitHours) / 24
|
||||||
}
|
}
|
||||||
|
|
||||||
// TopClientsIP implements the Interface interface for *StatsCtx.
|
// TopClientsIP implements the [Interface] interface for *StatsCtx.
|
||||||
func (s *StatsCtx) TopClientsIP(maxCount uint) (ips []net.IP) {
|
func (s *StatsCtx) TopClientsIP(maxCount uint) (ips []netip.Addr) {
|
||||||
limit := atomic.LoadUint32(&s.limitHours)
|
limit := atomic.LoadUint32(&s.limitHours)
|
||||||
if limit == 0 {
|
if limit == 0 {
|
||||||
return nil
|
return nil
|
||||||
|
@ -271,10 +273,10 @@ func (s *StatsCtx) TopClientsIP(maxCount uint) (ips []net.IP) {
|
||||||
}
|
}
|
||||||
|
|
||||||
a := convertMapToSlice(m, int(maxCount))
|
a := convertMapToSlice(m, int(maxCount))
|
||||||
ips = []net.IP{}
|
ips = []netip.Addr{}
|
||||||
for _, it := range a {
|
for _, it := range a {
|
||||||
ip := net.ParseIP(it.Name)
|
ip, err := netip.ParseAddr(it.Name)
|
||||||
if ip != nil {
|
if err == nil {
|
||||||
ips = append(ips, ip)
|
ips = append(ips, ip)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||||
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -45,7 +46,7 @@ func assertSuccessAndUnmarshal(t *testing.T, to any, handler http.Handler, req *
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStats(t *testing.T) {
|
func TestStats(t *testing.T) {
|
||||||
cliIP := net.IP{127, 0, 0, 1}
|
cliIP := netutil.IPv4Localhost()
|
||||||
cliIPStr := cliIP.String()
|
cliIPStr := cliIP.String()
|
||||||
|
|
||||||
handlers := map[string]http.Handler{}
|
handlers := map[string]http.Handler{}
|
||||||
|
@ -123,7 +124,7 @@ func TestStats(t *testing.T) {
|
||||||
topClients := s.TopClientsIP(2)
|
topClients := s.TopClientsIP(2)
|
||||||
require.NotEmpty(t, topClients)
|
require.NotEmpty(t, topClients)
|
||||||
|
|
||||||
assert.True(t, cliIP.Equal(topClients[0]))
|
assert.Equal(t, cliIP, topClients[0])
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("reset", func(t *testing.T) {
|
t.Run("reset", func(t *testing.T) {
|
||||||
|
|
Loading…
Add table
Reference in a new issue