all: imp code

This commit is contained in:
Stanislav Chzhen 2024-09-09 21:13:05 +03:00
parent 2d2b8e0216
commit fafd7cbb52
9 changed files with 260 additions and 809 deletions

View file

@ -119,8 +119,8 @@ func (r *Runtime) Info() (cs Source, host string) {
return cs, info[0]
}
// SetInfo sets a host as a client information from the cs.
func (r *Runtime) SetInfo(cs Source, hosts []string) {
// setInfo sets a host as a client information from the cs.
func (r *Runtime) setInfo(cs Source, hosts []string) {
// TODO(s.chzhen): Use contract where hosts must contain non-empty host.
if len(hosts) == 1 && hosts[0] == "" {
hosts = []string{}
@ -138,13 +138,13 @@ func (r *Runtime) SetInfo(cs Source, hosts []string) {
}
}
// WHOIS returns a WHOIS client information.
// WHOIS returns a copy of WHOIS client information.
func (r *Runtime) WHOIS() (info *whois.Info) {
return r.whois
return r.whois.Clone()
}
// SetWHOIS sets a WHOIS client information. info must be non-nil.
func (r *Runtime) SetWHOIS(info *whois.Info) {
// setWHOIS sets a WHOIS client information. info must be non-nil.
func (r *Runtime) setWHOIS(info *whois.Info) {
r.whois = info
}
@ -178,8 +178,8 @@ func (r *Runtime) Addr() (ip netip.Addr) {
return r.ip
}
// Clone returns a deep copy of the runtime client.
func (r *Runtime) Clone() (c *Runtime) {
// clone returns a deep copy of the runtime client.
func (r *Runtime) clone() (c *Runtime) {
return &Runtime{
ip: r.ip,
whois: r.whois.Clone(),

View file

@ -2,39 +2,34 @@ package client
import "net/netip"
// RuntimeIndex stores information about runtime clients.
type RuntimeIndex struct {
// runtimeIndex stores information about runtime clients.
type runtimeIndex struct {
// index maps IP address to runtime client.
index map[netip.Addr]*Runtime
}
// NewRuntimeIndex returns initialized runtime index.
func NewRuntimeIndex() (ri *RuntimeIndex) {
return &RuntimeIndex{
// newRuntimeIndex returns initialized runtime index.
func newRuntimeIndex() (ri *runtimeIndex) {
return &runtimeIndex{
index: map[netip.Addr]*Runtime{},
}
}
// Client returns the saved runtime client by ip. If no such client exists,
// client returns the saved runtime client by ip. If no such client exists,
// returns nil.
func (ri *RuntimeIndex) Client(ip netip.Addr) (rc *Runtime) {
func (ri *runtimeIndex) client(ip netip.Addr) (rc *Runtime) {
return ri.index[ip]
}
// Add saves the runtime client in the index. IP address of a client must be
// add saves the runtime client in the index. IP address of a client must be
// unique. See [Runtime.Client]. rc must not be nil.
func (ri *RuntimeIndex) Add(rc *Runtime) {
func (ri *runtimeIndex) add(rc *Runtime) {
ip := rc.Addr()
ri.index[ip] = rc
}
// Size returns the number of the runtime clients.
func (ri *RuntimeIndex) Size() (n int) {
return len(ri.index)
}
// Range calls f for each runtime client in an undefined order.
func (ri *RuntimeIndex) Range(f func(rc *Runtime) (cont bool)) {
// rangeF calls f for each runtime client in an undefined order.
func (ri *runtimeIndex) rangeF(f func(rc *Runtime) (cont bool)) {
for _, rc := range ri.index {
if !f(rc) {
return
@ -42,17 +37,31 @@ func (ri *RuntimeIndex) Range(f func(rc *Runtime) (cont bool)) {
}
}
// Delete removes the runtime client by ip.
func (ri *RuntimeIndex) Delete(ip netip.Addr) {
delete(ri.index, ip)
// setInfo sets the client information from cs for runtime client stored by ip.
// If no such client exists, it creates one.
func (ri *runtimeIndex) setInfo(ip netip.Addr, cs Source, hosts []string) (rc *Runtime) {
rc = ri.index[ip]
if rc == nil {
rc = NewRuntime(ip)
ri.add(rc)
}
// DeleteBySource removes all runtime clients that have information only from
// the specified source and returns the number of removed clients.
func (ri *RuntimeIndex) DeleteBySource(src Source) (n int) {
for ip, rc := range ri.index {
rc.unset(src)
rc.setInfo(cs, hosts)
return rc
}
// clearSource removes information from the specified source from all clients.
func (ri *runtimeIndex) clearSource(src Source) {
for _, rc := range ri.index {
rc.unset(src)
}
}
// removeEmpty removes empty runtime clients and returns the number of removed
// clients.
func (ri *runtimeIndex) removeEmpty() (n int) {
for ip, rc := range ri.index {
if rc.isEmpty() {
delete(ri.index, ip)
n++

View file

@ -1,85 +0,0 @@
package client_test
import (
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/stretchr/testify/assert"
)
func TestRuntimeIndex(t *testing.T) {
const cliSrc = client.SourceARP
var (
ip1 = netip.MustParseAddr("1.1.1.1")
ip2 = netip.MustParseAddr("2.2.2.2")
ip3 = netip.MustParseAddr("3.3.3.3")
)
ri := client.NewRuntimeIndex()
currentSize := 0
testCases := []struct {
ip netip.Addr
name string
hosts []string
src client.Source
}{{
src: cliSrc,
ip: ip1,
name: "1",
hosts: []string{"host1"},
}, {
src: cliSrc,
ip: ip2,
name: "2",
hosts: []string{"host2"},
}, {
src: cliSrc,
ip: ip3,
name: "3",
hosts: []string{"host3"},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
rc := client.NewRuntime(tc.ip)
rc.SetInfo(tc.src, tc.hosts)
ri.Add(rc)
currentSize++
got := ri.Client(tc.ip)
assert.Equal(t, rc, got)
})
}
t.Run("size", func(t *testing.T) {
assert.Equal(t, currentSize, ri.Size())
})
t.Run("range", func(t *testing.T) {
s := 0
ri.Range(func(rc *client.Runtime) (cont bool) {
s++
return true
})
assert.Equal(t, currentSize, s)
})
t.Run("delete", func(t *testing.T) {
ri.Delete(ip1)
currentSize--
assert.Equal(t, currentSize, ri.Size())
})
t.Run("delete_by_src", func(t *testing.T) {
assert.Equal(t, currentSize, ri.DeleteBySource(cliSrc))
assert.Equal(t, 0, ri.Size())
})
}

View file

@ -2,10 +2,10 @@ package client
import (
"cmp"
"context"
"fmt"
"net"
"net/netip"
"slices"
"sync"
"time"
@ -24,8 +24,8 @@ type DHCP interface {
Leases() (leases []*dhcpsvc.Lease)
// HostByIP returns the hostname of the DHCP client with the given IP
// address. The address will be netip.Addr{} if there is no such client,
// due to an assumption that a DHCP client must always have a hostname.
// address. host will be empty if there is no such client, due to an
// assumption that a DHCP client must always have a hostname.
HostByIP(ip netip.Addr) (host string)
// MACByIP returns the MAC address for the given IP address leased. It
@ -34,31 +34,47 @@ type DHCP interface {
MACByIP(ip netip.Addr) (mac net.HardwareAddr)
}
// emptyDHCP is the empty [DHCP] implementation that does nothing.
type emptyDHCP struct{}
// type check
var _ DHCP = emptyDHCP{}
func (emptyDHCP) Leases() (_ []*dhcpsvc.Lease) { return nil }
// Leases implements the [DHCP] interface for emptyDHCP.
func (emptyDHCP) Leases() (leases []*dhcpsvc.Lease) { return nil }
func (emptyDHCP) HostByIP(_ netip.Addr) (_ string) { return "" }
// HostByIP implements the [DHCP] interface for emptyDHCP.
func (emptyDHCP) HostByIP(_ netip.Addr) (host string) { return "" }
func (emptyDHCP) MACByIP(_ netip.Addr) (_ net.HardwareAddr) { return nil }
// MACByIP implements the [DHCP] interface for emptyDHCP.
func (emptyDHCP) MACByIP(_ netip.Addr) (mac net.HardwareAddr) { return nil }
// HostsContainer is an interface for receiving updates to the system hosts
// file.
type HostsContainer interface {
Upd() (updates <-chan *hostsfile.DefaultStorage)
}
// Config is the client storage configuration structure.
type Config struct {
// DHCP is used to update [SourceDHCP] runtime client information.
DHCP DHCP
// EtcHosts is used to update [SourceHostsFile] runtime client information.
EtcHosts HostsContainer
// ARPDB is used to update [SourceARP] runtime client information.
ARPDB arpdb.Interface
// AllowedTags is a list of all allowed client tags.
AllowedTags []string
// InitialClients is a list of persistent clients parsed from the
// configuration file. Each client must not be nil.
InitialClients []*Persistent
// ARPClientsUpdatePeriod defines how often [SourceARP] runtime client
// information is updated.
ARPClientsUpdatePeriod time.Duration
}
@ -74,11 +90,22 @@ type Storage struct {
index *index
// runtimeIndex contains information about runtime clients.
runtimeIndex *RuntimeIndex
runtimeIndex *runtimeIndex
// dhcp is used to update [SourceDHCP] runtime client information.
dhcp DHCP
// etcHosts is used to update [SourceHostsFile] runtime client information.
etcHosts HostsContainer
// arpDB is used to update [SourceARP] runtime client information.
arpDB arpdb.Interface
// done is the shutdown signaling channel.
done chan struct{}
// arpClientsUpdatePeriod defines how often [SourceARP] runtime client
// information is updated. It must be greater than zero.
arpClientsUpdatePeriod time.Duration
}
@ -89,11 +116,12 @@ func NewStorage(conf *Config) (s *Storage, err error) {
allowedTags: allowedTags,
mu: &sync.Mutex{},
index: newIndex(),
runtimeIndex: NewRuntimeIndex(),
runtimeIndex: newRuntimeIndex(),
dhcp: cmp.Or(conf.DHCP, DHCP(emptyDHCP{})),
etcHosts: conf.EtcHosts,
arpDB: conf.ARPDB,
arpClientsUpdatePeriod: conf.ARPClientsUpdatePeriod,
done: make(chan struct{}),
}
for i, p := range conf.InitialClients {
@ -107,9 +135,18 @@ func NewStorage(conf *Config) (s *Storage, err error) {
}
// Start starts the goroutines for updating the runtime client information.
func (s *Storage) Start() {
func (s *Storage) Start(_ context.Context) (err error) {
go s.periodicARPUpdate()
go s.handleHostsUpdates()
return nil
}
// Shutdown gracefully stops the client storage.
func (s *Storage) Shutdown(_ context.Context) (err error) {
close(s.done)
return nil
}
// periodicARPUpdate periodically reloads runtime clients from ARP. It is
@ -117,9 +154,15 @@ func (s *Storage) Start() {
func (s *Storage) periodicARPUpdate() {
defer log.OnPanic("storage")
t := time.NewTicker(s.arpClientsUpdatePeriod)
for {
select {
case <-t.C:
s.ReloadARP()
time.Sleep(s.arpClientsUpdatePeriod)
case <-s.done:
return
}
}
}
@ -133,6 +176,9 @@ func (s *Storage) ReloadARP() {
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
// command.
func (s *Storage) addFromSystemARP() {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.arpDB.Refresh(); err != nil {
s.arpDB = arpdb.Empty{}
log.Error("refreshing arp container: %s", err)
@ -147,16 +193,16 @@ func (s *Storage) addFromSystemARP() {
return
}
var rcs []*Runtime
for _, n := range ns {
rc := NewRuntime(n.IP)
rc.SetInfo(SourceARP, []string{n.Name})
src := SourceARP
s.runtimeIndex.clearSource(src)
rcs = append(rcs, rc)
for _, n := range ns {
s.runtimeIndex.setInfo(n.IP, src, []string{n.Name})
}
added, removed := s.BatchUpdateBySource(SourceARP, rcs)
log.Debug("storage: added %d, removed %d client aliases from arp neighborhood", added, removed)
removed := s.runtimeIndex.removeEmpty()
log.Debug("storage: added %d, removed %d client aliases from arp neighborhood", len(ns), removed)
}
// handleHostsUpdates receives the updates from the hosts container and adds
@ -168,29 +214,38 @@ func (s *Storage) handleHostsUpdates() {
defer log.OnPanic("storage")
for upd := range s.etcHosts.Upd() {
for {
select {
case upd := <-s.etcHosts.Upd():
s.addFromHostsFile(upd)
case <-s.done:
return
}
}
}
// addFromHostsFile fills the client-hostname pairing index from the system's
// hosts files.
func (s *Storage) addFromHostsFile(hosts *hostsfile.DefaultStorage) {
var rcs []*Runtime
s.mu.Lock()
defer s.mu.Unlock()
src := SourceHostsFile
s.runtimeIndex.clearSource(src)
added := 0
hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) {
// Only the first name of the first record is considered a canonical
// hostname for the IP address.
//
// TODO(e.burkov): Consider using all the names from all the records.
rc := NewRuntime(addr)
rc.SetInfo(SourceHostsFile, []string{names[0]})
rcs = append(rcs, rc)
s.runtimeIndex.setInfo(addr, src, []string{names[0]})
added++
return true
})
added, removed := s.BatchUpdateBySource(SourceHostsFile, rcs)
removed := s.runtimeIndex.removeEmpty()
log.Debug("storage: added %d, removed %d client aliases from system hosts file", added, removed)
}
@ -204,10 +259,11 @@ func (s *Storage) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {
return
}
s.mu.Lock()
defer s.mu.Unlock()
if host != "" {
rc := NewRuntime(ip)
rc.SetInfo(SourceRDNS, []string{host})
s.UpdateRuntime(rc)
s.runtimeIndex.setInfo(ip, SourceRDNS, []string{host})
}
if info != nil {
@ -215,18 +271,44 @@ func (s *Storage) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {
}
}
// UpdateDHCP updates [SourceDHCP] runtime client information.
func (s *Storage) UpdateDHCP() {
if s.dhcp == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
src := SourceDHCP
s.runtimeIndex.clearSource(src)
added := 0
for _, l := range s.dhcp.Leases() {
s.runtimeIndex.setInfo(l.IP, src, []string{l.Hostname})
added++
}
removed := s.runtimeIndex.removeEmpty()
log.Debug("storage: added %d, removed %d client aliases from dhcp", added, removed)
}
// setWHOISInfo sets the WHOIS information for a runtime client.
func (s *Storage) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
_, ok := s.Find(ip.String())
_, ok := s.index.findByIP(ip)
if ok {
log.Debug("storage: client for %s is already created, ignore whois info", ip)
return
}
rc := NewRuntime(ip)
rc.SetWHOIS(wi)
s.UpdateRuntime(rc)
rc := s.runtimeIndex.client(ip)
if rc == nil {
rc = NewRuntime(ip)
s.runtimeIndex.add(rc)
}
rc.setWHOIS(wi)
log.Debug("storage: set whois info for runtime client with ip %s: %+v", ip, wi)
}
@ -422,9 +504,9 @@ func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) {
s.mu.Lock()
defer s.mu.Unlock()
rc = s.runtimeIndex.Client(ip)
rc = s.runtimeIndex.client(ip)
if rc != nil {
return rc
return rc.clone()
}
host := s.dhcp.HostByIP(ip)
@ -432,87 +514,9 @@ func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) {
return nil
}
// TODO(s.chzhen): Update runtime index.
rc = NewRuntime(ip)
rc.SetInfo(SourceDHCP, []string{host})
rc = s.runtimeIndex.setInfo(ip, SourceDHCP, []string{host})
return rc
}
// UpdateRuntime updates the stored runtime client with information from rc. If
// no such client exists, saves the copy of rc in storage. rc must not be nil.
func (s *Storage) UpdateRuntime(rc *Runtime) (added bool) {
s.mu.Lock()
defer s.mu.Unlock()
return s.updateRuntimeLocked(rc)
}
// updateRuntimeLocked updates the stored runtime client with information from
// rc. rc must not be nil. Storage.mu is expected to be locked.
func (s *Storage) updateRuntimeLocked(rc *Runtime) (added bool) {
stored := s.runtimeIndex.Client(rc.ip)
if stored == nil {
s.runtimeIndex.Add(rc.Clone())
return true
}
if rc.whois != nil {
stored.whois = rc.whois.Clone()
}
if rc.arp != nil {
stored.arp = slices.Clone(rc.arp)
}
if rc.rdns != nil {
stored.rdns = slices.Clone(rc.rdns)
}
if rc.dhcp != nil {
stored.dhcp = slices.Clone(rc.dhcp)
}
if rc.hostsFile != nil {
stored.hostsFile = slices.Clone(rc.hostsFile)
}
return false
}
// BatchUpdateBySource updates the stored runtime clients information from the
// specified source and returns the number of added and removed clients.
func (s *Storage) BatchUpdateBySource(src Source, rcs []*Runtime) (added, removed int) {
s.mu.Lock()
defer s.mu.Unlock()
for _, rc := range s.runtimeIndex.index {
rc.unset(src)
}
for _, rc := range rcs {
if s.updateRuntimeLocked(rc) {
added++
}
}
for ip, rc := range s.runtimeIndex.index {
if rc.isEmpty() {
delete(s.runtimeIndex.index, ip)
removed++
}
}
return added, removed
}
// SizeRuntime returns the number of the runtime clients.
func (s *Storage) SizeRuntime() (n int) {
s.mu.Lock()
defer s.mu.Unlock()
return s.runtimeIndex.Size()
return rc.clone()
}
// RangeRuntime calls f for each runtime client in an undefined order.
@ -520,16 +524,5 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
s.mu.Lock()
defer s.mu.Unlock()
s.runtimeIndex.Range(f)
}
// DeleteBySource removes all runtime clients that have information only from
// the specified source and returns the number of removed clients.
//
// TODO(s.chzhen): Use it.
func (s *Storage) DeleteBySource(src Source) (n int) {
s.mu.Lock()
defer s.mu.Unlock()
return s.runtimeIndex.DeleteBySource(src)
s.runtimeIndex.rangeF(f)
}

View file

@ -19,6 +19,8 @@ import (
"github.com/stretchr/testify/require"
)
// testHostsContainer is a mock implementation of the [client.HostsContainer]
// interface.
type testHostsContainer struct {
onUpd func() (updates <-chan *hostsfile.DefaultStorage)
}
@ -26,6 +28,7 @@ type testHostsContainer struct {
// type check
var _ client.HostsContainer = (*testHostsContainer)(nil)
// Upd implements the [client.HostsContainer] interface for *testHostsContainer.
func (c *testHostsContainer) Upd() (updates <-chan *hostsfile.DefaultStorage) {
return c.onUpd()
}
@ -41,33 +44,42 @@ type Interface interface {
Neighbors() (ns []arpdb.Neighbor)
}
// testARP is a mock implementation of the [arpdb.Interface].
type testARP struct {
onRefresh func() (err error)
onNeighbors func() (ns []arpdb.Neighbor)
}
// type check
var _ arpdb.Interface = (*testARP)(nil)
// Refresh implements the [arpdb.Interface] interface for *testARP.
func (c *testARP) Refresh() (err error) {
return c.onRefresh()
}
// Neighbors implements the [arpdb.Interface] interface for *testARP.
func (c *testARP) Neighbors() (ns []arpdb.Neighbor) {
return c.onNeighbors()
}
// testDHCP is a mock implementation of the [client.DHCP].
type testDHCP struct {
OnLeases func() (leases []*dhcpsvc.Lease)
OnHostBy func(ip netip.Addr) (host string)
OnMACBy func(ip netip.Addr) (mac net.HardwareAddr)
}
// Lease implements the [DHCP] interface for testDHCP.
// type check
var _ client.DHCP = (*testDHCP)(nil)
// Lease implements the [client.DHCP] interface for *testDHCP.
func (t *testDHCP) Leases() (leases []*dhcpsvc.Lease) { return t.OnLeases() }
// HostByIP implements the [DHCP] interface for testDHCP.
// HostByIP implements the [client.DHCP] interface for *testDHCP.
func (t *testDHCP) HostByIP(ip netip.Addr) (host string) { return t.OnHostBy(ip) }
// MACByIP implements the [DHCP] interface for testDHCP.
// MACByIP implements the [client.DHCP] interface for *testDHCP.
func (t *testDHCP) MACByIP(ip netip.Addr) (mac net.HardwareAddr) { return t.OnMACBy(ip) }
// compareRuntimeInfo is a helper function that returns true if the runtime
@ -99,10 +111,16 @@ func TestStorage_Add_hostsfile(t *testing.T) {
storage, err := client.NewStorage(&client.Config{
EtcHosts: h,
ARPClientsUpdatePeriod: testTimeout / 10,
})
require.NoError(t, err)
storage.Start()
err = storage.Start(testutil.ContextWithTimeout(t, testTimeout))
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return storage.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
})
t.Run("add_hosts", func(t *testing.T) {
var s *hostsfile.DefaultStorage
@ -184,7 +202,12 @@ func TestStorage_Add_arp(t *testing.T) {
})
require.NoError(t, err)
storage.Start()
err = storage.Start(testutil.ContextWithTimeout(t, testTimeout))
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return storage.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
})
t.Run("add_hosts", func(t *testing.T) {
func() {
@ -292,11 +315,19 @@ func TestStorage_Add_whois(t *testing.T) {
func TestClientsDHCP(t *testing.T) {
var (
cliIP1 = netip.MustParseAddr("1.1.1.1")
cliName1 = "client_one"
cliName1 = "one.dhcp"
cliIP2 = netip.MustParseAddr("2.2.2.2")
cliMAC2 = mustParseMAC("22:22:22:22:22:22")
cliName2 = "two.dhcp"
cliIP3 = netip.MustParseAddr("3.3.3.3")
cliMAC3 = mustParseMAC("33:33:33:33:33:33")
cliName3 = "three.dhcp"
prsCliIP = netip.MustParseAddr("4.3.2.1")
prsCliMAC = mustParseMAC("AA:AA:AA:AA:AA:AA")
prsCliName = "persitent_client"
prsCliName = "persitent.dhcp"
)
ipToHost := map[netip.Addr]string{
@ -306,8 +337,20 @@ func TestClientsDHCP(t *testing.T) {
prsCliIP: prsCliMAC,
}
leases := []*dhcpsvc.Lease{{
IP: cliIP2,
Hostname: cliName2,
HWAddr: cliMAC2,
}, {
IP: cliIP3,
Hostname: cliName3,
HWAddr: cliMAC3,
}}
d := &testDHCP{
OnLeases: func() (leases []*dhcpsvc.Lease) { panic("not implemented") },
OnLeases: func() (ls []*dhcpsvc.Lease) {
return leases
},
OnHostBy: func(ip netip.Addr) (host string) {
return ipToHost[ip]
},
@ -341,6 +384,34 @@ func TestClientsDHCP(t *testing.T) {
assert.Equal(t, prsCliName, prsCli.Name)
})
t.Run("leases", func(t *testing.T) {
delete(ipToHost, cliIP1)
storage.UpdateDHCP()
cli1 := storage.ClientRuntime(cliIP1)
require.Nil(t, cli1)
for i, l := range leases {
cli := storage.ClientRuntime(l.IP)
require.NotNil(t, cli)
src, host := cli.Info()
assert.Equal(t, client.SourceDHCP, src)
assert.Equal(t, leases[i].Hostname, host)
}
})
t.Run("range", func(t *testing.T) {
s := 0
storage.RangeRuntime(func(rc *client.Runtime) (cont bool) {
s++
return true
})
assert.Equal(t, len(leases), s)
})
}
func TestClientsAddExisting(t *testing.T) {
@ -439,14 +510,6 @@ func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) {
return s
}
// newRuntimeClient is a helper function that returns a new runtime client.
func newRuntimeClient(ip netip.Addr, source client.Source, host string) (rc *client.Runtime) {
rc = client.NewRuntime(ip)
rc.SetInfo(source, []string{host})
return rc
}
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
// error.
func mustParseMAC(s string) (mac net.HardwareAddr) {
@ -1037,159 +1100,3 @@ func TestStorage_RangeByName(t *testing.T) {
})
}
}
func TestStorage_UpdateRuntime(t *testing.T) {
const (
addedARP = "added_arp"
addedSecondARP = "added_arp"
updatedARP = "updated_arp"
cliCity = "City"
cliCountry = "Country"
cliOrgname = "Orgname"
)
var (
ip = netip.MustParseAddr("1.1.1.1")
ip2 = netip.MustParseAddr("2.2.2.2")
)
updated := client.NewRuntime(ip)
updated.SetInfo(client.SourceARP, []string{updatedARP})
info := &whois.Info{
City: cliCity,
Country: cliCountry,
Orgname: cliOrgname,
}
updated.SetWHOIS(info)
s, err := client.NewStorage(&client.Config{
AllowedTags: nil,
})
require.NoError(t, err)
t.Run("add_arp_client", func(t *testing.T) {
added := client.NewRuntime(ip)
added.SetInfo(client.SourceARP, []string{addedARP})
require.True(t, s.UpdateRuntime(added))
require.Equal(t, 1, s.SizeRuntime())
got := s.ClientRuntime(ip)
source, host := got.Info()
assert.Equal(t, client.SourceARP, source)
assert.Equal(t, addedARP, host)
})
t.Run("add_second_arp_client", func(t *testing.T) {
added := client.NewRuntime(ip2)
added.SetInfo(client.SourceARP, []string{addedSecondARP})
require.True(t, s.UpdateRuntime(added))
require.Equal(t, 2, s.SizeRuntime())
got := s.ClientRuntime(ip2)
source, host := got.Info()
assert.Equal(t, client.SourceARP, source)
assert.Equal(t, addedSecondARP, host)
})
t.Run("update_first_client", func(t *testing.T) {
require.False(t, s.UpdateRuntime(updated))
got := s.ClientRuntime(ip)
require.Equal(t, 2, s.SizeRuntime())
source, host := got.Info()
assert.Equal(t, client.SourceARP, source)
assert.Equal(t, updatedARP, host)
})
t.Run("remove_arp_info", func(t *testing.T) {
n := s.DeleteBySource(client.SourceARP)
require.Equal(t, 1, n)
require.Equal(t, 1, s.SizeRuntime())
got := s.ClientRuntime(ip)
source, _ := got.Info()
assert.Equal(t, client.SourceWHOIS, source)
assert.Equal(t, info, got.WHOIS())
})
t.Run("remove_whois_info", func(t *testing.T) {
n := s.DeleteBySource(client.SourceWHOIS)
require.Equal(t, 1, n)
require.Equal(t, 0, s.SizeRuntime())
})
}
func TestStorage_BatchUpdateBySource(t *testing.T) {
const (
defSrc = client.SourceARP
cliFirstHost1 = "host1"
cliFirstHost2 = "host2"
cliUpdatedHost3 = "host3"
cliUpdatedHost4 = "host4"
cliUpdatedHost5 = "host5"
)
var (
cliFirstIP1 = netip.MustParseAddr("1.1.1.1")
cliFirstIP2 = netip.MustParseAddr("2.2.2.2")
cliUpdatedIP3 = netip.MustParseAddr("3.3.3.3")
cliUpdatedIP4 = netip.MustParseAddr("4.4.4.4")
cliUpdatedIP5 = netip.MustParseAddr("5.5.5.5")
)
firstClients := []*client.Runtime{
newRuntimeClient(cliFirstIP1, defSrc, cliFirstHost1),
newRuntimeClient(cliFirstIP2, defSrc, cliFirstHost2),
}
updatedClients := []*client.Runtime{
newRuntimeClient(cliUpdatedIP3, defSrc, cliUpdatedHost3),
newRuntimeClient(cliUpdatedIP4, defSrc, cliUpdatedHost4),
newRuntimeClient(cliUpdatedIP5, defSrc, cliUpdatedHost5),
}
s, err := client.NewStorage(&client.Config{
AllowedTags: nil,
})
require.NoError(t, err)
t.Run("populate_storage_with_first_clients", func(t *testing.T) {
added, removed := s.BatchUpdateBySource(defSrc, firstClients)
require.Equal(t, len(firstClients), added)
require.Equal(t, 0, removed)
require.Equal(t, len(firstClients), s.SizeRuntime())
rc := s.ClientRuntime(cliFirstIP1)
src, host := rc.Info()
assert.Equal(t, defSrc, src)
assert.Equal(t, cliFirstHost1, host)
})
t.Run("update_storage", func(t *testing.T) {
added, removed := s.BatchUpdateBySource(defSrc, updatedClients)
require.Equal(t, len(updatedClients), added)
require.Equal(t, len(firstClients), removed)
require.Equal(t, len(updatedClients), s.SizeRuntime())
rc := s.ClientRuntime(cliUpdatedIP3)
src, host := rc.Info()
assert.Equal(t, defSrc, src)
assert.Equal(t, cliUpdatedHost3, host)
rc = s.ClientRuntime(cliFirstIP1)
assert.Nil(t, rc)
})
t.Run("remove_all", func(t *testing.T) {
added, removed := s.BatchUpdateBySource(defSrc, []*client.Runtime{})
require.Equal(t, 0, added)
require.Equal(t, len(updatedClients), removed)
require.Equal(t, 0, s.SizeRuntime())
})
}

View file

@ -1,6 +1,7 @@
package home
import (
"context"
"fmt"
"net/netip"
"slices"
@ -110,7 +111,7 @@ func (clients *clientsContainer) Init(
var webHandlersRegistered = false
// Start starts the clients container.
func (clients *clientsContainer) Start() {
func (clients *clientsContainer) Start(ctx context.Context) (err error) {
if clients.testing {
return
}
@ -120,7 +121,7 @@ func (clients *clientsContainer) Start() {
clients.registerWebHandlers()
}
clients.storage.Start()
return clients.storage.Start(ctx)
}
// clientObject is the YAML representation of a persistent client.

View file

@ -3,34 +3,14 @@ package home
import (
"net"
"net/netip"
"runtime"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type testDHCP struct {
OnLeases func() (leases []*dhcpsvc.Lease)
OnHostBy func(ip netip.Addr) (host string)
OnMACBy func(ip netip.Addr) (mac net.HardwareAddr)
}
// Lease implements the [DHCP] interface for testDHCP.
func (t *testDHCP) Leases() (leases []*dhcpsvc.Lease) { return t.OnLeases() }
// HostByIP implements the [DHCP] interface for testDHCP.
func (t *testDHCP) HostByIP(ip netip.Addr) (host string) { return t.OnHostBy(ip) }
// MACByIP implements the [DHCP] interface for testDHCP.
func (t *testDHCP) MACByIP(ip netip.Addr) (mac net.HardwareAddr) { return t.OnMACBy(ip) }
// newClientsContainer is a helper that creates a new clients container for
// tests.
func newClientsContainer(t *testing.T) (c *clientsContainer) {
@ -40,359 +20,11 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) {
testing: true,
}
dhcp := &testDHCP{
OnLeases: func() (leases []*dhcpsvc.Lease) { return nil },
OnHostBy: func(ip netip.Addr) (host string) { return "" },
OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) { return nil },
}
require.NoError(t, c.Init(nil, dhcp, nil, nil, &filtering.Config{}))
require.NoError(t, c.Init(nil, nil, nil, nil, &filtering.Config{}))
return c
}
// addHost adds a new IP-hostname pairing.
func (clients *clientsContainer) addHost(
ip netip.Addr,
host string,
src client.Source,
) (ok bool) {
rc := client.NewRuntime(ip)
rc.SetInfo(src, []string{host})
clients.storage.UpdateRuntime(rc)
return true
}
// setWHOISInfo sets the WHOIS information for a client.
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
_, ok := clients.storage.Find(ip.String())
if ok {
return
}
rc := client.NewRuntime(ip)
rc.SetWHOIS(wi)
clients.storage.UpdateRuntime(rc)
}
// clientSource checks if client with this IP address already exists and returns
// the highest priority client source.
func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source) {
_, ok := clients.storage.Find(ip.String())
if ok {
return client.SourcePersistent
}
rc := clients.storage.ClientRuntime(ip)
if rc != nil {
src, _ = rc.Info()
}
return src
}
func TestClients(t *testing.T) {
clients := newClientsContainer(t)
t.Run("add_success", func(t *testing.T) {
var (
cliNone = "1.2.3.4"
cli1 = "1.1.1.1"
cli2 = "2.2.2.2"
cli1IP = netip.MustParseAddr(cli1)
cli2IP = netip.MustParseAddr(cli2)
cliIPv6 = netip.MustParseAddr("1:2:3::4")
)
c := &client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{cli1IP, cliIPv6},
}
err := clients.storage.Add(c)
require.NoError(t, err)
c = &client.Persistent{
Name: "client2",
UID: client.MustNewUID(),
IPs: []netip.Addr{cli2IP},
}
err = clients.storage.Add(c)
require.NoError(t, err)
c, ok := clients.storage.Find(cli1)
require.True(t, ok)
assert.Equal(t, "client1", c.Name)
c, ok = clients.storage.Find("1:2:3::4")
require.True(t, ok)
assert.Equal(t, "client1", c.Name)
c, ok = clients.storage.Find(cli2)
require.True(t, ok)
assert.Equal(t, "client2", c.Name)
_, ok = clients.storage.Find(cliNone)
assert.False(t, ok)
assert.Equal(t, clients.clientSource(cli1IP), client.SourcePersistent)
assert.Equal(t, clients.clientSource(cli2IP), client.SourcePersistent)
})
t.Run("add_fail_name", func(t *testing.T) {
err := clients.storage.Add(&client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")},
})
require.Error(t, err)
})
t.Run("add_fail_ip", func(t *testing.T) {
err := clients.storage.Add(&client.Persistent{
Name: "client3",
UID: client.MustNewUID(),
})
require.Error(t, err)
})
t.Run("update_fail_ip", func(t *testing.T) {
err := clients.storage.Update("client1", &client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
})
assert.Error(t, err)
})
t.Run("update_success", func(t *testing.T) {
var (
cliOld = "1.1.1.1"
cliNew = "1.1.1.2"
cliNewIP = netip.MustParseAddr(cliNew)
)
prev, ok := clients.storage.FindByName("client1")
require.True(t, ok)
require.NotNil(t, prev)
err := clients.storage.Update("client1", &client.Persistent{
Name: "client1",
UID: prev.UID,
IPs: []netip.Addr{cliNewIP},
})
require.NoError(t, err)
_, ok = clients.storage.Find(cliOld)
assert.False(t, ok)
assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent)
prev, ok = clients.storage.FindByName("client1")
require.True(t, ok)
require.NotNil(t, prev)
err = clients.storage.Update("client1", &client.Persistent{
Name: "client1-renamed",
UID: prev.UID,
IPs: []netip.Addr{cliNewIP},
UseOwnSettings: true,
})
require.NoError(t, err)
c, ok := clients.storage.Find(cliNew)
require.True(t, ok)
assert.Equal(t, "client1-renamed", c.Name)
assert.True(t, c.UseOwnSettings)
nilCli, ok := clients.storage.FindByName("client1")
require.False(t, ok)
assert.Nil(t, nilCli)
require.Len(t, c.IDs(), 1)
assert.Equal(t, cliNewIP, c.IPs[0])
})
t.Run("del_success", func(t *testing.T) {
ok := clients.storage.RemoveByName("client1-renamed")
require.True(t, ok)
_, ok = clients.storage.Find("1.1.1.2")
assert.False(t, ok)
})
t.Run("del_fail", func(t *testing.T) {
ok := clients.storage.RemoveByName("client3")
assert.False(t, ok)
})
t.Run("addhost_success", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
ok := clients.addHost(ip, "host", client.SourceARP)
assert.True(t, ok)
ok = clients.addHost(ip, "host2", client.SourceARP)
assert.True(t, ok)
ok = clients.addHost(ip, "host3", client.SourceHostsFile)
assert.True(t, ok)
assert.Equal(t, clients.clientSource(ip), client.SourceHostsFile)
})
t.Run("dhcp_replaces_arp", func(t *testing.T) {
ip := netip.MustParseAddr("1.2.3.4")
ok := clients.addHost(ip, "from_arp", client.SourceARP)
assert.True(t, ok)
assert.Equal(t, clients.clientSource(ip), client.SourceARP)
ok = clients.addHost(ip, "from_dhcp", client.SourceDHCP)
assert.True(t, ok)
assert.Equal(t, clients.clientSource(ip), client.SourceDHCP)
})
t.Run("addhost_priority", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
ok := clients.addHost(ip, "host1", client.SourceRDNS)
assert.True(t, ok)
assert.Equal(t, client.SourceHostsFile, clients.clientSource(ip))
})
}
func TestClientsWHOIS(t *testing.T) {
clients := newClientsContainer(t)
whois := &whois.Info{
Country: "AU",
Orgname: "Example Org",
}
t.Run("new_client", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.255")
clients.setWHOISInfo(ip, whois)
rc := clients.storage.ClientRuntime(ip)
require.NotNil(t, rc)
assert.Equal(t, whois, rc.WHOIS())
})
t.Run("existing_auto-client", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
ok := clients.addHost(ip, "host", client.SourceRDNS)
assert.True(t, ok)
clients.setWHOISInfo(ip, whois)
rc := clients.storage.ClientRuntime(ip)
require.NotNil(t, rc)
assert.Equal(t, whois, rc.WHOIS())
})
t.Run("can't_set_manually-added", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.2")
err := clients.storage.Add(&client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")},
})
require.NoError(t, err)
clients.setWHOISInfo(ip, whois)
rc := clients.storage.ClientRuntime(ip)
require.Nil(t, rc)
assert.True(t, clients.storage.RemoveByName("client1"))
})
}
func TestClientsAddExisting(t *testing.T) {
clients := &clientsContainer{
testing: true,
}
// First, init a DHCP server with a single static lease.
config := &dhcpd.ServerConfig{
Enabled: true,
DataDir: t.TempDir(),
Conf4: dhcpd.V4ServerConf{
Enabled: true,
GatewayIP: netip.MustParseAddr("1.2.3.1"),
SubnetMask: netip.MustParseAddr("255.255.255.0"),
RangeStart: netip.MustParseAddr("1.2.3.2"),
RangeEnd: netip.MustParseAddr("1.2.3.10"),
},
}
dhcpServer, err := dhcpd.Create(config)
require.NoError(t, err)
require.NoError(t, clients.Init(nil, dhcpServer, nil, nil, &filtering.Config{}))
t.Run("simple", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
// Add a client.
err = clients.storage.Add(&client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")},
Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")},
MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}},
})
require.NoError(t, err)
// Now add an auto-client with the same IP.
ok := clients.addHost(ip, "test", client.SourceRDNS)
assert.True(t, ok)
})
t.Run("complicated", func(t *testing.T) {
// TODO(a.garipov): Properly decouple the DHCP server from the client
// storage.
if runtime.GOOS == "windows" {
t.Skip("skipping dhcp test on windows")
}
ip := netip.MustParseAddr("1.2.3.4")
err = dhcpServer.AddStaticLease(&dhcpsvc.Lease{
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: ip,
Hostname: "testhost",
Expiry: time.Now().Add(time.Hour),
})
require.NoError(t, err)
// Add a new client with the same IP as for a client with MAC.
err = clients.storage.Add(&client.Persistent{
Name: "client2",
UID: client.MustNewUID(),
IPs: []netip.Addr{ip},
})
require.NoError(t, err)
// Add a new client with the IP from the first client's IP range.
err = clients.storage.Add(&client.Persistent{
Name: "client3",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
})
require.NoError(t, err)
})
}
func TestClientsCustomUpstream(t *testing.T) {
clients := newClientsContainer(t)

View file

@ -103,6 +103,8 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
return true
})
clients.storage.UpdateDHCP()
clients.storage.RangeRuntime(func(rc *client.Runtime) (cont bool) {
src, host := rc.Info()
cj := runtimeClientJSON{
@ -117,18 +119,6 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
return true
})
// TODO(s.chzhen): Remove.
for _, l := range clients.dhcp.Leases() {
cj := runtimeClientJSON{
Name: l.Hostname,
Source: client.SourceDHCP,
IP: l.IP,
WHOIS: &whois.Info{},
}
data.RuntimeClients = append(data.RuntimeClients, cj)
}
data.Tags = clientTags
aghhttp.WriteJSONResponseOK(w, r, data)

View file

@ -1,6 +1,7 @@
package home
import (
"context"
"fmt"
"log/slog"
"net"
@ -457,9 +458,12 @@ func startDNSServer() error {
Context.filters.EnableFilters(false)
Context.clients.Start()
err := Context.clients.Start(context.TODO())
if err != nil {
return fmt.Errorf("couldn't start clients container: %w", err)
}
err := Context.dnsServer.Start()
err = Context.dnsServer.Start()
if err != nil {
return fmt.Errorf("couldn't start forwarding DNS server: %w", err)
}