Pull request 2273: AG-27492-client-storage-runtime-sources

Squashed commit of the following:

commit 3191224d6d
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Sep 26 18:20:04 2024 +0300

    client: imp tests

commit 6cc4ed53a2
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Sep 26 18:04:36 2024 +0300

    client: imp code

commit 79272b299a
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Sep 26 16:10:06 2024 +0300

    all: imp code

commit 0a001fffbe
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Sep 24 20:05:47 2024 +0300

    all: imp tests

commit 80f7e98d30
Merge: df7492e9d e338214ad
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Sep 24 19:10:13 2024 +0300

    Merge branch 'master' into AG-27492-client-storage-runtime-sources

commit df7492e9de
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Sep 24 19:06:37 2024 +0300

    all: imp code

commit 23896ae5a6
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Sep 19 21:04:34 2024 +0300

    client: fix typo

commit ba0ba2478c
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Sep 19 21:02:13 2024 +0300

    all: imp code

commit f7315be742
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Sep 12 14:35:38 2024 +0300

    home: imp code

commit f63d0e80fb
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Sep 12 14:15:49 2024 +0300

    all: imp code

commit 9feda414b6
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Sep 10 17:53:42 2024 +0300

    all: imp code

commit fafd7cbb52
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Sep 9 21:13:05 2024 +0300

    all: imp code

commit 2d2b8e0216
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Sep 5 20:55:10 2024 +0300

    client: add tests

commit 4d394e6f21
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Aug 29 20:40:38 2024 +0300

    all: client storage runtime sources
This commit is contained in:
Stanislav Chzhen 2024-09-30 14:17:42 +03:00
parent e338214ad5
commit d40de33316
14 changed files with 1008 additions and 1176 deletions

View file

@ -119,8 +119,8 @@ func (r *Runtime) Info() (cs Source, host string) {
return cs, info[0] return cs, info[0]
} }
// SetInfo sets a host as a client information from the cs. // setInfo sets a host as a client information from the cs.
func (r *Runtime) SetInfo(cs Source, hosts []string) { func (r *Runtime) setInfo(cs Source, hosts []string) {
// TODO(s.chzhen): Use contract where hosts must contain non-empty host. // TODO(s.chzhen): Use contract where hosts must contain non-empty host.
if len(hosts) == 1 && hosts[0] == "" { if len(hosts) == 1 && hosts[0] == "" {
hosts = []string{} hosts = []string{}
@ -138,13 +138,13 @@ func (r *Runtime) SetInfo(cs Source, hosts []string) {
} }
} }
// WHOIS returns a WHOIS client information. // WHOIS returns a copy of WHOIS client information.
func (r *Runtime) WHOIS() (info *whois.Info) { func (r *Runtime) WHOIS() (info *whois.Info) {
return r.whois return r.whois.Clone()
} }
// SetWHOIS sets a WHOIS client information. info must be non-nil. // setWHOIS sets a WHOIS client information. info must be non-nil.
func (r *Runtime) SetWHOIS(info *whois.Info) { func (r *Runtime) setWHOIS(info *whois.Info) {
r.whois = info r.whois = info
} }
@ -178,8 +178,8 @@ func (r *Runtime) Addr() (ip netip.Addr) {
return r.ip return r.ip
} }
// Clone returns a deep copy of the runtime client. // clone returns a deep copy of the runtime client.
func (r *Runtime) Clone() (c *Runtime) { func (r *Runtime) clone() (c *Runtime) {
return &Runtime{ return &Runtime{
ip: r.ip, ip: r.ip,
whois: r.whois.Clone(), whois: r.whois.Clone(),

View file

@ -13,7 +13,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
@ -136,7 +135,8 @@ type Persistent struct {
} }
// validate returns an error if persistent client information contains errors. // validate returns an error if persistent client information contains errors.
func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) { // allTags must be sorted.
func (c *Persistent) validate(allTags []string) (err error) {
switch { switch {
case c.Name == "": case c.Name == "":
return errors.Error("empty name") return errors.Error("empty name")
@ -157,7 +157,8 @@ func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) {
} }
for _, t := range c.Tags { for _, t := range c.Tags {
if !allTags.Has(t) { _, ok := slices.BinarySearch(allTags, t)
if !ok {
return fmt.Errorf("invalid tag: %q", t) return fmt.Errorf("invalid tag: %q", t)
} }
} }

View file

@ -1,11 +1,8 @@
package client package client
import ( import (
"net/netip"
"testing" "testing"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -125,69 +122,3 @@ func TestPersistent_EqualIDs(t *testing.T) {
}) })
} }
} }
func TestPersistent_Validate(t *testing.T) {
const (
allowedTag = "allowed_tag"
notAllowedTag = "not_allowed_tag"
)
allowedTags := container.NewMapSet(allowedTag)
testCases := []struct {
name string
cli *Persistent
wantErrMsg string
}{{
name: "success",
cli: &Persistent{
Name: "basic",
IPs: []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
},
UID: MustNewUID(),
},
wantErrMsg: "",
}, {
name: "empty_name",
cli: &Persistent{
Name: "",
},
wantErrMsg: "empty name",
}, {
name: "no_id",
cli: &Persistent{
Name: "no_id",
},
wantErrMsg: "id required",
}, {
name: "no_uid",
cli: &Persistent{
Name: "no_uid",
IPs: []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
},
},
wantErrMsg: "uid required",
}, {
name: "not_allowed_tag",
cli: &Persistent{
Name: "basic",
IPs: []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
},
UID: MustNewUID(),
Tags: []string{
notAllowedTag,
},
},
wantErrMsg: `invalid tag: "` + notAllowedTag + `"`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.cli.validate(allowedTags)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}

View file

@ -2,39 +2,34 @@ package client
import "net/netip" import "net/netip"
// RuntimeIndex stores information about runtime clients. // runtimeIndex stores information about runtime clients.
type RuntimeIndex struct { type runtimeIndex struct {
// index maps IP address to runtime client. // index maps IP address to runtime client.
index map[netip.Addr]*Runtime index map[netip.Addr]*Runtime
} }
// NewRuntimeIndex returns initialized runtime index. // newRuntimeIndex returns initialized runtime index.
func NewRuntimeIndex() (ri *RuntimeIndex) { func newRuntimeIndex() (ri *runtimeIndex) {
return &RuntimeIndex{ return &runtimeIndex{
index: map[netip.Addr]*Runtime{}, index: map[netip.Addr]*Runtime{},
} }
} }
// Client returns the saved runtime client by ip. If no such client exists, // client returns the saved runtime client by ip. If no such client exists,
// returns nil. // returns nil.
func (ri *RuntimeIndex) Client(ip netip.Addr) (rc *Runtime) { func (ri *runtimeIndex) client(ip netip.Addr) (rc *Runtime) {
return ri.index[ip] return ri.index[ip]
} }
// Add saves the runtime client in the index. IP address of a client must be // add saves the runtime client in the index. IP address of a client must be
// unique. See [Runtime.Client]. rc must not be nil. // unique. See [Runtime.Client]. rc must not be nil.
func (ri *RuntimeIndex) Add(rc *Runtime) { func (ri *runtimeIndex) add(rc *Runtime) {
ip := rc.Addr() ip := rc.Addr()
ri.index[ip] = rc ri.index[ip] = rc
} }
// Size returns the number of the runtime clients. // rangeClients calls f for each runtime client in an undefined order.
func (ri *RuntimeIndex) Size() (n int) { func (ri *runtimeIndex) rangeClients(f func(rc *Runtime) (cont bool)) {
return len(ri.index)
}
// Range calls f for each runtime client in an undefined order.
func (ri *RuntimeIndex) Range(f func(rc *Runtime) (cont bool)) {
for _, rc := range ri.index { for _, rc := range ri.index {
if !f(rc) { if !f(rc) {
return return
@ -42,17 +37,31 @@ func (ri *RuntimeIndex) Range(f func(rc *Runtime) (cont bool)) {
} }
} }
// Delete removes the runtime client by ip. // setInfo sets the client information from cs for runtime client stored by ip.
func (ri *RuntimeIndex) Delete(ip netip.Addr) { // If no such client exists, it creates one.
delete(ri.index, ip) func (ri *runtimeIndex) setInfo(ip netip.Addr, cs Source, hosts []string) (rc *Runtime) {
rc = ri.index[ip]
if rc == nil {
rc = NewRuntime(ip)
ri.add(rc)
}
rc.setInfo(cs, hosts)
return rc
} }
// DeleteBySource removes all runtime clients that have information only from // clearSource removes information from the specified source from all clients.
// the specified source and returns the number of removed clients. func (ri *runtimeIndex) clearSource(src Source) {
func (ri *RuntimeIndex) DeleteBySource(src Source) (n int) { for _, rc := range ri.index {
for ip, rc := range ri.index {
rc.unset(src) rc.unset(src)
}
}
// removeEmpty removes empty runtime clients and returns the number of removed
// clients.
func (ri *runtimeIndex) removeEmpty() (n int) {
for ip, rc := range ri.index {
if rc.isEmpty() { if rc.isEmpty() {
delete(ri.index, ip) delete(ri.index, ip)
n++ n++

View file

@ -1,85 +0,0 @@
package client_test
import (
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/stretchr/testify/assert"
)
func TestRuntimeIndex(t *testing.T) {
const cliSrc = client.SourceARP
var (
ip1 = netip.MustParseAddr("1.1.1.1")
ip2 = netip.MustParseAddr("2.2.2.2")
ip3 = netip.MustParseAddr("3.3.3.3")
)
ri := client.NewRuntimeIndex()
currentSize := 0
testCases := []struct {
ip netip.Addr
name string
hosts []string
src client.Source
}{{
src: cliSrc,
ip: ip1,
name: "1",
hosts: []string{"host1"},
}, {
src: cliSrc,
ip: ip2,
name: "2",
hosts: []string{"host2"},
}, {
src: cliSrc,
ip: ip3,
name: "3",
hosts: []string{"host3"},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
rc := client.NewRuntime(tc.ip)
rc.SetInfo(tc.src, tc.hosts)
ri.Add(rc)
currentSize++
got := ri.Client(tc.ip)
assert.Equal(t, rc, got)
})
}
t.Run("size", func(t *testing.T) {
assert.Equal(t, currentSize, ri.Size())
})
t.Run("range", func(t *testing.T) {
s := 0
ri.Range(func(rc *client.Runtime) (cont bool) {
s++
return true
})
assert.Equal(t, currentSize, s)
})
t.Run("delete", func(t *testing.T) {
ri.Delete(ip1)
currentSize--
assert.Equal(t, currentSize, ri.Size())
})
t.Run("delete_by_src", func(t *testing.T) {
assert.Equal(t, currentSize, ri.DeleteBySource(cliSrc))
assert.Equal(t, 0, ri.Size())
})
}

View file

@ -1,30 +1,113 @@
package client package client
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"slices" "slices"
"sync" "sync"
"time"
"github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/AdGuardHome/internal/arpdb"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
// Config is the client storage configuration structure. // allowedTags is the list of available client tags.
// var allowedTags = []string{
// TODO(s.chzhen): Expand. "device_audio",
type Config struct { "device_camera",
// AllowedTags is a list of all allowed client tags. "device_gameconsole",
AllowedTags []string "device_laptop",
"device_nas", // Network-attached Storage
"device_other",
"device_pc",
"device_phone",
"device_printer",
"device_securityalarm",
"device_tablet",
"device_tv",
"os_android",
"os_ios",
"os_linux",
"os_macos",
"os_other",
"os_windows",
"user_admin",
"user_child",
"user_regular",
}
// DHCP is an interface for accessing DHCP lease data the [Storage] needs.
type DHCP interface {
// Leases returns all the DHCP leases.
Leases() (leases []*dhcpsvc.Lease)
// HostByIP returns the hostname of the DHCP client with the given IP
// address. host will be empty if there is no such client, due to an
// assumption that a DHCP client must always have a hostname.
HostByIP(ip netip.Addr) (host string)
// MACByIP returns the MAC address for the given IP address leased. It
// returns nil if there is no such client, due to an assumption that a DHCP
// client must always have a MAC address.
MACByIP(ip netip.Addr) (mac net.HardwareAddr)
}
// EmptyDHCP is the empty [DHCP] implementation that does nothing.
type EmptyDHCP struct{}
// type check
var _ DHCP = EmptyDHCP{}
// Leases implements the [DHCP] interface for emptyDHCP.
func (EmptyDHCP) Leases() (leases []*dhcpsvc.Lease) { return nil }
// HostByIP implements the [DHCP] interface for emptyDHCP.
func (EmptyDHCP) HostByIP(_ netip.Addr) (host string) { return "" }
// MACByIP implements the [DHCP] interface for emptyDHCP.
func (EmptyDHCP) MACByIP(_ netip.Addr) (mac net.HardwareAddr) { return nil }
// HostsContainer is an interface for receiving updates to the system hosts
// file.
type HostsContainer interface {
Upd() (updates <-chan *hostsfile.DefaultStorage)
}
// StorageConfig is the client storage configuration structure.
type StorageConfig struct {
// DHCP is used to match IPs against MACs of persistent clients and update
// [SourceDHCP] runtime client information. It must not be nil.
DHCP DHCP
// EtcHosts is used to update [SourceHostsFile] runtime client information.
EtcHosts HostsContainer
// ARPDB is used to update [SourceARP] runtime client information.
ARPDB arpdb.Interface
// InitialClients is a list of persistent clients parsed from the
// configuration file. Each client must not be nil.
InitialClients []*Persistent
// ARPClientsUpdatePeriod defines how often [SourceARP] runtime client
// information is updated.
ARPClientsUpdatePeriod time.Duration
// RuntimeSourceDHCP specifies whether to update [SourceDHCP] information
// of runtime clients.
RuntimeSourceDHCP bool
} }
// Storage contains information about persistent and runtime clients. // Storage contains information about persistent and runtime clients.
type Storage struct { type Storage struct {
// allowedTags is a set of all allowed tags.
allowedTags *container.MapSet[string]
// mu protects indexes of persistent and runtime clients. // mu protects indexes of persistent and runtime clients.
mu *sync.Mutex mu *sync.Mutex
@ -32,19 +115,250 @@ type Storage struct {
index *index index *index
// runtimeIndex contains information about runtime clients. // runtimeIndex contains information about runtime clients.
runtimeIndex *RuntimeIndex runtimeIndex *runtimeIndex
// dhcp is used to update [SourceDHCP] runtime client information.
dhcp DHCP
// etcHosts is used to update [SourceHostsFile] runtime client information.
etcHosts HostsContainer
// arpDB is used to update [SourceARP] runtime client information.
arpDB arpdb.Interface
// done is the shutdown signaling channel.
done chan struct{}
// allowedTags is a sorted list of all allowed tags. It must not be
// modified after initialization.
//
// TODO(s.chzhen): Use custom type.
allowedTags []string
// arpClientsUpdatePeriod defines how often [SourceARP] runtime client
// information is updated. It must be greater than zero.
arpClientsUpdatePeriod time.Duration
// runtimeSourceDHCP specifies whether to update [SourceDHCP] information
// of runtime clients.
runtimeSourceDHCP bool
} }
// NewStorage returns initialized client storage. conf must not be nil. // NewStorage returns initialized client storage. conf must not be nil.
func NewStorage(conf *Config) (s *Storage) { func NewStorage(conf *StorageConfig) (s *Storage, err error) {
allowedTags := container.NewMapSet(conf.AllowedTags...) tags := slices.Clone(allowedTags)
slices.Sort(tags)
return &Storage{ s = &Storage{
allowedTags: allowedTags, allowedTags: tags,
mu: &sync.Mutex{}, mu: &sync.Mutex{},
index: newIndex(), index: newIndex(),
runtimeIndex: NewRuntimeIndex(), runtimeIndex: newRuntimeIndex(),
dhcp: conf.DHCP,
etcHosts: conf.EtcHosts,
arpDB: conf.ARPDB,
done: make(chan struct{}),
arpClientsUpdatePeriod: conf.ARPClientsUpdatePeriod,
runtimeSourceDHCP: conf.RuntimeSourceDHCP,
} }
for i, p := range conf.InitialClients {
err = s.Add(p)
if err != nil {
return nil, fmt.Errorf("adding client %q at index %d: %w", p.Name, i, err)
}
}
s.ReloadARP()
return s, nil
}
// Start starts the goroutines for updating the runtime client information.
//
// TODO(s.chzhen): Pass context.
func (s *Storage) Start(_ context.Context) (err error) {
go s.periodicARPUpdate()
go s.handleHostsUpdates()
return nil
}
// Shutdown gracefully stops the client storage.
//
// TODO(s.chzhen): Pass context.
func (s *Storage) Shutdown(_ context.Context) (err error) {
close(s.done)
return s.closeUpstreams()
}
// periodicARPUpdate periodically reloads runtime clients from ARP. It is
// intended to be used as a goroutine.
func (s *Storage) periodicARPUpdate() {
defer log.OnPanic("storage")
t := time.NewTicker(s.arpClientsUpdatePeriod)
for {
select {
case <-t.C:
s.ReloadARP()
case <-s.done:
return
}
}
}
// ReloadARP reloads runtime clients from ARP, if configured.
func (s *Storage) ReloadARP() {
if s.arpDB != nil {
s.addFromSystemARP()
}
}
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
// command.
func (s *Storage) addFromSystemARP() {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.arpDB.Refresh(); err != nil {
s.arpDB = arpdb.Empty{}
log.Error("refreshing arp container: %s", err)
return
}
ns := s.arpDB.Neighbors()
if len(ns) == 0 {
log.Debug("refreshing arp container: the update is empty")
return
}
src := SourceARP
s.runtimeIndex.clearSource(src)
for _, n := range ns {
s.runtimeIndex.setInfo(n.IP, src, []string{n.Name})
}
removed := s.runtimeIndex.removeEmpty()
log.Debug("storage: added %d, removed %d client aliases from arp neighborhood", len(ns), removed)
}
// handleHostsUpdates receives the updates from the hosts container and adds
// them to the clients storage. It is intended to be used as a goroutine.
func (s *Storage) handleHostsUpdates() {
if s.etcHosts == nil {
return
}
defer log.OnPanic("storage")
for {
select {
case upd, ok := <-s.etcHosts.Upd():
if !ok {
return
}
s.addFromHostsFile(upd)
case <-s.done:
return
}
}
}
// addFromHostsFile fills the client-hostname pairing index from the system's
// hosts files.
func (s *Storage) addFromHostsFile(hosts *hostsfile.DefaultStorage) {
s.mu.Lock()
defer s.mu.Unlock()
src := SourceHostsFile
s.runtimeIndex.clearSource(src)
added := 0
hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) {
// Only the first name of the first record is considered a canonical
// hostname for the IP address.
//
// TODO(e.burkov): Consider using all the names from all the records.
s.runtimeIndex.setInfo(addr, src, []string{names[0]})
added++
return true
})
removed := s.runtimeIndex.removeEmpty()
log.Debug("storage: added %d, removed %d client aliases from system hosts file", added, removed)
}
// type check
var _ AddressUpdater = (*Storage)(nil)
// UpdateAddress implements the [AddressUpdater] interface for *Storage
func (s *Storage) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {
// Common fast path optimization.
if host == "" && info == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
if host != "" {
s.runtimeIndex.setInfo(ip, SourceRDNS, []string{host})
}
if info != nil {
s.setWHOISInfo(ip, info)
}
}
// UpdateDHCP updates [SourceDHCP] runtime client information.
func (s *Storage) UpdateDHCP() {
if s.dhcp == nil || !s.runtimeSourceDHCP {
return
}
s.mu.Lock()
defer s.mu.Unlock()
src := SourceDHCP
s.runtimeIndex.clearSource(src)
added := 0
for _, l := range s.dhcp.Leases() {
s.runtimeIndex.setInfo(l.IP, src, []string{l.Hostname})
added++
}
removed := s.runtimeIndex.removeEmpty()
log.Debug("storage: added %d, removed %d client aliases from dhcp", added, removed)
}
// setWHOISInfo sets the WHOIS information for a runtime client.
func (s *Storage) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
_, ok := s.index.findByIP(ip)
if ok {
log.Debug("storage: client for %s is already created, ignore whois info", ip)
return
}
rc := s.runtimeIndex.client(ip)
if rc == nil {
rc = NewRuntime(ip)
s.runtimeIndex.add(rc)
}
rc.setWHOIS(wi)
log.Debug("storage: set whois info for runtime client with ip %s: %+v", ip, wi)
} }
// Add stores persistent client information or returns an error. // Add stores persistent client information or returns an error.
@ -94,6 +408,9 @@ func (s *Storage) FindByName(name string) (p *Persistent, ok bool) {
// Find finds persistent client by string representation of the client ID, IP // Find finds persistent client by string representation of the client ID, IP
// address, or MAC. And returns its shallow copy. // address, or MAC. And returns its shallow copy.
//
// TODO(s.chzhen): Accept ClientIDData structure instead, which will contain
// the parsed IP address, if any.
func (s *Storage) Find(id string) (p *Persistent, ok bool) { func (s *Storage) Find(id string) (p *Persistent, ok bool) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@ -103,6 +420,16 @@ func (s *Storage) Find(id string) (p *Persistent, ok bool) {
return p.ShallowClone(), ok return p.ShallowClone(), ok
} }
ip, err := netip.ParseAddr(id)
if err != nil {
return nil, false
}
foundMAC := s.dhcp.MACByIP(ip)
if foundMAC != nil {
return s.FindByMAC(foundMAC)
}
return nil, false return nil, false
} }
@ -130,11 +457,9 @@ func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) {
return nil, false return nil, false
} }
// FindByMAC finds persistent client by MAC and returns its shallow copy. // FindByMAC finds persistent client by MAC and returns its shallow copy. s.mu
// is expected to be locked.
func (s *Storage) FindByMAC(mac net.HardwareAddr) (p *Persistent, ok bool) { func (s *Storage) FindByMAC(mac net.HardwareAddr) (p *Persistent, ok bool) {
s.mu.Lock()
defer s.mu.Unlock()
p, ok = s.index.findByMAC(mac) p, ok = s.index.findByMAC(mac)
if ok { if ok {
return p.ShallowClone(), ok return p.ShallowClone(), ok
@ -216,8 +541,8 @@ func (s *Storage) Size() (n int) {
return s.index.size() return s.index.size()
} }
// CloseUpstreams closes upstream configurations of persistent clients. // closeUpstreams closes upstream configurations of persistent clients.
func (s *Storage) CloseUpstreams() (err error) { func (s *Storage) closeUpstreams() (err error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@ -226,89 +551,27 @@ func (s *Storage) CloseUpstreams() (err error) {
// ClientRuntime returns a copy of the saved runtime client by ip. If no such // ClientRuntime returns a copy of the saved runtime client by ip. If no such
// client exists, returns nil. // client exists, returns nil.
//
// TODO(s.chzhen): Use it.
func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) { func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
return s.runtimeIndex.Client(ip) rc = s.runtimeIndex.client(ip)
} if rc != nil {
return rc.clone()
// UpdateRuntime updates the stored runtime client with information from rc. If
// no such client exists, saves the copy of rc in storage. rc must not be nil.
func (s *Storage) UpdateRuntime(rc *Runtime) (added bool) {
s.mu.Lock()
defer s.mu.Unlock()
return s.updateRuntimeLocked(rc)
}
// updateRuntimeLocked updates the stored runtime client with information from
// rc. rc must not be nil. Storage.mu is expected to be locked.
func (s *Storage) updateRuntimeLocked(rc *Runtime) (added bool) {
stored := s.runtimeIndex.Client(rc.ip)
if stored == nil {
s.runtimeIndex.Add(rc.Clone())
return true
} }
if rc.whois != nil { if !s.runtimeSourceDHCP {
stored.whois = rc.whois.Clone() return nil
} }
if rc.arp != nil { host := s.dhcp.HostByIP(ip)
stored.arp = slices.Clone(rc.arp) if host == "" {
return nil
} }
if rc.rdns != nil { rc = s.runtimeIndex.setInfo(ip, SourceDHCP, []string{host})
stored.rdns = slices.Clone(rc.rdns)
}
if rc.dhcp != nil { return rc.clone()
stored.dhcp = slices.Clone(rc.dhcp)
}
if rc.hostsFile != nil {
stored.hostsFile = slices.Clone(rc.hostsFile)
}
return false
}
// BatchUpdateBySource updates the stored runtime clients information from the
// specified source and returns the number of added and removed clients.
func (s *Storage) BatchUpdateBySource(src Source, rcs []*Runtime) (added, removed int) {
s.mu.Lock()
defer s.mu.Unlock()
for _, rc := range s.runtimeIndex.index {
rc.unset(src)
}
for _, rc := range rcs {
if s.updateRuntimeLocked(rc) {
added++
}
}
for ip, rc := range s.runtimeIndex.index {
if rc.isEmpty() {
delete(s.runtimeIndex.index, ip)
removed++
}
}
return added, removed
}
// SizeRuntime returns the number of the runtime clients.
func (s *Storage) SizeRuntime() (n int) {
s.mu.Lock()
defer s.mu.Unlock()
return s.runtimeIndex.Size()
} }
// RangeRuntime calls f for each runtime client in an undefined order. // RangeRuntime calls f for each runtime client in an undefined order.
@ -316,16 +579,11 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.runtimeIndex.Range(f) s.runtimeIndex.rangeClients(f)
} }
// DeleteBySource removes all runtime clients that have information only from // AllowedTags returns the list of available client tags. tags must not be
// the specified source and returns the number of removed clients. // modified.
// func (s *Storage) AllowedTags() (tags []string) {
// TODO(s.chzhen): Use it. return s.allowedTags
func (s *Storage) DeleteBySource(src Source) (n int) {
s.mu.Lock()
defer s.mu.Unlock()
return s.runtimeIndex.DeleteBySource(src)
} }

View file

@ -3,23 +3,513 @@ package client_test
import ( import (
"net" "net"
"net/netip" "net/netip"
"runtime"
"slices"
"sync"
"testing" "testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
"github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// testHostsContainer is a mock implementation of the [client.HostsContainer]
// interface.
type testHostsContainer struct {
onUpd func() (updates <-chan *hostsfile.DefaultStorage)
}
// type check
var _ client.HostsContainer = (*testHostsContainer)(nil)
// Upd implements the [client.HostsContainer] interface for *testHostsContainer.
func (c *testHostsContainer) Upd() (updates <-chan *hostsfile.DefaultStorage) {
return c.onUpd()
}
// Interface stores and refreshes the network neighborhood reported by ARP
// (Address Resolution Protocol).
type Interface interface {
// Refresh updates the stored data. It must be safe for concurrent use.
Refresh() (err error)
// Neighbors returnes the last set of data reported by ARP. Both the method
// and it's result must be safe for concurrent use.
Neighbors() (ns []arpdb.Neighbor)
}
// testARPDB is a mock implementation of the [arpdb.Interface].
type testARPDB struct {
onRefresh func() (err error)
onNeighbors func() (ns []arpdb.Neighbor)
}
// type check
var _ arpdb.Interface = (*testARPDB)(nil)
// Refresh implements the [arpdb.Interface] interface for *testARP.
func (c *testARPDB) Refresh() (err error) {
return c.onRefresh()
}
// Neighbors implements the [arpdb.Interface] interface for *testARP.
func (c *testARPDB) Neighbors() (ns []arpdb.Neighbor) {
return c.onNeighbors()
}
// testDHCP is a mock implementation of the [client.DHCP].
type testDHCP struct {
OnLeases func() (leases []*dhcpsvc.Lease)
OnHostBy func(ip netip.Addr) (host string)
OnMACBy func(ip netip.Addr) (mac net.HardwareAddr)
}
// type check
var _ client.DHCP = (*testDHCP)(nil)
// Lease implements the [client.DHCP] interface for *testDHCP.
func (t *testDHCP) Leases() (leases []*dhcpsvc.Lease) { return t.OnLeases() }
// HostByIP implements the [client.DHCP] interface for *testDHCP.
func (t *testDHCP) HostByIP(ip netip.Addr) (host string) { return t.OnHostBy(ip) }
// MACByIP implements the [client.DHCP] interface for *testDHCP.
func (t *testDHCP) MACByIP(ip netip.Addr) (mac net.HardwareAddr) { return t.OnMACBy(ip) }
// compareRuntimeInfo is a helper function that returns true if the runtime
// client has provided info.
func compareRuntimeInfo(rc *client.Runtime, src client.Source, host string) (ok bool) {
s, h := rc.Info()
if s != src {
return false
} else if h != host {
return false
}
return true
}
func TestStorage_Add_hostsfile(t *testing.T) {
var (
cliIP1 = netip.MustParseAddr("1.1.1.1")
cliName1 = "client_one"
cliIP2 = netip.MustParseAddr("2.2.2.2")
cliName2 = "client_two"
)
hostCh := make(chan *hostsfile.DefaultStorage)
h := &testHostsContainer{
onUpd: func() (updates <-chan *hostsfile.DefaultStorage) { return hostCh },
}
storage, err := client.NewStorage(&client.StorageConfig{
DHCP: client.EmptyDHCP{},
EtcHosts: h,
ARPClientsUpdatePeriod: testTimeout / 10,
})
require.NoError(t, err)
err = storage.Start(testutil.ContextWithTimeout(t, testTimeout))
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return storage.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
})
t.Run("add_hosts", func(t *testing.T) {
var s *hostsfile.DefaultStorage
s, err = hostsfile.NewDefaultStorage()
require.NoError(t, err)
s.Add(&hostsfile.Record{
Addr: cliIP1,
Names: []string{cliName1},
})
testutil.RequireSend(t, hostCh, s, testTimeout)
require.Eventually(t, func() (ok bool) {
cli1 := storage.ClientRuntime(cliIP1)
if cli1 == nil {
return false
}
assert.True(t, compareRuntimeInfo(cli1, client.SourceHostsFile, cliName1))
return true
}, testTimeout, testTimeout/10)
})
t.Run("update_hosts", func(t *testing.T) {
var s *hostsfile.DefaultStorage
s, err = hostsfile.NewDefaultStorage()
require.NoError(t, err)
s.Add(&hostsfile.Record{
Addr: cliIP2,
Names: []string{cliName2},
})
testutil.RequireSend(t, hostCh, s, testTimeout)
require.Eventually(t, func() (ok bool) {
cli2 := storage.ClientRuntime(cliIP2)
if cli2 == nil {
return false
}
assert.True(t, compareRuntimeInfo(cli2, client.SourceHostsFile, cliName2))
cli1 := storage.ClientRuntime(cliIP1)
require.Nil(t, cli1)
return true
}, testTimeout, testTimeout/10)
})
}
func TestStorage_Add_arp(t *testing.T) {
var (
mu sync.Mutex
neighbors []arpdb.Neighbor
cliIP1 = netip.MustParseAddr("1.1.1.1")
cliName1 = "client_one"
cliIP2 = netip.MustParseAddr("2.2.2.2")
cliName2 = "client_two"
)
a := &testARPDB{
onRefresh: func() (err error) { return nil },
onNeighbors: func() (ns []arpdb.Neighbor) {
mu.Lock()
defer mu.Unlock()
return neighbors
},
}
storage, err := client.NewStorage(&client.StorageConfig{
DHCP: client.EmptyDHCP{},
ARPDB: a,
ARPClientsUpdatePeriod: testTimeout / 10,
})
require.NoError(t, err)
err = storage.Start(testutil.ContextWithTimeout(t, testTimeout))
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return storage.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
})
t.Run("add_hosts", func(t *testing.T) {
func() {
mu.Lock()
defer mu.Unlock()
neighbors = []arpdb.Neighbor{{
Name: cliName1,
IP: cliIP1,
}}
}()
require.Eventually(t, func() (ok bool) {
cli1 := storage.ClientRuntime(cliIP1)
if cli1 == nil {
return false
}
assert.True(t, compareRuntimeInfo(cli1, client.SourceARP, cliName1))
return true
}, testTimeout, testTimeout/10)
})
t.Run("update_hosts", func(t *testing.T) {
func() {
mu.Lock()
defer mu.Unlock()
neighbors = []arpdb.Neighbor{{
Name: cliName2,
IP: cliIP2,
}}
}()
require.Eventually(t, func() (ok bool) {
cli2 := storage.ClientRuntime(cliIP2)
if cli2 == nil {
return false
}
assert.True(t, compareRuntimeInfo(cli2, client.SourceARP, cliName2))
cli1 := storage.ClientRuntime(cliIP1)
require.Nil(t, cli1)
return true
}, testTimeout, testTimeout/10)
})
}
func TestStorage_Add_whois(t *testing.T) {
var (
cliIP1 = netip.MustParseAddr("1.1.1.1")
cliIP2 = netip.MustParseAddr("2.2.2.2")
cliName2 = "client_two"
cliIP3 = netip.MustParseAddr("3.3.3.3")
cliName3 = "client_three"
)
storage, err := client.NewStorage(&client.StorageConfig{
DHCP: client.EmptyDHCP{},
})
require.NoError(t, err)
whois := &whois.Info{
Country: "AU",
Orgname: "Example Org",
}
t.Run("new_client", func(t *testing.T) {
storage.UpdateAddress(cliIP1, "", whois)
cli1 := storage.ClientRuntime(cliIP1)
require.NotNil(t, cli1)
assert.Equal(t, whois, cli1.WHOIS())
})
t.Run("existing_runtime_client", func(t *testing.T) {
storage.UpdateAddress(cliIP2, cliName2, nil)
storage.UpdateAddress(cliIP2, "", whois)
cli2 := storage.ClientRuntime(cliIP2)
require.NotNil(t, cli2)
assert.True(t, compareRuntimeInfo(cli2, client.SourceRDNS, cliName2))
assert.Equal(t, whois, cli2.WHOIS())
})
t.Run("can't_set_persistent_client", func(t *testing.T) {
err = storage.Add(&client.Persistent{
Name: cliName3,
UID: client.MustNewUID(),
IPs: []netip.Addr{cliIP3},
})
require.NoError(t, err)
storage.UpdateAddress(cliIP3, "", whois)
rc := storage.ClientRuntime(cliIP3)
require.Nil(t, rc)
})
}
func TestClientsDHCP(t *testing.T) {
var (
cliIP1 = netip.MustParseAddr("1.1.1.1")
cliName1 = "one.dhcp"
cliIP2 = netip.MustParseAddr("2.2.2.2")
cliMAC2 = mustParseMAC("22:22:22:22:22:22")
cliName2 = "two.dhcp"
cliIP3 = netip.MustParseAddr("3.3.3.3")
cliMAC3 = mustParseMAC("33:33:33:33:33:33")
cliName3 = "three.dhcp"
prsCliIP = netip.MustParseAddr("4.3.2.1")
prsCliMAC = mustParseMAC("AA:AA:AA:AA:AA:AA")
prsCliName = "persistent.dhcp"
)
ipToHost := map[netip.Addr]string{
cliIP1: cliName1,
}
ipToMAC := map[netip.Addr]net.HardwareAddr{
prsCliIP: prsCliMAC,
}
leases := []*dhcpsvc.Lease{{
IP: cliIP2,
Hostname: cliName2,
HWAddr: cliMAC2,
}, {
IP: cliIP3,
Hostname: cliName3,
HWAddr: cliMAC3,
}}
d := &testDHCP{
OnLeases: func() (ls []*dhcpsvc.Lease) {
return leases
},
OnHostBy: func(ip netip.Addr) (host string) {
return ipToHost[ip]
},
OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) {
return ipToMAC[ip]
},
}
storage, err := client.NewStorage(&client.StorageConfig{
DHCP: d,
RuntimeSourceDHCP: true,
})
require.NoError(t, err)
t.Run("find_runtime", func(t *testing.T) {
cli1 := storage.ClientRuntime(cliIP1)
require.NotNil(t, cli1)
assert.True(t, compareRuntimeInfo(cli1, client.SourceDHCP, cliName1))
})
t.Run("find_persistent", func(t *testing.T) {
err = storage.Add(&client.Persistent{
Name: prsCliName,
UID: client.MustNewUID(),
MACs: []net.HardwareAddr{prsCliMAC},
})
require.NoError(t, err)
prsCli, ok := storage.Find(prsCliIP.String())
require.True(t, ok)
assert.Equal(t, prsCliName, prsCli.Name)
})
t.Run("leases", func(t *testing.T) {
delete(ipToHost, cliIP1)
storage.UpdateDHCP()
cli1 := storage.ClientRuntime(cliIP1)
require.Nil(t, cli1)
for i, l := range leases {
cli := storage.ClientRuntime(l.IP)
require.NotNil(t, cli)
src, host := cli.Info()
assert.Equal(t, client.SourceDHCP, src)
assert.Equal(t, leases[i].Hostname, host)
}
})
t.Run("range", func(t *testing.T) {
s := 0
storage.RangeRuntime(func(rc *client.Runtime) (cont bool) {
s++
return true
})
assert.Equal(t, len(leases), s)
})
}
func TestClientsAddExisting(t *testing.T) {
t.Run("simple", func(t *testing.T) {
storage, err := client.NewStorage(&client.StorageConfig{
DHCP: client.EmptyDHCP{},
})
require.NoError(t, err)
ip := netip.MustParseAddr("1.1.1.1")
// Add a client.
err = storage.Add(&client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")},
Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")},
MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}},
})
require.NoError(t, err)
// Now add an auto-client with the same IP.
storage.UpdateAddress(ip, "test", nil)
rc := storage.ClientRuntime(ip)
assert.True(t, compareRuntimeInfo(rc, client.SourceRDNS, "test"))
})
t.Run("complicated", func(t *testing.T) {
// TODO(a.garipov): Properly decouple the DHCP server from the client
// storage.
if runtime.GOOS == "windows" {
t.Skip("skipping dhcp test on windows")
}
// First, init a DHCP server with a single static lease.
config := &dhcpd.ServerConfig{
Enabled: true,
DataDir: t.TempDir(),
Conf4: dhcpd.V4ServerConf{
Enabled: true,
GatewayIP: netip.MustParseAddr("1.2.3.1"),
SubnetMask: netip.MustParseAddr("255.255.255.0"),
RangeStart: netip.MustParseAddr("1.2.3.2"),
RangeEnd: netip.MustParseAddr("1.2.3.10"),
},
}
dhcpServer, err := dhcpd.Create(config)
require.NoError(t, err)
storage, err := client.NewStorage(&client.StorageConfig{
DHCP: dhcpServer,
})
require.NoError(t, err)
ip := netip.MustParseAddr("1.2.3.4")
err = dhcpServer.AddStaticLease(&dhcpsvc.Lease{
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: ip,
Hostname: "testhost",
Expiry: time.Now().Add(time.Hour),
})
require.NoError(t, err)
// Add a new client with the same IP as for a client with MAC.
err = storage.Add(&client.Persistent{
Name: "client2",
UID: client.MustNewUID(),
IPs: []netip.Addr{ip},
})
require.NoError(t, err)
// Add a new client with the IP from the first client's IP range.
err = storage.Add(&client.Persistent{
Name: "client3",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
})
require.NoError(t, err)
})
}
// newStorage is a helper function that returns a client storage filled with // newStorage is a helper function that returns a client storage filled with
// persistent clients from the m. It also generates a UID for each client. // persistent clients from the m. It also generates a UID for each client.
func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) { func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) {
tb.Helper() tb.Helper()
s = client.NewStorage(&client.Config{ s, err := client.NewStorage(&client.StorageConfig{
AllowedTags: nil, DHCP: client.EmptyDHCP{},
}) })
require.NoError(tb, err)
for _, c := range m { for _, c := range m {
c.UID = client.MustNewUID() c.UID = client.MustNewUID()
@ -31,14 +521,6 @@ func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) {
return s return s
} }
// newRuntimeClient is a helper function that returns a new runtime client.
func newRuntimeClient(ip netip.Addr, source client.Source, host string) (rc *client.Runtime) {
rc = client.NewRuntime(ip)
rc.SetInfo(source, []string{host})
return rc
}
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an // mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
// error. // error.
func mustParseMAC(s string) (mac net.HardwareAddr) { func mustParseMAC(s string) (mac net.HardwareAddr) {
@ -55,7 +537,7 @@ func TestStorage_Add(t *testing.T) {
existingName = "existing_name" existingName = "existing_name"
existingClientID = "existing_client_id" existingClientID = "existing_client_id"
allowedTag = "tag" allowedTag = "user_admin"
notAllowedTag = "not_allowed_tag" notAllowedTag = "not_allowed_tag"
) )
@ -73,10 +555,20 @@ func TestStorage_Add(t *testing.T) {
UID: existingClientUID, UID: existingClientUID,
} }
s := client.NewStorage(&client.Config{ s, err := client.NewStorage(&client.StorageConfig{})
AllowedTags: []string{allowedTag}, require.NoError(t, err)
})
err := s.Add(existingClient) tags := s.AllowedTags()
require.NotZero(t, len(tags))
require.True(t, slices.IsSorted(tags))
_, ok := slices.BinarySearch(tags, allowedTag)
require.True(t, ok)
_, ok = slices.BinarySearch(tags, notAllowedTag)
require.False(t, ok)
err = s.Add(existingClient)
require.NoError(t, err) require.NoError(t, err)
testCases := []struct { testCases := []struct {
@ -136,12 +628,43 @@ func TestStorage_Add(t *testing.T) {
}, { }, {
name: "not_allowed_tag", name: "not_allowed_tag",
cli: &client.Persistent{ cli: &client.Persistent{
Name: "nont_allowed_tag", Name: "not_allowed_tag",
Tags: []string{notAllowedTag}, Tags: []string{notAllowedTag},
IPs: []netip.Addr{netip.MustParseAddr("4.4.4.4")}, IPs: []netip.Addr{netip.MustParseAddr("4.4.4.4")},
UID: client.MustNewUID(), UID: client.MustNewUID(),
}, },
wantErrMsg: `adding client: invalid tag: "not_allowed_tag"`, wantErrMsg: `adding client: invalid tag: "not_allowed_tag"`,
}, {
name: "allowed_tag",
cli: &client.Persistent{
Name: "allowed_tag",
Tags: []string{allowedTag},
IPs: []netip.Addr{netip.MustParseAddr("5.5.5.5")},
UID: client.MustNewUID(),
},
wantErrMsg: "",
}, {
name: "",
cli: &client.Persistent{
Name: "",
IPs: []netip.Addr{netip.MustParseAddr("6.6.6.6")},
UID: client.MustNewUID(),
},
wantErrMsg: "adding client: empty name",
}, {
name: "no_id",
cli: &client.Persistent{
Name: "no_id",
UID: client.MustNewUID(),
},
wantErrMsg: "adding client: id required",
}, {
name: "no_uid",
cli: &client.Persistent{
Name: "no_uid",
IPs: []netip.Addr{netip.MustParseAddr("7.7.7.7")},
},
wantErrMsg: "adding client: uid required",
}} }}
for _, tc := range testCases { for _, tc := range testCases {
@ -164,10 +687,10 @@ func TestStorage_RemoveByName(t *testing.T) {
UID: client.MustNewUID(), UID: client.MustNewUID(),
} }
s := client.NewStorage(&client.Config{ s, err := client.NewStorage(&client.StorageConfig{})
AllowedTags: nil, require.NoError(t, err)
})
err := s.Add(existingClient) err = s.Add(existingClient)
require.NoError(t, err) require.NoError(t, err)
testCases := []struct { testCases := []struct {
@ -191,9 +714,9 @@ func TestStorage_RemoveByName(t *testing.T) {
} }
t.Run("duplicate_remove", func(t *testing.T) { t.Run("duplicate_remove", func(t *testing.T) {
s = client.NewStorage(&client.Config{ s, err = client.NewStorage(&client.StorageConfig{})
AllowedTags: nil, require.NoError(t, err)
})
err = s.Add(existingClient) err = s.Add(existingClient)
require.NoError(t, err) require.NoError(t, err)
@ -623,157 +1146,3 @@ func TestStorage_RangeByName(t *testing.T) {
}) })
} }
} }
func TestStorage_UpdateRuntime(t *testing.T) {
const (
addedARP = "added_arp"
addedSecondARP = "added_arp"
updatedARP = "updated_arp"
cliCity = "City"
cliCountry = "Country"
cliOrgname = "Orgname"
)
var (
ip = netip.MustParseAddr("1.1.1.1")
ip2 = netip.MustParseAddr("2.2.2.2")
)
updated := client.NewRuntime(ip)
updated.SetInfo(client.SourceARP, []string{updatedARP})
info := &whois.Info{
City: cliCity,
Country: cliCountry,
Orgname: cliOrgname,
}
updated.SetWHOIS(info)
s := client.NewStorage(&client.Config{
AllowedTags: nil,
})
t.Run("add_arp_client", func(t *testing.T) {
added := client.NewRuntime(ip)
added.SetInfo(client.SourceARP, []string{addedARP})
require.True(t, s.UpdateRuntime(added))
require.Equal(t, 1, s.SizeRuntime())
got := s.ClientRuntime(ip)
source, host := got.Info()
assert.Equal(t, client.SourceARP, source)
assert.Equal(t, addedARP, host)
})
t.Run("add_second_arp_client", func(t *testing.T) {
added := client.NewRuntime(ip2)
added.SetInfo(client.SourceARP, []string{addedSecondARP})
require.True(t, s.UpdateRuntime(added))
require.Equal(t, 2, s.SizeRuntime())
got := s.ClientRuntime(ip2)
source, host := got.Info()
assert.Equal(t, client.SourceARP, source)
assert.Equal(t, addedSecondARP, host)
})
t.Run("update_first_client", func(t *testing.T) {
require.False(t, s.UpdateRuntime(updated))
got := s.ClientRuntime(ip)
require.Equal(t, 2, s.SizeRuntime())
source, host := got.Info()
assert.Equal(t, client.SourceARP, source)
assert.Equal(t, updatedARP, host)
})
t.Run("remove_arp_info", func(t *testing.T) {
n := s.DeleteBySource(client.SourceARP)
require.Equal(t, 1, n)
require.Equal(t, 1, s.SizeRuntime())
got := s.ClientRuntime(ip)
source, _ := got.Info()
assert.Equal(t, client.SourceWHOIS, source)
assert.Equal(t, info, got.WHOIS())
})
t.Run("remove_whois_info", func(t *testing.T) {
n := s.DeleteBySource(client.SourceWHOIS)
require.Equal(t, 1, n)
require.Equal(t, 0, s.SizeRuntime())
})
}
func TestStorage_BatchUpdateBySource(t *testing.T) {
const (
defSrc = client.SourceARP
cliFirstHost1 = "host1"
cliFirstHost2 = "host2"
cliUpdatedHost3 = "host3"
cliUpdatedHost4 = "host4"
cliUpdatedHost5 = "host5"
)
var (
cliFirstIP1 = netip.MustParseAddr("1.1.1.1")
cliFirstIP2 = netip.MustParseAddr("2.2.2.2")
cliUpdatedIP3 = netip.MustParseAddr("3.3.3.3")
cliUpdatedIP4 = netip.MustParseAddr("4.4.4.4")
cliUpdatedIP5 = netip.MustParseAddr("5.5.5.5")
)
firstClients := []*client.Runtime{
newRuntimeClient(cliFirstIP1, defSrc, cliFirstHost1),
newRuntimeClient(cliFirstIP2, defSrc, cliFirstHost2),
}
updatedClients := []*client.Runtime{
newRuntimeClient(cliUpdatedIP3, defSrc, cliUpdatedHost3),
newRuntimeClient(cliUpdatedIP4, defSrc, cliUpdatedHost4),
newRuntimeClient(cliUpdatedIP5, defSrc, cliUpdatedHost5),
}
s := client.NewStorage(&client.Config{
AllowedTags: nil,
})
t.Run("populate_storage_with_first_clients", func(t *testing.T) {
added, removed := s.BatchUpdateBySource(defSrc, firstClients)
require.Equal(t, len(firstClients), added)
require.Equal(t, 0, removed)
require.Equal(t, len(firstClients), s.SizeRuntime())
rc := s.ClientRuntime(cliFirstIP1)
src, host := rc.Info()
assert.Equal(t, defSrc, src)
assert.Equal(t, cliFirstHost1, host)
})
t.Run("update_storage", func(t *testing.T) {
added, removed := s.BatchUpdateBySource(defSrc, updatedClients)
require.Equal(t, len(updatedClients), added)
require.Equal(t, len(firstClients), removed)
require.Equal(t, len(updatedClients), s.SizeRuntime())
rc := s.ClientRuntime(cliUpdatedIP3)
src, host := rc.Info()
assert.Equal(t, defSrc, src)
assert.Equal(t, cliUpdatedHost3, host)
rc = s.ClientRuntime(cliFirstIP1)
assert.Nil(t, rc)
})
t.Run("remove_all", func(t *testing.T) {
added, removed := s.BatchUpdateBySource(defSrc, []*client.Runtime{})
require.Equal(t, 0, added)
require.Equal(t, len(updatedClients), removed)
require.Equal(t, 0, s.SizeRuntime())
})
}

View file

@ -1,8 +1,8 @@
package home package home
import ( import (
"context"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"slices" "slices"
"sync" "sync"
@ -11,7 +11,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/arpdb" "github.com/AdguardTeam/AdGuardHome/internal/arpdb"
"github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/querylog"
@ -20,47 +19,18 @@ import (
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/stringutil"
) )
// DHCP is an interface for accessing DHCP lease data the [clientsContainer]
// needs.
type DHCP interface {
// Leases returns all the DHCP leases.
Leases() (leases []*dhcpsvc.Lease)
// HostByIP returns the hostname of the DHCP client with the given IP
// address. The address will be netip.Addr{} if there is no such client,
// due to an assumption that a DHCP client must always have a hostname.
HostByIP(ip netip.Addr) (host string)
// MACByIP returns the MAC address for the given IP address leased. It
// returns nil if there is no such client, due to an assumption that a DHCP
// client must always have a MAC address.
MACByIP(ip netip.Addr) (mac net.HardwareAddr)
}
// clientsContainer is the storage of all runtime and persistent clients. // clientsContainer is the storage of all runtime and persistent clients.
type clientsContainer struct { type clientsContainer struct {
// storage stores information about persistent clients. // storage stores information about persistent clients.
storage *client.Storage storage *client.Storage
// dhcp is the DHCP service implementation.
dhcp DHCP
// clientChecker checks if a client is blocked by the current access // clientChecker checks if a client is blocked by the current access
// settings. // settings.
clientChecker BlockedClientChecker clientChecker BlockedClientChecker
// etcHosts contains list of rewrite rules taken from the operating system's
// hosts database.
etcHosts *aghnet.HostsContainer
// arpDB stores the neighbors retrieved from ARP.
arpDB arpdb.Interface
// lock protects all fields. // lock protects all fields.
// //
// TODO(a.garipov): Use a pointer and describe which fields are protected in // TODO(a.garipov): Use a pointer and describe which fields are protected in
@ -92,7 +62,7 @@ type BlockedClientChecker interface {
// Note: this function must be called only once // Note: this function must be called only once
func (clients *clientsContainer) Init( func (clients *clientsContainer) Init(
objects []*clientObject, objects []*clientObject,
dhcpServer DHCP, dhcpServer client.DHCP,
etcHosts *aghnet.HostsContainer, etcHosts *aghnet.HostsContainer,
arpDB arpdb.Interface, arpDB arpdb.Interface,
filteringConf *filtering.Config, filteringConf *filtering.Config,
@ -102,26 +72,15 @@ func (clients *clientsContainer) Init(
return errors.Error("clients container already initialized") return errors.Error("clients container already initialized")
} }
clients.storage = client.NewStorage(&client.Config{ confClients := make([]*client.Persistent, 0, len(objects))
AllowedTags: clientTags, for i, o := range objects {
}) var p *client.Persistent
p, err = o.toPersistent(filteringConf)
if err != nil {
return fmt.Errorf("init persistent client at index %d: %w", i, err)
}
// TODO(e.burkov): Use [dhcpsvc] implementation when it's ready. confClients = append(confClients, p)
clients.dhcp = dhcpServer
clients.etcHosts = etcHosts
clients.arpDB = arpDB
err = clients.addFromConfig(objects, filteringConf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
if clients.testing {
return nil
} }
// The clients.etcHosts may be nil even if config.Clients.Sources.HostsFile // The clients.etcHosts may be nil even if config.Clients.Sources.HostsFile
@ -130,21 +89,26 @@ func (clients *clientsContainer) Init(
// TODO(e.burkov): The option should probably be returned, since hosts file // TODO(e.burkov): The option should probably be returned, since hosts file
// currently used not only for clients' information enrichment, but also in // currently used not only for clients' information enrichment, but also in
// the filtering module and upstream addresses resolution. // the filtering module and upstream addresses resolution.
if config.Clients.Sources.HostsFile && clients.etcHosts != nil { var hosts client.HostsContainer = etcHosts
go clients.handleHostsUpdates() if !config.Clients.Sources.HostsFile {
hosts = nil
}
clients.storage, err = client.NewStorage(&client.StorageConfig{
InitialClients: confClients,
DHCP: dhcpServer,
EtcHosts: hosts,
ARPDB: arpDB,
ARPClientsUpdatePeriod: arpClientsUpdatePeriod,
RuntimeSourceDHCP: config.Clients.Sources.DHCP,
})
if err != nil {
return fmt.Errorf("init client storage: %w", err)
} }
return nil return nil
} }
// handleHostsUpdates receives the updates from the hosts container and adds
// them to the clients container. It is intended to be used as a goroutine.
func (clients *clientsContainer) handleHostsUpdates() {
for upd := range clients.etcHosts.Upd() {
clients.addFromHostsFile(upd)
}
}
// webHandlersRegistered prevents a [clientsContainer] from registering its web // webHandlersRegistered prevents a [clientsContainer] from registering its web
// handlers more than once. // handlers more than once.
// //
@ -152,7 +116,7 @@ func (clients *clientsContainer) handleHostsUpdates() {
var webHandlersRegistered = false var webHandlersRegistered = false
// Start starts the clients container. // Start starts the clients container.
func (clients *clientsContainer) Start() { func (clients *clientsContainer) Start(ctx context.Context) (err error) {
if clients.testing { if clients.testing {
return return
} }
@ -162,14 +126,7 @@ func (clients *clientsContainer) Start() {
clients.registerWebHandlers() clients.registerWebHandlers()
} }
go clients.periodicUpdate() return clients.storage.Start(ctx)
}
// reloadARP reloads runtime clients from ARP, if configured.
func (clients *clientsContainer) reloadARP() {
if clients.arpDB != nil {
clients.addFromSystemARP()
}
} }
// clientObject is the YAML representation of a persistent client. // clientObject is the YAML representation of a persistent client.
@ -270,28 +227,6 @@ func (o *clientObject) toPersistent(
return cli, nil return cli, nil
} }
// addFromConfig initializes the clients container with objects from the
// configuration file.
func (clients *clientsContainer) addFromConfig(
objects []*clientObject,
filteringConf *filtering.Config,
) (err error) {
for i, o := range objects {
var cli *client.Persistent
cli, err = o.toPersistent(filteringConf)
if err != nil {
return fmt.Errorf("clients: init persistent client at index %d: %w", i, err)
}
err = clients.storage.Add(cli)
if err != nil {
return fmt.Errorf("adding client %q at index %d: %w", cli.Name, i, err)
}
}
return nil
}
// forConfig returns all currently known persistent clients as objects for the // forConfig returns all currently known persistent clients as objects for the
// configuration file. // configuration file.
func (clients *clientsContainer) forConfig() (objs []*clientObject) { func (clients *clientsContainer) forConfig() (objs []*clientObject) {
@ -332,39 +267,6 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
// arpClientsUpdatePeriod defines how often ARP clients are updated. // arpClientsUpdatePeriod defines how often ARP clients are updated.
const arpClientsUpdatePeriod = 10 * time.Minute const arpClientsUpdatePeriod = 10 * time.Minute
func (clients *clientsContainer) periodicUpdate() {
defer log.OnPanic("clients container")
for {
clients.reloadARP()
time.Sleep(arpClientsUpdatePeriod)
}
}
// clientSource checks if client with this IP address already exists and returns
// the source which updated it last. It returns [client.SourceNone] if the
// client doesn't exist. Note that it is only used in tests.
func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source) {
clients.lock.Lock()
defer clients.lock.Unlock()
_, ok := clients.findLocked(ip.String())
if ok {
return client.SourcePersistent
}
rc := clients.storage.ClientRuntime(ip)
if rc != nil {
src, _ = rc.Info()
}
if src < client.SourceDHCP && clients.dhcp.HostByIP(ip) != "" {
src = client.SourceDHCP
}
return src
}
// findMultiple is a wrapper around [clientsContainer.find] to make it a valid // findMultiple is a wrapper around [clientsContainer.find] to make it a valid
// client finder for the query log. c is never nil; if no information about the // client finder for the query log. c is never nil; if no information about the
// client is found, it returns an artificial client record by only setting the // client is found, it returns an artificial client record by only setting the
@ -410,7 +312,7 @@ func (clients *clientsContainer) clientOrArtificial(
}, false }, false
} }
rc := clients.findRuntimeClient(ip) rc := clients.storage.ClientRuntime(ip)
if rc != nil { if rc != nil {
_, host := rc.Info() _, host := rc.Info()
@ -425,19 +327,6 @@ func (clients *clientsContainer) clientOrArtificial(
}, true }, true
} }
// find returns a shallow copy of the client if there is one found.
func (clients *clientsContainer) find(id string) (c *client.Persistent, ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
c, ok = clients.findLocked(id)
if !ok {
return nil, false
}
return c, true
}
// shouldCountClient is a wrapper around [clientsContainer.find] to make it a // shouldCountClient is a wrapper around [clientsContainer.find] to make it a
// valid client information finder for the statistics. If no information about // valid client information finder for the statistics. If no information about
// the client is found, it returns true. // the client is found, it returns true.
@ -446,7 +335,7 @@ func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) {
defer clients.lock.Unlock() defer clients.lock.Unlock()
for _, id := range ids { for _, id := range ids {
client, ok := clients.findLocked(id) client, ok := clients.storage.Find(id)
if ok { if ok {
return !client.IgnoreStatistics return !client.IgnoreStatistics
} }
@ -468,7 +357,7 @@ func (clients *clientsContainer) UpstreamConfigByID(
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
c, ok := clients.findLocked(id) c, ok := clients.storage.Find(id)
if !ok { if !ok {
return nil, nil return nil, nil
} else if c.UpstreamConfig != nil { } else if c.UpstreamConfig != nil {
@ -506,198 +395,17 @@ func (clients *clientsContainer) UpstreamConfigByID(
return conf, nil return conf, nil
} }
// findLocked searches for a client by its ID. clients.lock is expected to be
// locked.
func (clients *clientsContainer) findLocked(id string) (c *client.Persistent, ok bool) {
c, ok = clients.storage.Find(id)
if ok {
return c, true
}
ip, err := netip.ParseAddr(id)
if err != nil {
return nil, false
}
// TODO(e.burkov): Iterate through clients.list only once.
return clients.findDHCP(ip)
}
// findDHCP searches for a client by its MAC, if the DHCP server is active and
// there is such client. clients.lock is expected to be locked.
func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent, ok bool) {
foundMAC := clients.dhcp.MACByIP(ip)
if foundMAC == nil {
return nil, false
}
return clients.storage.FindByMAC(foundMAC)
}
// findRuntimeClient finds a runtime client by their IP.
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime) {
rc = clients.storage.ClientRuntime(ip)
host := clients.dhcp.HostByIP(ip)
if host != "" {
if rc == nil {
rc = client.NewRuntime(ip)
}
rc.SetInfo(client.SourceDHCP, []string{host})
return rc
}
return rc
}
// setWHOISInfo sets the WHOIS information for a client. clients.lock is
// expected to be locked.
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
_, ok := clients.findLocked(ip.String())
if ok {
log.Debug("clients: client for %s is already created, ignore whois info", ip)
return
}
rc := client.NewRuntime(ip)
rc.SetWHOIS(wi)
clients.storage.UpdateRuntime(rc)
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
}
// addHost adds a new IP-hostname pairing. The priorities of the sources are
// taken into account. ok is true if the pairing was added.
//
// TODO(a.garipov): Only used in internal tests. Consider removing.
func (clients *clientsContainer) addHost(
ip netip.Addr,
host string,
src client.Source,
) (ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
return clients.addHostLocked(ip, host, src)
}
// type check // type check
var _ client.AddressUpdater = (*clientsContainer)(nil) var _ client.AddressUpdater = (*clientsContainer)(nil)
// UpdateAddress implements the [client.AddressUpdater] interface for // UpdateAddress implements the [client.AddressUpdater] interface for
// *clientsContainer // *clientsContainer
func (clients *clientsContainer) UpdateAddress(ip netip.Addr, host string, info *whois.Info) { func (clients *clientsContainer) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {
// Common fast path optimization. clients.storage.UpdateAddress(ip, host, info)
if host == "" && info == nil {
return
}
clients.lock.Lock()
defer clients.lock.Unlock()
if host != "" {
ok := clients.addHostLocked(ip, host, client.SourceRDNS)
if !ok {
log.Debug("clients: host for client %q already set with higher priority source", ip)
}
}
if info != nil {
clients.setWHOISInfo(ip, info)
}
}
// addHostLocked adds a new IP-hostname pairing. clients.lock is expected to be
// locked.
func (clients *clientsContainer) addHostLocked(
ip netip.Addr,
host string,
src client.Source,
) (ok bool) {
rc := client.NewRuntime(ip)
rc.SetInfo(src, []string{host})
if config.Clients.Sources.DHCP {
if dhcpHost := clients.dhcp.HostByIP(ip); dhcpHost != "" {
rc.SetInfo(client.SourceDHCP, []string{dhcpHost})
}
}
clients.storage.UpdateRuntime(rc)
log.Debug(
"clients: adding client info %s -> %q %q [%d]",
ip,
src,
host,
clients.storage.SizeRuntime(),
)
return true
}
// addFromHostsFile fills the client-hostname pairing index from the system's
// hosts files.
func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorage) {
clients.lock.Lock()
defer clients.lock.Unlock()
var rcs []*client.Runtime
hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) {
// Only the first name of the first record is considered a canonical
// hostname for the IP address.
//
// TODO(e.burkov): Consider using all the names from all the records.
rc := client.NewRuntime(addr)
rc.SetInfo(client.SourceHostsFile, []string{names[0]})
rcs = append(rcs, rc)
return true
})
added, removed := clients.storage.BatchUpdateBySource(client.SourceHostsFile, rcs)
log.Debug("clients: added %d, removed %d client aliases from system hosts file", added, removed)
}
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
// command.
func (clients *clientsContainer) addFromSystemARP() {
if err := clients.arpDB.Refresh(); err != nil {
log.Error("refreshing arp container: %s", err)
clients.arpDB = arpdb.Empty{}
return
}
ns := clients.arpDB.Neighbors()
if len(ns) == 0 {
log.Debug("refreshing arp container: the update is empty")
return
}
clients.lock.Lock()
defer clients.lock.Unlock()
var rcs []*client.Runtime
for _, n := range ns {
rc := client.NewRuntime(n.IP)
rc.SetInfo(client.SourceARP, []string{n.Name})
rcs = append(rcs, rc)
}
added, removed := clients.storage.BatchUpdateBySource(client.SourceARP, rcs)
log.Debug("clients: added %d, removed %d client aliases from arp neighborhood", added, removed)
} }
// close gracefully closes all the client-specific upstream configurations of // close gracefully closes all the client-specific upstream configurations of
// the persistent clients. // the persistent clients.
func (clients *clientsContainer) close() (err error) { func (clients *clientsContainer) close(ctx context.Context) (err error) {
return clients.storage.CloseUpstreams() return clients.storage.Shutdown(ctx)
} }

View file

@ -3,34 +3,14 @@ package home
import ( import (
"net" "net"
"net/netip" "net/netip"
"runtime"
"testing" "testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type testDHCP struct {
OnLeases func() (leases []*dhcpsvc.Lease)
OnHostBy func(ip netip.Addr) (host string)
OnMACBy func(ip netip.Addr) (mac net.HardwareAddr)
}
// Lease implements the [DHCP] interface for testDHCP.
func (t *testDHCP) Leases() (leases []*dhcpsvc.Lease) { return t.OnLeases() }
// HostByIP implements the [DHCP] interface for testDHCP.
func (t *testDHCP) HostByIP(ip netip.Addr) (host string) { return t.OnHostBy(ip) }
// MACByIP implements the [DHCP] interface for testDHCP.
func (t *testDHCP) MACByIP(ip netip.Addr) (mac net.HardwareAddr) { return t.OnMACBy(ip) }
// newClientsContainer is a helper that creates a new clients container for // newClientsContainer is a helper that creates a new clients container for
// tests. // tests.
func newClientsContainer(t *testing.T) (c *clientsContainer) { func newClientsContainer(t *testing.T) (c *clientsContainer) {
@ -40,316 +20,11 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) {
testing: true, testing: true,
} }
dhcp := &testDHCP{ require.NoError(t, c.Init(nil, client.EmptyDHCP{}, nil, nil, &filtering.Config{}))
OnLeases: func() (leases []*dhcpsvc.Lease) { return nil },
OnHostBy: func(ip netip.Addr) (host string) { return "" },
OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) { return nil },
}
require.NoError(t, c.Init(nil, dhcp, nil, nil, &filtering.Config{}))
return c return c
} }
func TestClients(t *testing.T) {
clients := newClientsContainer(t)
t.Run("add_success", func(t *testing.T) {
var (
cliNone = "1.2.3.4"
cli1 = "1.1.1.1"
cli2 = "2.2.2.2"
cli1IP = netip.MustParseAddr(cli1)
cli2IP = netip.MustParseAddr(cli2)
cliIPv6 = netip.MustParseAddr("1:2:3::4")
)
c := &client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{cli1IP, cliIPv6},
}
err := clients.storage.Add(c)
require.NoError(t, err)
c = &client.Persistent{
Name: "client2",
UID: client.MustNewUID(),
IPs: []netip.Addr{cli2IP},
}
err = clients.storage.Add(c)
require.NoError(t, err)
c, ok := clients.find(cli1)
require.True(t, ok)
assert.Equal(t, "client1", c.Name)
c, ok = clients.find("1:2:3::4")
require.True(t, ok)
assert.Equal(t, "client1", c.Name)
c, ok = clients.find(cli2)
require.True(t, ok)
assert.Equal(t, "client2", c.Name)
_, ok = clients.find(cliNone)
assert.False(t, ok)
assert.Equal(t, clients.clientSource(cli1IP), client.SourcePersistent)
assert.Equal(t, clients.clientSource(cli2IP), client.SourcePersistent)
})
t.Run("add_fail_name", func(t *testing.T) {
err := clients.storage.Add(&client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")},
})
require.Error(t, err)
})
t.Run("add_fail_ip", func(t *testing.T) {
err := clients.storage.Add(&client.Persistent{
Name: "client3",
UID: client.MustNewUID(),
})
require.Error(t, err)
})
t.Run("update_fail_ip", func(t *testing.T) {
err := clients.storage.Update("client1", &client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
})
assert.Error(t, err)
})
t.Run("update_success", func(t *testing.T) {
var (
cliOld = "1.1.1.1"
cliNew = "1.1.1.2"
cliNewIP = netip.MustParseAddr(cliNew)
)
prev, ok := clients.storage.FindByName("client1")
require.True(t, ok)
require.NotNil(t, prev)
err := clients.storage.Update("client1", &client.Persistent{
Name: "client1",
UID: prev.UID,
IPs: []netip.Addr{cliNewIP},
})
require.NoError(t, err)
_, ok = clients.find(cliOld)
assert.False(t, ok)
assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent)
prev, ok = clients.storage.FindByName("client1")
require.True(t, ok)
require.NotNil(t, prev)
err = clients.storage.Update("client1", &client.Persistent{
Name: "client1-renamed",
UID: prev.UID,
IPs: []netip.Addr{cliNewIP},
UseOwnSettings: true,
})
require.NoError(t, err)
c, ok := clients.find(cliNew)
require.True(t, ok)
assert.Equal(t, "client1-renamed", c.Name)
assert.True(t, c.UseOwnSettings)
nilCli, ok := clients.storage.FindByName("client1")
require.False(t, ok)
assert.Nil(t, nilCli)
require.Len(t, c.IDs(), 1)
assert.Equal(t, cliNewIP, c.IPs[0])
})
t.Run("del_success", func(t *testing.T) {
ok := clients.storage.RemoveByName("client1-renamed")
require.True(t, ok)
_, ok = clients.find("1.1.1.2")
assert.False(t, ok)
})
t.Run("del_fail", func(t *testing.T) {
ok := clients.storage.RemoveByName("client3")
assert.False(t, ok)
})
t.Run("addhost_success", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
ok := clients.addHost(ip, "host", client.SourceARP)
assert.True(t, ok)
ok = clients.addHost(ip, "host2", client.SourceARP)
assert.True(t, ok)
ok = clients.addHost(ip, "host3", client.SourceHostsFile)
assert.True(t, ok)
assert.Equal(t, clients.clientSource(ip), client.SourceHostsFile)
})
t.Run("dhcp_replaces_arp", func(t *testing.T) {
ip := netip.MustParseAddr("1.2.3.4")
ok := clients.addHost(ip, "from_arp", client.SourceARP)
assert.True(t, ok)
assert.Equal(t, clients.clientSource(ip), client.SourceARP)
ok = clients.addHost(ip, "from_dhcp", client.SourceDHCP)
assert.True(t, ok)
assert.Equal(t, clients.clientSource(ip), client.SourceDHCP)
})
t.Run("addhost_priority", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
ok := clients.addHost(ip, "host1", client.SourceRDNS)
assert.True(t, ok)
assert.Equal(t, client.SourceHostsFile, clients.clientSource(ip))
})
}
func TestClientsWHOIS(t *testing.T) {
clients := newClientsContainer(t)
whois := &whois.Info{
Country: "AU",
Orgname: "Example Org",
}
t.Run("new_client", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.255")
clients.setWHOISInfo(ip, whois)
rc := clients.storage.ClientRuntime(ip)
require.NotNil(t, rc)
assert.Equal(t, whois, rc.WHOIS())
})
t.Run("existing_auto-client", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
ok := clients.addHost(ip, "host", client.SourceRDNS)
assert.True(t, ok)
clients.setWHOISInfo(ip, whois)
rc := clients.storage.ClientRuntime(ip)
require.NotNil(t, rc)
assert.Equal(t, whois, rc.WHOIS())
})
t.Run("can't_set_manually-added", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.2")
err := clients.storage.Add(&client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")},
})
require.NoError(t, err)
clients.setWHOISInfo(ip, whois)
rc := clients.storage.ClientRuntime(ip)
require.Nil(t, rc)
assert.True(t, clients.storage.RemoveByName("client1"))
})
}
func TestClientsAddExisting(t *testing.T) {
clients := newClientsContainer(t)
t.Run("simple", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
// Add a client.
err := clients.storage.Add(&client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")},
Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")},
MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}},
})
require.NoError(t, err)
// Now add an auto-client with the same IP.
ok := clients.addHost(ip, "test", client.SourceRDNS)
assert.True(t, ok)
})
t.Run("complicated", func(t *testing.T) {
// TODO(a.garipov): Properly decouple the DHCP server from the client
// storage.
if runtime.GOOS == "windows" {
t.Skip("skipping dhcp test on windows")
}
ip := netip.MustParseAddr("1.2.3.4")
// First, init a DHCP server with a single static lease.
config := &dhcpd.ServerConfig{
Enabled: true,
DataDir: t.TempDir(),
Conf4: dhcpd.V4ServerConf{
Enabled: true,
GatewayIP: netip.MustParseAddr("1.2.3.1"),
SubnetMask: netip.MustParseAddr("255.255.255.0"),
RangeStart: netip.MustParseAddr("1.2.3.2"),
RangeEnd: netip.MustParseAddr("1.2.3.10"),
},
}
dhcpServer, err := dhcpd.Create(config)
require.NoError(t, err)
clients.dhcp = dhcpServer
err = dhcpServer.AddStaticLease(&dhcpsvc.Lease{
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: ip,
Hostname: "testhost",
Expiry: time.Now().Add(time.Hour),
})
require.NoError(t, err)
// Add a new client with the same IP as for a client with MAC.
err = clients.storage.Add(&client.Persistent{
Name: "client2",
UID: client.MustNewUID(),
IPs: []netip.Addr{ip},
})
require.NoError(t, err)
// Add a new client with the IP from the first client's IP range.
err = clients.storage.Add(&client.Persistent{
Name: "client3",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
})
require.NoError(t, err)
})
}
func TestClientsCustomUpstream(t *testing.T) { func TestClientsCustomUpstream(t *testing.T) {
clients := newClientsContainer(t) clients := newClientsContainer(t)

View file

@ -103,6 +103,8 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
return true return true
}) })
clients.storage.UpdateDHCP()
clients.storage.RangeRuntime(func(rc *client.Runtime) (cont bool) { clients.storage.RangeRuntime(func(rc *client.Runtime) (cont bool) {
src, host := rc.Info() src, host := rc.Info()
cj := runtimeClientJSON{ cj := runtimeClientJSON{
@ -117,20 +119,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
return true return true
}) })
if config.Clients.Sources.DHCP { data.Tags = clients.storage.AllowedTags()
for _, l := range clients.dhcp.Leases() {
cj := runtimeClientJSON{
Name: l.Hostname,
Source: client.SourceDHCP,
IP: l.IP,
WHOIS: &whois.Info{},
}
data.RuntimeClients = append(data.RuntimeClients, cj)
}
}
data.Tags = clientTags
aghhttp.WriteJSONResponseOK(w, r, data) aghhttp.WriteJSONResponseOK(w, r, data)
} }
@ -432,7 +421,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
} }
ip, _ := netip.ParseAddr(idStr) ip, _ := netip.ParseAddr(idStr)
c, ok := clients.find(idStr) c, ok := clients.storage.Find(idStr)
var cj *clientJSON var cj *clientJSON
if !ok { if !ok {
cj = clients.findRuntime(ip, idStr) cj = clients.findRuntime(ip, idStr)
@ -454,7 +443,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be // /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
// non-nil. // non-nil.
func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) { func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) {
rc := clients.findRuntimeClient(ip) rc := clients.storage.ClientRuntime(ip)
if rc == nil { if rc == nil {
// It is still possible that the IP used to be in the runtime clients // It is still possible that the IP used to be in the runtime clients
// list, but then the server was reloaded. So, check the DNS server's // list, but then the server was reloaded. So, check the DNS server's

View file

@ -1,27 +0,0 @@
package home
var clientTags = []string{
"device_audio",
"device_camera",
"device_gameconsole",
"device_laptop",
"device_nas", // Network-attached Storage
"device_other",
"device_pc",
"device_phone",
"device_printer",
"device_securityalarm",
"device_tablet",
"device_tv",
"os_android",
"os_ios",
"os_linux",
"os_macos",
"os_other",
"os_windows",
"user_admin",
"user_child",
"user_regular",
}

View file

@ -1,6 +1,7 @@
package home package home
import ( import (
"context"
"fmt" "fmt"
"log/slog" "log/slog"
"net" "net"
@ -414,9 +415,9 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
setts.ClientIP = clientIP setts.ClientIP = clientIP
c, ok := Context.clients.find(clientID) c, ok := Context.clients.storage.Find(clientID)
if !ok { if !ok {
c, ok = Context.clients.find(clientIP.String()) c, ok = Context.clients.storage.Find(clientIP.String())
if !ok { if !ok {
log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID) log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID)
@ -459,11 +460,15 @@ func startDNSServer() error {
Context.filters.EnableFilters(false) Context.filters.EnableFilters(false)
Context.clients.Start() // TODO(s.chzhen): Pass context.
err := Context.clients.Start(context.TODO())
err := Context.dnsServer.Start()
if err != nil { if err != nil {
return fmt.Errorf("couldn't start forwarding DNS server: %w", err) return fmt.Errorf("starting clients container: %w", err)
}
err = Context.dnsServer.Start()
if err != nil {
return fmt.Errorf("starting dns server: %w", err)
} }
Context.filters.Start() Context.filters.Start()
@ -500,7 +505,7 @@ func stopDNSServer() (err error) {
return fmt.Errorf("stopping forwarding dns server: %w", err) return fmt.Errorf("stopping forwarding dns server: %w", err)
} }
err = Context.clients.close() err = Context.clients.close(context.TODO())
if err != nil { if err != nil {
return fmt.Errorf("closing clients container: %w", err) return fmt.Errorf("closing clients container: %w", err)
} }

View file

@ -18,9 +18,8 @@ var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4})
func newStorage(tb testing.TB, clients []*client.Persistent) (s *client.Storage) { func newStorage(tb testing.TB, clients []*client.Persistent) (s *client.Storage) {
tb.Helper() tb.Helper()
s = client.NewStorage(&client.Config{ s, err := client.NewStorage(&client.StorageConfig{})
AllowedTags: nil, require.NoError(tb, err)
})
for _, p := range clients { for _, p := range clients {
p.UID = client.MustNewUID() p.UID = client.MustNewUID()

View file

@ -119,7 +119,7 @@ func Main(clientBuildFS fs.FS) {
log.Info("Received signal %q", sig) log.Info("Received signal %q", sig)
switch sig { switch sig {
case syscall.SIGHUP: case syscall.SIGHUP:
Context.clients.reloadARP() Context.clients.storage.ReloadARP()
Context.tls.reload() Context.tls.reload()
default: default:
cleanup(context.Background()) cleanup(context.Background())