Pull request 2265: AG-27492-client-runtime-storage

Squashed commit of the following:

commit a164bace2e0333cf95622f34df7b0e79eac69f41
Merge: 6567cd330 184f476bd
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Aug 21 16:14:55 2024 +0300

    Merge branch 'master' into AG-27492-client-runtime-storage

commit 6567cd330c
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Aug 20 16:45:43 2024 +0300

    all: imp code

commit 243123a404
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Aug 15 19:15:54 2024 +0300

    all: add tests

commit 6489996878
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Aug 5 15:12:05 2024 +0300

    all: client runtime storage
This commit is contained in:
Stanislav Chzhen 2024-08-21 16:27:28 +03:00
parent 184f476bdc
commit 30c0bbe5cc
7 changed files with 420 additions and 88 deletions

View file

@ -8,6 +8,7 @@ import (
"encoding"
"fmt"
"net/netip"
"slices"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
)
@ -120,6 +121,7 @@ func (r *Runtime) Info() (cs Source, host string) {
// SetInfo sets a host as a client information from the cs.
func (r *Runtime) SetInfo(cs Source, hosts []string) {
// TODO(s.chzhen): Use contract where hosts must contain non-empty host.
if len(hosts) == 1 && hosts[0] == "" {
hosts = []string{}
}
@ -175,3 +177,15 @@ func (r *Runtime) isEmpty() (ok bool) {
func (r *Runtime) Addr() (ip netip.Addr) {
return r.ip
}
// Clone returns a deep copy of the runtime client.
func (r *Runtime) Clone() (c *Runtime) {
return &Runtime{
ip: r.ip,
whois: r.whois.Clone(),
arp: slices.Clone(r.arp),
rdns: slices.Clone(r.rdns),
dhcp: slices.Clone(r.dhcp),
hostsFile: slices.Clone(r.hostsFile),
}
}

View file

@ -4,6 +4,7 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"sync"
"github.com/AdguardTeam/golibs/container"
@ -31,8 +32,6 @@ type Storage struct {
index *index
// runtimeIndex contains information about runtime clients.
//
// TODO(s.chzhen): Use it.
runtimeIndex *RuntimeIndex
}
@ -236,20 +235,75 @@ func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) {
return s.runtimeIndex.Client(ip)
}
// AddRuntime saves the runtime client information in the storage. IP address
// of a client must be unique. rc must not be nil.
//
// TODO(s.chzhen): Use it.
func (s *Storage) AddRuntime(rc *Runtime) {
// UpdateRuntime updates the stored runtime client with information from rc. If
// no such client exists, saves the copy of rc in storage. rc must not be nil.
func (s *Storage) UpdateRuntime(rc *Runtime) (added bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.runtimeIndex.Add(rc)
return s.updateRuntimeLocked(rc)
}
// updateRuntimeLocked updates the stored runtime client with information from
// rc. rc must not be nil. Storage.mu is expected to be locked.
func (s *Storage) updateRuntimeLocked(rc *Runtime) (added bool) {
stored := s.runtimeIndex.Client(rc.ip)
if stored == nil {
s.runtimeIndex.Add(rc.Clone())
return true
}
if rc.whois != nil {
stored.whois = rc.whois.Clone()
}
if rc.arp != nil {
stored.arp = slices.Clone(rc.arp)
}
if rc.rdns != nil {
stored.rdns = slices.Clone(rc.rdns)
}
if rc.dhcp != nil {
stored.dhcp = slices.Clone(rc.dhcp)
}
if rc.hostsFile != nil {
stored.hostsFile = slices.Clone(rc.hostsFile)
}
return false
}
// BatchUpdateBySource updates the stored runtime clients information from the
// specified source and returns the number of added and removed clients.
func (s *Storage) BatchUpdateBySource(src Source, rcs []*Runtime) (added, removed int) {
s.mu.Lock()
defer s.mu.Unlock()
for _, rc := range s.runtimeIndex.index {
rc.unset(src)
}
for _, rc := range rcs {
if s.updateRuntimeLocked(rc) {
added++
}
}
for ip, rc := range s.runtimeIndex.index {
if rc.isEmpty() {
delete(s.runtimeIndex.index, ip)
removed++
}
}
return added, removed
}
// SizeRuntime returns the number of the runtime clients.
//
// TODO(s.chzhen): Use it.
func (s *Storage) SizeRuntime() (n int) {
s.mu.Lock()
defer s.mu.Unlock()
@ -258,8 +312,6 @@ func (s *Storage) SizeRuntime() (n int) {
}
// RangeRuntime calls f for each runtime client in an undefined order.
//
// TODO(s.chzhen): Use it.
func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
s.mu.Lock()
defer s.mu.Unlock()
@ -267,16 +319,6 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
s.runtimeIndex.Range(f)
}
// DeleteRuntime removes the runtime client by ip.
//
// TODO(s.chzhen): Use it.
func (s *Storage) DeleteRuntime(ip netip.Addr) {
s.mu.Lock()
defer s.mu.Unlock()
s.runtimeIndex.Delete(ip)
}
// DeleteBySource removes all runtime clients that have information only from
// the specified source and returns the number of removed clients.
//

View file

@ -6,6 +6,7 @@ import (
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -25,9 +26,19 @@ func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) {
require.NoError(tb, s.Add(c))
}
require.Equal(tb, len(m), s.Size())
return s
}
// newRuntimeClient is a helper function that returns a new runtime client.
func newRuntimeClient(ip netip.Addr, source client.Source, host string) (rc *client.Runtime) {
rc = client.NewRuntime(ip)
rc.SetInfo(source, []string{host})
return rc
}
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
// error.
func mustParseMAC(s string) (mac net.HardwareAddr) {
@ -43,6 +54,9 @@ func TestStorage_Add(t *testing.T) {
const (
existingName = "existing_name"
existingClientID = "existing_client_id"
allowedTag = "tag"
notAllowedTag = "not_allowed_tag"
)
var (
@ -60,7 +74,7 @@ func TestStorage_Add(t *testing.T) {
}
s := client.NewStorage(&client.Config{
AllowedTags: nil,
AllowedTags: []string{allowedTag},
})
err := s.Add(existingClient)
require.NoError(t, err)
@ -119,6 +133,15 @@ func TestStorage_Add(t *testing.T) {
},
wantErrMsg: `adding client: another client "existing_name" ` +
`uses the same ClientID "existing_client_id"`,
}, {
name: "not_allowed_tag",
cli: &client.Persistent{
Name: "nont_allowed_tag",
Tags: []string{notAllowedTag},
IPs: []netip.Addr{netip.MustParseAddr("4.4.4.4")},
UID: client.MustNewUID(),
},
wantErrMsg: `adding client: invalid tag: "not_allowed_tag"`,
}}
for _, tc := range testCases {
@ -341,6 +364,127 @@ func TestStorage_FindLoose(t *testing.T) {
}
}
func TestStorage_FindByName(t *testing.T) {
const (
cliIP1 = "1.1.1.1"
cliIP2 = "2.2.2.2"
)
const (
clientExistingName = "client_existing"
clientAnotherExistingName = "client_another_existing"
nonExistingClientName = "client_non_existing"
)
var (
clientExisting = &client.Persistent{
Name: clientExistingName,
IPs: []netip.Addr{netip.MustParseAddr(cliIP1)},
}
clientAnotherExisting = &client.Persistent{
Name: clientAnotherExistingName,
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
}
)
clients := []*client.Persistent{
clientExisting,
clientAnotherExisting,
}
s := newStorage(t, clients)
testCases := []struct {
want *client.Persistent
name string
clientName string
}{{
name: "existing",
clientName: clientExistingName,
want: clientExisting,
}, {
name: "another_existing",
clientName: clientAnotherExistingName,
want: clientAnotherExisting,
}, {
name: "non_existing",
clientName: nonExistingClientName,
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c, ok := s.FindByName(tc.clientName)
if tc.want == nil {
assert.False(t, ok)
return
}
assert.True(t, ok)
assert.Equal(t, tc.want, c)
})
}
}
func TestStorage_FindByMAC(t *testing.T) {
var (
cliMAC = mustParseMAC("11:11:11:11:11:11")
cliAnotherMAC = mustParseMAC("22:22:22:22:22:22")
nonExistingClientMAC = mustParseMAC("33:33:33:33:33:33")
)
var (
clientExisting = &client.Persistent{
Name: "client",
MACs: []net.HardwareAddr{cliMAC},
}
clientAnotherExisting = &client.Persistent{
Name: "another_client",
MACs: []net.HardwareAddr{cliAnotherMAC},
}
)
clients := []*client.Persistent{
clientExisting,
clientAnotherExisting,
}
s := newStorage(t, clients)
testCases := []struct {
want *client.Persistent
name string
clientMAC net.HardwareAddr
}{{
name: "existing",
clientMAC: cliMAC,
want: clientExisting,
}, {
name: "another_existing",
clientMAC: cliAnotherMAC,
want: clientAnotherExisting,
}, {
name: "non_existing",
clientMAC: nonExistingClientMAC,
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c, ok := s.FindByMAC(tc.clientMAC)
if tc.want == nil {
assert.False(t, ok)
return
}
assert.True(t, ok)
assert.Equal(t, tc.want, c)
})
}
}
func TestStorage_Update(t *testing.T) {
const (
clientName = "client_name"
@ -479,3 +623,157 @@ func TestStorage_RangeByName(t *testing.T) {
})
}
}
func TestStorage_UpdateRuntime(t *testing.T) {
const (
addedARP = "added_arp"
addedSecondARP = "added_arp"
updatedARP = "updated_arp"
cliCity = "City"
cliCountry = "Country"
cliOrgname = "Orgname"
)
var (
ip = netip.MustParseAddr("1.1.1.1")
ip2 = netip.MustParseAddr("2.2.2.2")
)
updated := client.NewRuntime(ip)
updated.SetInfo(client.SourceARP, []string{updatedARP})
info := &whois.Info{
City: cliCity,
Country: cliCountry,
Orgname: cliOrgname,
}
updated.SetWHOIS(info)
s := client.NewStorage(&client.Config{
AllowedTags: nil,
})
t.Run("add_arp_client", func(t *testing.T) {
added := client.NewRuntime(ip)
added.SetInfo(client.SourceARP, []string{addedARP})
require.True(t, s.UpdateRuntime(added))
require.Equal(t, 1, s.SizeRuntime())
got := s.ClientRuntime(ip)
source, host := got.Info()
assert.Equal(t, client.SourceARP, source)
assert.Equal(t, addedARP, host)
})
t.Run("add_second_arp_client", func(t *testing.T) {
added := client.NewRuntime(ip2)
added.SetInfo(client.SourceARP, []string{addedSecondARP})
require.True(t, s.UpdateRuntime(added))
require.Equal(t, 2, s.SizeRuntime())
got := s.ClientRuntime(ip2)
source, host := got.Info()
assert.Equal(t, client.SourceARP, source)
assert.Equal(t, addedSecondARP, host)
})
t.Run("update_first_client", func(t *testing.T) {
require.False(t, s.UpdateRuntime(updated))
got := s.ClientRuntime(ip)
require.Equal(t, 2, s.SizeRuntime())
source, host := got.Info()
assert.Equal(t, client.SourceARP, source)
assert.Equal(t, updatedARP, host)
})
t.Run("remove_arp_info", func(t *testing.T) {
n := s.DeleteBySource(client.SourceARP)
require.Equal(t, 1, n)
require.Equal(t, 1, s.SizeRuntime())
got := s.ClientRuntime(ip)
source, _ := got.Info()
assert.Equal(t, client.SourceWHOIS, source)
assert.Equal(t, info, got.WHOIS())
})
t.Run("remove_whois_info", func(t *testing.T) {
n := s.DeleteBySource(client.SourceWHOIS)
require.Equal(t, 1, n)
require.Equal(t, 0, s.SizeRuntime())
})
}
func TestStorage_BatchUpdateBySource(t *testing.T) {
const (
defSrc = client.SourceARP
cliFirstHost1 = "host1"
cliFirstHost2 = "host2"
cliUpdatedHost3 = "host3"
cliUpdatedHost4 = "host4"
cliUpdatedHost5 = "host5"
)
var (
cliFirstIP1 = netip.MustParseAddr("1.1.1.1")
cliFirstIP2 = netip.MustParseAddr("2.2.2.2")
cliUpdatedIP3 = netip.MustParseAddr("3.3.3.3")
cliUpdatedIP4 = netip.MustParseAddr("4.4.4.4")
cliUpdatedIP5 = netip.MustParseAddr("5.5.5.5")
)
firstClients := []*client.Runtime{
newRuntimeClient(cliFirstIP1, defSrc, cliFirstHost1),
newRuntimeClient(cliFirstIP2, defSrc, cliFirstHost2),
}
updatedClients := []*client.Runtime{
newRuntimeClient(cliUpdatedIP3, defSrc, cliUpdatedHost3),
newRuntimeClient(cliUpdatedIP4, defSrc, cliUpdatedHost4),
newRuntimeClient(cliUpdatedIP5, defSrc, cliUpdatedHost5),
}
s := client.NewStorage(&client.Config{
AllowedTags: nil,
})
t.Run("populate_storage_with_first_clients", func(t *testing.T) {
added, removed := s.BatchUpdateBySource(defSrc, firstClients)
require.Equal(t, len(firstClients), added)
require.Equal(t, 0, removed)
require.Equal(t, len(firstClients), s.SizeRuntime())
rc := s.ClientRuntime(cliFirstIP1)
src, host := rc.Info()
assert.Equal(t, defSrc, src)
assert.Equal(t, cliFirstHost1, host)
})
t.Run("update_storage", func(t *testing.T) {
added, removed := s.BatchUpdateBySource(defSrc, updatedClients)
require.Equal(t, len(updatedClients), added)
require.Equal(t, len(firstClients), removed)
require.Equal(t, len(updatedClients), s.SizeRuntime())
rc := s.ClientRuntime(cliUpdatedIP3)
src, host := rc.Info()
assert.Equal(t, defSrc, src)
assert.Equal(t, cliUpdatedHost3, host)
rc = s.ClientRuntime(cliFirstIP1)
assert.Nil(t, rc)
})
t.Run("remove_all", func(t *testing.T) {
added, removed := s.BatchUpdateBySource(defSrc, []*client.Runtime{})
require.Equal(t, 0, added)
require.Equal(t, len(updatedClients), removed)
require.Equal(t, 0, s.SizeRuntime())
})
}

View file

@ -47,9 +47,6 @@ type clientsContainer struct {
// storage stores information about persistent clients.
storage *client.Storage
// runtimeIndex stores information about runtime clients.
runtimeIndex *client.RuntimeIndex
// dhcp is the DHCP service implementation.
dhcp DHCP
@ -105,8 +102,6 @@ func (clients *clientsContainer) Init(
return errors.Error("clients container already initialized")
}
clients.runtimeIndex = client.NewRuntimeIndex()
clients.storage = client.NewStorage(&client.Config{
AllowedTags: clientTags,
})
@ -358,7 +353,7 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source)
return client.SourcePersistent
}
rc := clients.runtimeIndex.Client(ip)
rc := clients.storage.ClientRuntime(ip)
if rc != nil {
src, _ = rc.Info()
}
@ -539,22 +534,9 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent,
return clients.storage.FindByMAC(foundMAC)
}
// runtimeClient returns a runtime client from internal index. Note that it
// doesn't include DHCP clients.
func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime) {
if ip == (netip.Addr{}) {
return nil
}
clients.lock.Lock()
defer clients.lock.Unlock()
return clients.runtimeIndex.Client(ip)
}
// findRuntimeClient finds a runtime client by their IP.
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime) {
rc = clients.runtimeClient(ip)
rc = clients.storage.ClientRuntime(ip)
host := clients.dhcp.HostByIP(ip)
if host != "" {
@ -580,20 +562,11 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
return
}
rc := clients.runtimeIndex.Client(ip)
if rc == nil {
// Create a RuntimeClient implicitly so that we don't do this check
// again.
rc = client.NewRuntime(ip)
clients.runtimeIndex.Add(rc)
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
} else {
host, _ := rc.Info()
log.Debug("clients: set whois info for runtime client %s: %+v", host, wi)
}
rc := client.NewRuntime(ip)
rc.SetWHOIS(wi)
clients.storage.UpdateRuntime(rc)
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
}
// addHost adds a new IP-hostname pairing. The priorities of the sources are
@ -644,26 +617,20 @@ func (clients *clientsContainer) addHostLocked(
host string,
src client.Source,
) (ok bool) {
rc := clients.runtimeIndex.Client(ip)
if rc == nil {
if src < client.SourceDHCP {
if clients.dhcp.HostByIP(ip) != "" {
return false
}
}
rc = client.NewRuntime(ip)
clients.runtimeIndex.Add(rc)
rc := client.NewRuntime(ip)
rc.SetInfo(src, []string{host})
if dhcpHost := clients.dhcp.HostByIP(ip); dhcpHost != "" {
rc.SetInfo(client.SourceDHCP, []string{dhcpHost})
}
rc.SetInfo(src, []string{host})
clients.storage.UpdateRuntime(rc)
log.Debug(
"clients: adding client info %s -> %q %q [%d]",
ip,
src,
host,
clients.runtimeIndex.Size(),
clients.storage.SizeRuntime(),
)
return true
@ -675,23 +642,22 @@ func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorag
clients.lock.Lock()
defer clients.lock.Unlock()
deleted := clients.runtimeIndex.DeleteBySource(client.SourceHostsFile)
log.Debug("clients: removed %d client aliases from system hosts file", deleted)
added := 0
var rcs []*client.Runtime
hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) {
// Only the first name of the first record is considered a canonical
// hostname for the IP address.
//
// TODO(e.burkov): Consider using all the names from all the records.
if clients.addHostLocked(addr, names[0], client.SourceHostsFile) {
added++
}
rc := client.NewRuntime(addr)
rc.SetInfo(client.SourceHostsFile, []string{names[0]})
rcs = append(rcs, rc)
return true
})
log.Debug("clients: added %d client aliases from system hosts file", added)
added, removed := clients.storage.BatchUpdateBySource(client.SourceHostsFile, rcs)
log.Debug("clients: added %d, removed %d client aliases from system hosts file", added, removed)
}
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
@ -715,17 +681,16 @@ func (clients *clientsContainer) addFromSystemARP() {
clients.lock.Lock()
defer clients.lock.Unlock()
deleted := clients.runtimeIndex.DeleteBySource(client.SourceARP)
log.Debug("clients: removed %d client aliases from arp neighborhood", deleted)
added := 0
var rcs []*client.Runtime
for _, n := range ns {
if clients.addHostLocked(n.IP, n.Name, client.SourceARP) {
added++
}
rc := client.NewRuntime(n.IP)
rc.SetInfo(client.SourceARP, []string{n.Name})
rcs = append(rcs, rc)
}
log.Debug("clients: added %d client aliases from arp neighborhood", added)
added, removed := clients.storage.BatchUpdateBySource(client.SourceARP, rcs)
log.Debug("clients: added %d, removed %d client aliases from arp neighborhood", added, removed)
}
// close gracefully closes all the client-specific upstream configurations of

View file

@ -240,7 +240,7 @@ func TestClientsWHOIS(t *testing.T) {
t.Run("new_client", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.255")
clients.setWHOISInfo(ip, whois)
rc := clients.runtimeIndex.Client(ip)
rc := clients.storage.ClientRuntime(ip)
require.NotNil(t, rc)
assert.Equal(t, whois, rc.WHOIS())
@ -252,7 +252,7 @@ func TestClientsWHOIS(t *testing.T) {
assert.True(t, ok)
clients.setWHOISInfo(ip, whois)
rc := clients.runtimeIndex.Client(ip)
rc := clients.storage.ClientRuntime(ip)
require.NotNil(t, rc)
assert.Equal(t, whois, rc.WHOIS())
@ -269,7 +269,7 @@ func TestClientsWHOIS(t *testing.T) {
require.NoError(t, err)
clients.setWHOISInfo(ip, whois)
rc := clients.runtimeIndex.Client(ip)
rc := clients.storage.ClientRuntime(ip)
require.Nil(t, rc)
assert.True(t, clients.storage.RemoveByName("client1"))

View file

@ -103,7 +103,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
return true
})
clients.runtimeIndex.Range(func(rc *client.Runtime) (cont bool) {
clients.storage.RangeRuntime(func(rc *client.Runtime) (cont bool) {
src, host := rc.Info()
cj := runtimeClientJSON{
WHOIS: whoisOrEmpty(rc),

View file

@ -354,6 +354,19 @@ type Info struct {
Orgname string `json:"orgname,omitempty"`
}
// Clone returns a deep copy of the WHOIS info.
func (i *Info) Clone() (c *Info) {
if i == nil {
return nil
}
return &Info{
City: i.City,
Country: i.Country,
Orgname: i.Orgname,
}
}
// cacheItem represents an item that we will store in the cache.
type cacheItem struct {
// expiry is the time when cacheItem will expire.