diff --git a/internal/client/client.go b/internal/client/client.go index 9e76f01e..24e8c9a2 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -119,8 +119,8 @@ func (r *Runtime) Info() (cs Source, host string) { return cs, info[0] } -// SetInfo sets a host as a client information from the cs. -func (r *Runtime) SetInfo(cs Source, hosts []string) { +// setInfo sets a host as a client information from the cs. +func (r *Runtime) setInfo(cs Source, hosts []string) { // TODO(s.chzhen): Use contract where hosts must contain non-empty host. if len(hosts) == 1 && hosts[0] == "" { hosts = []string{} @@ -138,13 +138,13 @@ func (r *Runtime) SetInfo(cs Source, hosts []string) { } } -// WHOIS returns a WHOIS client information. +// WHOIS returns a copy of WHOIS client information. func (r *Runtime) WHOIS() (info *whois.Info) { - return r.whois + return r.whois.Clone() } -// SetWHOIS sets a WHOIS client information. info must be non-nil. -func (r *Runtime) SetWHOIS(info *whois.Info) { +// setWHOIS sets a WHOIS client information. info must be non-nil. +func (r *Runtime) setWHOIS(info *whois.Info) { r.whois = info } @@ -178,8 +178,8 @@ func (r *Runtime) Addr() (ip netip.Addr) { return r.ip } -// Clone returns a deep copy of the runtime client. -func (r *Runtime) Clone() (c *Runtime) { +// clone returns a deep copy of the runtime client. +func (r *Runtime) clone() (c *Runtime) { return &Runtime{ ip: r.ip, whois: r.whois.Clone(), diff --git a/internal/client/persistent.go b/internal/client/persistent.go index ce68986f..4e09c5b8 100644 --- a/internal/client/persistent.go +++ b/internal/client/persistent.go @@ -13,7 +13,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" - "github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" @@ -136,7 +135,8 @@ type Persistent struct { } // validate returns an error if persistent client information contains errors. -func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) { +// allTags must be sorted. +func (c *Persistent) validate(allTags []string) (err error) { switch { case c.Name == "": return errors.Error("empty name") @@ -157,7 +157,8 @@ func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) { } for _, t := range c.Tags { - if !allTags.Has(t) { + _, ok := slices.BinarySearch(allTags, t) + if !ok { return fmt.Errorf("invalid tag: %q", t) } } diff --git a/internal/client/persistent_internal_test.go b/internal/client/persistent_internal_test.go index a96c3778..c62b76da 100644 --- a/internal/client/persistent_internal_test.go +++ b/internal/client/persistent_internal_test.go @@ -1,11 +1,8 @@ package client import ( - "net/netip" "testing" - "github.com/AdguardTeam/golibs/container" - "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -125,69 +122,3 @@ func TestPersistent_EqualIDs(t *testing.T) { }) } } - -func TestPersistent_Validate(t *testing.T) { - const ( - allowedTag = "allowed_tag" - notAllowedTag = "not_allowed_tag" - ) - - allowedTags := container.NewMapSet(allowedTag) - - testCases := []struct { - name string - cli *Persistent - wantErrMsg string - }{{ - name: "success", - cli: &Persistent{ - Name: "basic", - IPs: []netip.Addr{ - netip.MustParseAddr("1.2.3.4"), - }, - UID: MustNewUID(), - }, - wantErrMsg: "", - }, { - name: "empty_name", - cli: &Persistent{ - Name: "", - }, - wantErrMsg: "empty name", - }, { - name: "no_id", - cli: &Persistent{ - Name: "no_id", - }, - wantErrMsg: "id required", - }, { - name: "no_uid", - cli: &Persistent{ - Name: "no_uid", - IPs: []netip.Addr{ - netip.MustParseAddr("1.2.3.4"), - }, - }, - wantErrMsg: "uid required", - }, { - name: "not_allowed_tag", - cli: &Persistent{ - Name: "basic", - IPs: []netip.Addr{ - netip.MustParseAddr("1.2.3.4"), - }, - UID: MustNewUID(), - Tags: []string{ - notAllowedTag, - }, - }, - wantErrMsg: `invalid tag: "` + notAllowedTag + `"`, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := tc.cli.validate(allowedTags) - testutil.AssertErrorMsg(t, tc.wantErrMsg, err) - }) - } -} diff --git a/internal/client/runtimeindex.go b/internal/client/runtimeindex.go index 300fdca0..fc994bfa 100644 --- a/internal/client/runtimeindex.go +++ b/internal/client/runtimeindex.go @@ -2,39 +2,34 @@ package client import "net/netip" -// RuntimeIndex stores information about runtime clients. -type RuntimeIndex struct { +// runtimeIndex stores information about runtime clients. +type runtimeIndex struct { // index maps IP address to runtime client. index map[netip.Addr]*Runtime } -// NewRuntimeIndex returns initialized runtime index. -func NewRuntimeIndex() (ri *RuntimeIndex) { - return &RuntimeIndex{ +// newRuntimeIndex returns initialized runtime index. +func newRuntimeIndex() (ri *runtimeIndex) { + return &runtimeIndex{ index: map[netip.Addr]*Runtime{}, } } -// Client returns the saved runtime client by ip. If no such client exists, +// client returns the saved runtime client by ip. If no such client exists, // returns nil. -func (ri *RuntimeIndex) Client(ip netip.Addr) (rc *Runtime) { +func (ri *runtimeIndex) client(ip netip.Addr) (rc *Runtime) { return ri.index[ip] } -// Add saves the runtime client in the index. IP address of a client must be +// add saves the runtime client in the index. IP address of a client must be // unique. See [Runtime.Client]. rc must not be nil. -func (ri *RuntimeIndex) Add(rc *Runtime) { +func (ri *runtimeIndex) add(rc *Runtime) { ip := rc.Addr() ri.index[ip] = rc } -// Size returns the number of the runtime clients. -func (ri *RuntimeIndex) Size() (n int) { - return len(ri.index) -} - -// Range calls f for each runtime client in an undefined order. -func (ri *RuntimeIndex) Range(f func(rc *Runtime) (cont bool)) { +// rangeClients calls f for each runtime client in an undefined order. +func (ri *runtimeIndex) rangeClients(f func(rc *Runtime) (cont bool)) { for _, rc := range ri.index { if !f(rc) { return @@ -42,17 +37,31 @@ func (ri *RuntimeIndex) Range(f func(rc *Runtime) (cont bool)) { } } -// Delete removes the runtime client by ip. -func (ri *RuntimeIndex) Delete(ip netip.Addr) { - delete(ri.index, ip) +// setInfo sets the client information from cs for runtime client stored by ip. +// If no such client exists, it creates one. +func (ri *runtimeIndex) setInfo(ip netip.Addr, cs Source, hosts []string) (rc *Runtime) { + rc = ri.index[ip] + if rc == nil { + rc = NewRuntime(ip) + ri.add(rc) + } + + rc.setInfo(cs, hosts) + + return rc } -// DeleteBySource removes all runtime clients that have information only from -// the specified source and returns the number of removed clients. -func (ri *RuntimeIndex) DeleteBySource(src Source) (n int) { - for ip, rc := range ri.index { +// clearSource removes information from the specified source from all clients. +func (ri *runtimeIndex) clearSource(src Source) { + for _, rc := range ri.index { rc.unset(src) + } +} +// removeEmpty removes empty runtime clients and returns the number of removed +// clients. +func (ri *runtimeIndex) removeEmpty() (n int) { + for ip, rc := range ri.index { if rc.isEmpty() { delete(ri.index, ip) n++ diff --git a/internal/client/runtimeindex_test.go b/internal/client/runtimeindex_test.go deleted file mode 100644 index 66b975a0..00000000 --- a/internal/client/runtimeindex_test.go +++ /dev/null @@ -1,85 +0,0 @@ -package client_test - -import ( - "net/netip" - "testing" - - "github.com/AdguardTeam/AdGuardHome/internal/client" - "github.com/stretchr/testify/assert" -) - -func TestRuntimeIndex(t *testing.T) { - const cliSrc = client.SourceARP - - var ( - ip1 = netip.MustParseAddr("1.1.1.1") - ip2 = netip.MustParseAddr("2.2.2.2") - ip3 = netip.MustParseAddr("3.3.3.3") - ) - - ri := client.NewRuntimeIndex() - currentSize := 0 - - testCases := []struct { - ip netip.Addr - name string - hosts []string - src client.Source - }{{ - src: cliSrc, - ip: ip1, - name: "1", - hosts: []string{"host1"}, - }, { - src: cliSrc, - ip: ip2, - name: "2", - hosts: []string{"host2"}, - }, { - src: cliSrc, - ip: ip3, - name: "3", - hosts: []string{"host3"}, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - rc := client.NewRuntime(tc.ip) - rc.SetInfo(tc.src, tc.hosts) - - ri.Add(rc) - currentSize++ - - got := ri.Client(tc.ip) - assert.Equal(t, rc, got) - }) - } - - t.Run("size", func(t *testing.T) { - assert.Equal(t, currentSize, ri.Size()) - }) - - t.Run("range", func(t *testing.T) { - s := 0 - - ri.Range(func(rc *client.Runtime) (cont bool) { - s++ - - return true - }) - - assert.Equal(t, currentSize, s) - }) - - t.Run("delete", func(t *testing.T) { - ri.Delete(ip1) - currentSize-- - - assert.Equal(t, currentSize, ri.Size()) - }) - - t.Run("delete_by_src", func(t *testing.T) { - assert.Equal(t, currentSize, ri.DeleteBySource(cliSrc)) - assert.Equal(t, 0, ri.Size()) - }) -} diff --git a/internal/client/storage.go b/internal/client/storage.go index 23bb6ca8..da6dda5c 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -1,30 +1,113 @@ package client import ( + "context" "fmt" "net" "net/netip" "slices" "sync" + "time" - "github.com/AdguardTeam/golibs/container" + "github.com/AdguardTeam/AdGuardHome/internal/arpdb" + "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" + "github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/log" ) -// Config is the client storage configuration structure. -// -// TODO(s.chzhen): Expand. -type Config struct { - // AllowedTags is a list of all allowed client tags. - AllowedTags []string +// allowedTags is the list of available client tags. +var allowedTags = []string{ + "device_audio", + "device_camera", + "device_gameconsole", + "device_laptop", + "device_nas", // Network-attached Storage + "device_other", + "device_pc", + "device_phone", + "device_printer", + "device_securityalarm", + "device_tablet", + "device_tv", + + "os_android", + "os_ios", + "os_linux", + "os_macos", + "os_other", + "os_windows", + + "user_admin", + "user_child", + "user_regular", +} + +// DHCP is an interface for accessing DHCP lease data the [Storage] needs. +type DHCP interface { + // Leases returns all the DHCP leases. + Leases() (leases []*dhcpsvc.Lease) + + // HostByIP returns the hostname of the DHCP client with the given IP + // address. host will be empty if there is no such client, due to an + // assumption that a DHCP client must always have a hostname. + HostByIP(ip netip.Addr) (host string) + + // MACByIP returns the MAC address for the given IP address leased. It + // returns nil if there is no such client, due to an assumption that a DHCP + // client must always have a MAC address. + MACByIP(ip netip.Addr) (mac net.HardwareAddr) +} + +// EmptyDHCP is the empty [DHCP] implementation that does nothing. +type EmptyDHCP struct{} + +// type check +var _ DHCP = EmptyDHCP{} + +// Leases implements the [DHCP] interface for emptyDHCP. +func (EmptyDHCP) Leases() (leases []*dhcpsvc.Lease) { return nil } + +// HostByIP implements the [DHCP] interface for emptyDHCP. +func (EmptyDHCP) HostByIP(_ netip.Addr) (host string) { return "" } + +// MACByIP implements the [DHCP] interface for emptyDHCP. +func (EmptyDHCP) MACByIP(_ netip.Addr) (mac net.HardwareAddr) { return nil } + +// HostsContainer is an interface for receiving updates to the system hosts +// file. +type HostsContainer interface { + Upd() (updates <-chan *hostsfile.DefaultStorage) +} + +// StorageConfig is the client storage configuration structure. +type StorageConfig struct { + // DHCP is used to match IPs against MACs of persistent clients and update + // [SourceDHCP] runtime client information. It must not be nil. + DHCP DHCP + + // EtcHosts is used to update [SourceHostsFile] runtime client information. + EtcHosts HostsContainer + + // ARPDB is used to update [SourceARP] runtime client information. + ARPDB arpdb.Interface + + // InitialClients is a list of persistent clients parsed from the + // configuration file. Each client must not be nil. + InitialClients []*Persistent + + // ARPClientsUpdatePeriod defines how often [SourceARP] runtime client + // information is updated. + ARPClientsUpdatePeriod time.Duration + + // RuntimeSourceDHCP specifies whether to update [SourceDHCP] information + // of runtime clients. + RuntimeSourceDHCP bool } // Storage contains information about persistent and runtime clients. type Storage struct { - // allowedTags is a set of all allowed tags. - allowedTags *container.MapSet[string] - // mu protects indexes of persistent and runtime clients. mu *sync.Mutex @@ -32,19 +115,250 @@ type Storage struct { index *index // runtimeIndex contains information about runtime clients. - runtimeIndex *RuntimeIndex + runtimeIndex *runtimeIndex + + // dhcp is used to update [SourceDHCP] runtime client information. + dhcp DHCP + + // etcHosts is used to update [SourceHostsFile] runtime client information. + etcHosts HostsContainer + + // arpDB is used to update [SourceARP] runtime client information. + arpDB arpdb.Interface + + // done is the shutdown signaling channel. + done chan struct{} + + // allowedTags is a sorted list of all allowed tags. It must not be + // modified after initialization. + // + // TODO(s.chzhen): Use custom type. + allowedTags []string + + // arpClientsUpdatePeriod defines how often [SourceARP] runtime client + // information is updated. It must be greater than zero. + arpClientsUpdatePeriod time.Duration + + // runtimeSourceDHCP specifies whether to update [SourceDHCP] information + // of runtime clients. + runtimeSourceDHCP bool } // NewStorage returns initialized client storage. conf must not be nil. -func NewStorage(conf *Config) (s *Storage) { - allowedTags := container.NewMapSet(conf.AllowedTags...) +func NewStorage(conf *StorageConfig) (s *Storage, err error) { + tags := slices.Clone(allowedTags) + slices.Sort(tags) - return &Storage{ - allowedTags: allowedTags, - mu: &sync.Mutex{}, - index: newIndex(), - runtimeIndex: NewRuntimeIndex(), + s = &Storage{ + allowedTags: tags, + mu: &sync.Mutex{}, + index: newIndex(), + runtimeIndex: newRuntimeIndex(), + dhcp: conf.DHCP, + etcHosts: conf.EtcHosts, + arpDB: conf.ARPDB, + done: make(chan struct{}), + arpClientsUpdatePeriod: conf.ARPClientsUpdatePeriod, + runtimeSourceDHCP: conf.RuntimeSourceDHCP, } + + for i, p := range conf.InitialClients { + err = s.Add(p) + if err != nil { + return nil, fmt.Errorf("adding client %q at index %d: %w", p.Name, i, err) + } + } + + s.ReloadARP() + + return s, nil +} + +// Start starts the goroutines for updating the runtime client information. +// +// TODO(s.chzhen): Pass context. +func (s *Storage) Start(_ context.Context) (err error) { + go s.periodicARPUpdate() + go s.handleHostsUpdates() + + return nil +} + +// Shutdown gracefully stops the client storage. +// +// TODO(s.chzhen): Pass context. +func (s *Storage) Shutdown(_ context.Context) (err error) { + close(s.done) + + return s.closeUpstreams() +} + +// periodicARPUpdate periodically reloads runtime clients from ARP. It is +// intended to be used as a goroutine. +func (s *Storage) periodicARPUpdate() { + defer log.OnPanic("storage") + + t := time.NewTicker(s.arpClientsUpdatePeriod) + + for { + select { + case <-t.C: + s.ReloadARP() + case <-s.done: + return + } + } +} + +// ReloadARP reloads runtime clients from ARP, if configured. +func (s *Storage) ReloadARP() { + if s.arpDB != nil { + s.addFromSystemARP() + } +} + +// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a +// command. +func (s *Storage) addFromSystemARP() { + s.mu.Lock() + defer s.mu.Unlock() + + if err := s.arpDB.Refresh(); err != nil { + s.arpDB = arpdb.Empty{} + log.Error("refreshing arp container: %s", err) + + return + } + + ns := s.arpDB.Neighbors() + if len(ns) == 0 { + log.Debug("refreshing arp container: the update is empty") + + return + } + + src := SourceARP + s.runtimeIndex.clearSource(src) + + for _, n := range ns { + s.runtimeIndex.setInfo(n.IP, src, []string{n.Name}) + } + + removed := s.runtimeIndex.removeEmpty() + + log.Debug("storage: added %d, removed %d client aliases from arp neighborhood", len(ns), removed) +} + +// handleHostsUpdates receives the updates from the hosts container and adds +// them to the clients storage. It is intended to be used as a goroutine. +func (s *Storage) handleHostsUpdates() { + if s.etcHosts == nil { + return + } + + defer log.OnPanic("storage") + + for { + select { + case upd, ok := <-s.etcHosts.Upd(): + if !ok { + return + } + + s.addFromHostsFile(upd) + case <-s.done: + return + } + } +} + +// addFromHostsFile fills the client-hostname pairing index from the system's +// hosts files. +func (s *Storage) addFromHostsFile(hosts *hostsfile.DefaultStorage) { + s.mu.Lock() + defer s.mu.Unlock() + + src := SourceHostsFile + s.runtimeIndex.clearSource(src) + + added := 0 + hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) { + // Only the first name of the first record is considered a canonical + // hostname for the IP address. + // + // TODO(e.burkov): Consider using all the names from all the records. + s.runtimeIndex.setInfo(addr, src, []string{names[0]}) + added++ + + return true + }) + + removed := s.runtimeIndex.removeEmpty() + log.Debug("storage: added %d, removed %d client aliases from system hosts file", added, removed) +} + +// type check +var _ AddressUpdater = (*Storage)(nil) + +// UpdateAddress implements the [AddressUpdater] interface for *Storage +func (s *Storage) UpdateAddress(ip netip.Addr, host string, info *whois.Info) { + // Common fast path optimization. + if host == "" && info == nil { + return + } + + s.mu.Lock() + defer s.mu.Unlock() + + if host != "" { + s.runtimeIndex.setInfo(ip, SourceRDNS, []string{host}) + } + + if info != nil { + s.setWHOISInfo(ip, info) + } +} + +// UpdateDHCP updates [SourceDHCP] runtime client information. +func (s *Storage) UpdateDHCP() { + if s.dhcp == nil || !s.runtimeSourceDHCP { + return + } + + s.mu.Lock() + defer s.mu.Unlock() + + src := SourceDHCP + s.runtimeIndex.clearSource(src) + + added := 0 + for _, l := range s.dhcp.Leases() { + s.runtimeIndex.setInfo(l.IP, src, []string{l.Hostname}) + added++ + } + + removed := s.runtimeIndex.removeEmpty() + log.Debug("storage: added %d, removed %d client aliases from dhcp", added, removed) +} + +// setWHOISInfo sets the WHOIS information for a runtime client. +func (s *Storage) setWHOISInfo(ip netip.Addr, wi *whois.Info) { + _, ok := s.index.findByIP(ip) + if ok { + log.Debug("storage: client for %s is already created, ignore whois info", ip) + + return + } + + rc := s.runtimeIndex.client(ip) + if rc == nil { + rc = NewRuntime(ip) + s.runtimeIndex.add(rc) + } + + rc.setWHOIS(wi) + + log.Debug("storage: set whois info for runtime client with ip %s: %+v", ip, wi) } // Add stores persistent client information or returns an error. @@ -94,6 +408,9 @@ func (s *Storage) FindByName(name string) (p *Persistent, ok bool) { // Find finds persistent client by string representation of the client ID, IP // address, or MAC. And returns its shallow copy. +// +// TODO(s.chzhen): Accept ClientIDData structure instead, which will contain +// the parsed IP address, if any. func (s *Storage) Find(id string) (p *Persistent, ok bool) { s.mu.Lock() defer s.mu.Unlock() @@ -103,6 +420,16 @@ func (s *Storage) Find(id string) (p *Persistent, ok bool) { return p.ShallowClone(), ok } + ip, err := netip.ParseAddr(id) + if err != nil { + return nil, false + } + + foundMAC := s.dhcp.MACByIP(ip) + if foundMAC != nil { + return s.FindByMAC(foundMAC) + } + return nil, false } @@ -130,11 +457,9 @@ func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) { return nil, false } -// FindByMAC finds persistent client by MAC and returns its shallow copy. +// FindByMAC finds persistent client by MAC and returns its shallow copy. s.mu +// is expected to be locked. func (s *Storage) FindByMAC(mac net.HardwareAddr) (p *Persistent, ok bool) { - s.mu.Lock() - defer s.mu.Unlock() - p, ok = s.index.findByMAC(mac) if ok { return p.ShallowClone(), ok @@ -216,8 +541,8 @@ func (s *Storage) Size() (n int) { return s.index.size() } -// CloseUpstreams closes upstream configurations of persistent clients. -func (s *Storage) CloseUpstreams() (err error) { +// closeUpstreams closes upstream configurations of persistent clients. +func (s *Storage) closeUpstreams() (err error) { s.mu.Lock() defer s.mu.Unlock() @@ -226,89 +551,27 @@ func (s *Storage) CloseUpstreams() (err error) { // ClientRuntime returns a copy of the saved runtime client by ip. If no such // client exists, returns nil. -// -// TODO(s.chzhen): Use it. func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) { s.mu.Lock() defer s.mu.Unlock() - return s.runtimeIndex.Client(ip) -} - -// UpdateRuntime updates the stored runtime client with information from rc. If -// no such client exists, saves the copy of rc in storage. rc must not be nil. -func (s *Storage) UpdateRuntime(rc *Runtime) (added bool) { - s.mu.Lock() - defer s.mu.Unlock() - - return s.updateRuntimeLocked(rc) -} - -// updateRuntimeLocked updates the stored runtime client with information from -// rc. rc must not be nil. Storage.mu is expected to be locked. -func (s *Storage) updateRuntimeLocked(rc *Runtime) (added bool) { - stored := s.runtimeIndex.Client(rc.ip) - if stored == nil { - s.runtimeIndex.Add(rc.Clone()) - - return true + rc = s.runtimeIndex.client(ip) + if rc != nil { + return rc.clone() } - if rc.whois != nil { - stored.whois = rc.whois.Clone() + if !s.runtimeSourceDHCP { + return nil } - if rc.arp != nil { - stored.arp = slices.Clone(rc.arp) + host := s.dhcp.HostByIP(ip) + if host == "" { + return nil } - if rc.rdns != nil { - stored.rdns = slices.Clone(rc.rdns) - } + rc = s.runtimeIndex.setInfo(ip, SourceDHCP, []string{host}) - if rc.dhcp != nil { - stored.dhcp = slices.Clone(rc.dhcp) - } - - if rc.hostsFile != nil { - stored.hostsFile = slices.Clone(rc.hostsFile) - } - - return false -} - -// BatchUpdateBySource updates the stored runtime clients information from the -// specified source and returns the number of added and removed clients. -func (s *Storage) BatchUpdateBySource(src Source, rcs []*Runtime) (added, removed int) { - s.mu.Lock() - defer s.mu.Unlock() - - for _, rc := range s.runtimeIndex.index { - rc.unset(src) - } - - for _, rc := range rcs { - if s.updateRuntimeLocked(rc) { - added++ - } - } - - for ip, rc := range s.runtimeIndex.index { - if rc.isEmpty() { - delete(s.runtimeIndex.index, ip) - removed++ - } - } - - return added, removed -} - -// SizeRuntime returns the number of the runtime clients. -func (s *Storage) SizeRuntime() (n int) { - s.mu.Lock() - defer s.mu.Unlock() - - return s.runtimeIndex.Size() + return rc.clone() } // RangeRuntime calls f for each runtime client in an undefined order. @@ -316,16 +579,11 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) { s.mu.Lock() defer s.mu.Unlock() - s.runtimeIndex.Range(f) + s.runtimeIndex.rangeClients(f) } -// DeleteBySource removes all runtime clients that have information only from -// the specified source and returns the number of removed clients. -// -// TODO(s.chzhen): Use it. -func (s *Storage) DeleteBySource(src Source) (n int) { - s.mu.Lock() - defer s.mu.Unlock() - - return s.runtimeIndex.DeleteBySource(src) +// AllowedTags returns the list of available client tags. tags must not be +// modified. +func (s *Storage) AllowedTags() (tags []string) { + return s.allowedTags } diff --git a/internal/client/storage_test.go b/internal/client/storage_test.go index 5ac02747..60d766d0 100644 --- a/internal/client/storage_test.go +++ b/internal/client/storage_test.go @@ -3,23 +3,513 @@ package client_test import ( "net" "net/netip" + "runtime" + "slices" + "sync" "testing" + "time" + "github.com/AdguardTeam/AdGuardHome/internal/arpdb" "github.com/AdguardTeam/AdGuardHome/internal/client" + "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" + "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/AdGuardHome/internal/whois" + "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +// testHostsContainer is a mock implementation of the [client.HostsContainer] +// interface. +type testHostsContainer struct { + onUpd func() (updates <-chan *hostsfile.DefaultStorage) +} + +// type check +var _ client.HostsContainer = (*testHostsContainer)(nil) + +// Upd implements the [client.HostsContainer] interface for *testHostsContainer. +func (c *testHostsContainer) Upd() (updates <-chan *hostsfile.DefaultStorage) { + return c.onUpd() +} + +// Interface stores and refreshes the network neighborhood reported by ARP +// (Address Resolution Protocol). +type Interface interface { + // Refresh updates the stored data. It must be safe for concurrent use. + Refresh() (err error) + + // Neighbors returnes the last set of data reported by ARP. Both the method + // and it's result must be safe for concurrent use. + Neighbors() (ns []arpdb.Neighbor) +} + +// testARPDB is a mock implementation of the [arpdb.Interface]. +type testARPDB struct { + onRefresh func() (err error) + onNeighbors func() (ns []arpdb.Neighbor) +} + +// type check +var _ arpdb.Interface = (*testARPDB)(nil) + +// Refresh implements the [arpdb.Interface] interface for *testARP. +func (c *testARPDB) Refresh() (err error) { + return c.onRefresh() +} + +// Neighbors implements the [arpdb.Interface] interface for *testARP. +func (c *testARPDB) Neighbors() (ns []arpdb.Neighbor) { + return c.onNeighbors() +} + +// testDHCP is a mock implementation of the [client.DHCP]. +type testDHCP struct { + OnLeases func() (leases []*dhcpsvc.Lease) + OnHostBy func(ip netip.Addr) (host string) + OnMACBy func(ip netip.Addr) (mac net.HardwareAddr) +} + +// type check +var _ client.DHCP = (*testDHCP)(nil) + +// Lease implements the [client.DHCP] interface for *testDHCP. +func (t *testDHCP) Leases() (leases []*dhcpsvc.Lease) { return t.OnLeases() } + +// HostByIP implements the [client.DHCP] interface for *testDHCP. +func (t *testDHCP) HostByIP(ip netip.Addr) (host string) { return t.OnHostBy(ip) } + +// MACByIP implements the [client.DHCP] interface for *testDHCP. +func (t *testDHCP) MACByIP(ip netip.Addr) (mac net.HardwareAddr) { return t.OnMACBy(ip) } + +// compareRuntimeInfo is a helper function that returns true if the runtime +// client has provided info. +func compareRuntimeInfo(rc *client.Runtime, src client.Source, host string) (ok bool) { + s, h := rc.Info() + if s != src { + return false + } else if h != host { + return false + } + + return true +} + +func TestStorage_Add_hostsfile(t *testing.T) { + var ( + cliIP1 = netip.MustParseAddr("1.1.1.1") + cliName1 = "client_one" + + cliIP2 = netip.MustParseAddr("2.2.2.2") + cliName2 = "client_two" + ) + + hostCh := make(chan *hostsfile.DefaultStorage) + h := &testHostsContainer{ + onUpd: func() (updates <-chan *hostsfile.DefaultStorage) { return hostCh }, + } + + storage, err := client.NewStorage(&client.StorageConfig{ + DHCP: client.EmptyDHCP{}, + EtcHosts: h, + ARPClientsUpdatePeriod: testTimeout / 10, + }) + require.NoError(t, err) + + err = storage.Start(testutil.ContextWithTimeout(t, testTimeout)) + require.NoError(t, err) + + testutil.CleanupAndRequireSuccess(t, func() (err error) { + return storage.Shutdown(testutil.ContextWithTimeout(t, testTimeout)) + }) + + t.Run("add_hosts", func(t *testing.T) { + var s *hostsfile.DefaultStorage + s, err = hostsfile.NewDefaultStorage() + require.NoError(t, err) + + s.Add(&hostsfile.Record{ + Addr: cliIP1, + Names: []string{cliName1}, + }) + + testutil.RequireSend(t, hostCh, s, testTimeout) + + require.Eventually(t, func() (ok bool) { + cli1 := storage.ClientRuntime(cliIP1) + if cli1 == nil { + return false + } + + assert.True(t, compareRuntimeInfo(cli1, client.SourceHostsFile, cliName1)) + + return true + }, testTimeout, testTimeout/10) + }) + + t.Run("update_hosts", func(t *testing.T) { + var s *hostsfile.DefaultStorage + s, err = hostsfile.NewDefaultStorage() + require.NoError(t, err) + + s.Add(&hostsfile.Record{ + Addr: cliIP2, + Names: []string{cliName2}, + }) + + testutil.RequireSend(t, hostCh, s, testTimeout) + + require.Eventually(t, func() (ok bool) { + cli2 := storage.ClientRuntime(cliIP2) + if cli2 == nil { + return false + } + + assert.True(t, compareRuntimeInfo(cli2, client.SourceHostsFile, cliName2)) + + cli1 := storage.ClientRuntime(cliIP1) + require.Nil(t, cli1) + + return true + }, testTimeout, testTimeout/10) + }) +} + +func TestStorage_Add_arp(t *testing.T) { + var ( + mu sync.Mutex + neighbors []arpdb.Neighbor + + cliIP1 = netip.MustParseAddr("1.1.1.1") + cliName1 = "client_one" + + cliIP2 = netip.MustParseAddr("2.2.2.2") + cliName2 = "client_two" + ) + + a := &testARPDB{ + onRefresh: func() (err error) { return nil }, + onNeighbors: func() (ns []arpdb.Neighbor) { + mu.Lock() + defer mu.Unlock() + + return neighbors + }, + } + + storage, err := client.NewStorage(&client.StorageConfig{ + DHCP: client.EmptyDHCP{}, + ARPDB: a, + ARPClientsUpdatePeriod: testTimeout / 10, + }) + require.NoError(t, err) + + err = storage.Start(testutil.ContextWithTimeout(t, testTimeout)) + require.NoError(t, err) + + testutil.CleanupAndRequireSuccess(t, func() (err error) { + return storage.Shutdown(testutil.ContextWithTimeout(t, testTimeout)) + }) + + t.Run("add_hosts", func(t *testing.T) { + func() { + mu.Lock() + defer mu.Unlock() + + neighbors = []arpdb.Neighbor{{ + Name: cliName1, + IP: cliIP1, + }} + }() + + require.Eventually(t, func() (ok bool) { + cli1 := storage.ClientRuntime(cliIP1) + if cli1 == nil { + return false + } + + assert.True(t, compareRuntimeInfo(cli1, client.SourceARP, cliName1)) + + return true + }, testTimeout, testTimeout/10) + }) + + t.Run("update_hosts", func(t *testing.T) { + func() { + mu.Lock() + defer mu.Unlock() + + neighbors = []arpdb.Neighbor{{ + Name: cliName2, + IP: cliIP2, + }} + }() + + require.Eventually(t, func() (ok bool) { + cli2 := storage.ClientRuntime(cliIP2) + if cli2 == nil { + return false + } + + assert.True(t, compareRuntimeInfo(cli2, client.SourceARP, cliName2)) + + cli1 := storage.ClientRuntime(cliIP1) + require.Nil(t, cli1) + + return true + }, testTimeout, testTimeout/10) + }) +} + +func TestStorage_Add_whois(t *testing.T) { + var ( + cliIP1 = netip.MustParseAddr("1.1.1.1") + + cliIP2 = netip.MustParseAddr("2.2.2.2") + cliName2 = "client_two" + + cliIP3 = netip.MustParseAddr("3.3.3.3") + cliName3 = "client_three" + ) + + storage, err := client.NewStorage(&client.StorageConfig{ + DHCP: client.EmptyDHCP{}, + }) + require.NoError(t, err) + + whois := &whois.Info{ + Country: "AU", + Orgname: "Example Org", + } + + t.Run("new_client", func(t *testing.T) { + storage.UpdateAddress(cliIP1, "", whois) + cli1 := storage.ClientRuntime(cliIP1) + require.NotNil(t, cli1) + + assert.Equal(t, whois, cli1.WHOIS()) + }) + + t.Run("existing_runtime_client", func(t *testing.T) { + storage.UpdateAddress(cliIP2, cliName2, nil) + storage.UpdateAddress(cliIP2, "", whois) + + cli2 := storage.ClientRuntime(cliIP2) + require.NotNil(t, cli2) + + assert.True(t, compareRuntimeInfo(cli2, client.SourceRDNS, cliName2)) + + assert.Equal(t, whois, cli2.WHOIS()) + }) + + t.Run("can't_set_persistent_client", func(t *testing.T) { + err = storage.Add(&client.Persistent{ + Name: cliName3, + UID: client.MustNewUID(), + IPs: []netip.Addr{cliIP3}, + }) + require.NoError(t, err) + + storage.UpdateAddress(cliIP3, "", whois) + rc := storage.ClientRuntime(cliIP3) + require.Nil(t, rc) + }) +} + +func TestClientsDHCP(t *testing.T) { + var ( + cliIP1 = netip.MustParseAddr("1.1.1.1") + cliName1 = "one.dhcp" + + cliIP2 = netip.MustParseAddr("2.2.2.2") + cliMAC2 = mustParseMAC("22:22:22:22:22:22") + cliName2 = "two.dhcp" + + cliIP3 = netip.MustParseAddr("3.3.3.3") + cliMAC3 = mustParseMAC("33:33:33:33:33:33") + cliName3 = "three.dhcp" + + prsCliIP = netip.MustParseAddr("4.3.2.1") + prsCliMAC = mustParseMAC("AA:AA:AA:AA:AA:AA") + prsCliName = "persistent.dhcp" + ) + + ipToHost := map[netip.Addr]string{ + cliIP1: cliName1, + } + ipToMAC := map[netip.Addr]net.HardwareAddr{ + prsCliIP: prsCliMAC, + } + + leases := []*dhcpsvc.Lease{{ + IP: cliIP2, + Hostname: cliName2, + HWAddr: cliMAC2, + }, { + IP: cliIP3, + Hostname: cliName3, + HWAddr: cliMAC3, + }} + + d := &testDHCP{ + OnLeases: func() (ls []*dhcpsvc.Lease) { + return leases + }, + OnHostBy: func(ip netip.Addr) (host string) { + return ipToHost[ip] + }, + OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) { + return ipToMAC[ip] + }, + } + + storage, err := client.NewStorage(&client.StorageConfig{ + DHCP: d, + RuntimeSourceDHCP: true, + }) + require.NoError(t, err) + + t.Run("find_runtime", func(t *testing.T) { + cli1 := storage.ClientRuntime(cliIP1) + require.NotNil(t, cli1) + + assert.True(t, compareRuntimeInfo(cli1, client.SourceDHCP, cliName1)) + }) + + t.Run("find_persistent", func(t *testing.T) { + err = storage.Add(&client.Persistent{ + Name: prsCliName, + UID: client.MustNewUID(), + MACs: []net.HardwareAddr{prsCliMAC}, + }) + require.NoError(t, err) + + prsCli, ok := storage.Find(prsCliIP.String()) + require.True(t, ok) + + assert.Equal(t, prsCliName, prsCli.Name) + }) + + t.Run("leases", func(t *testing.T) { + delete(ipToHost, cliIP1) + storage.UpdateDHCP() + + cli1 := storage.ClientRuntime(cliIP1) + require.Nil(t, cli1) + + for i, l := range leases { + cli := storage.ClientRuntime(l.IP) + require.NotNil(t, cli) + + src, host := cli.Info() + assert.Equal(t, client.SourceDHCP, src) + assert.Equal(t, leases[i].Hostname, host) + } + }) + + t.Run("range", func(t *testing.T) { + s := 0 + storage.RangeRuntime(func(rc *client.Runtime) (cont bool) { + s++ + + return true + }) + + assert.Equal(t, len(leases), s) + }) +} + +func TestClientsAddExisting(t *testing.T) { + t.Run("simple", func(t *testing.T) { + storage, err := client.NewStorage(&client.StorageConfig{ + DHCP: client.EmptyDHCP{}, + }) + require.NoError(t, err) + + ip := netip.MustParseAddr("1.1.1.1") + + // Add a client. + err = storage.Add(&client.Persistent{ + Name: "client1", + UID: client.MustNewUID(), + IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")}, + Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")}, + MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}}, + }) + require.NoError(t, err) + + // Now add an auto-client with the same IP. + storage.UpdateAddress(ip, "test", nil) + rc := storage.ClientRuntime(ip) + assert.True(t, compareRuntimeInfo(rc, client.SourceRDNS, "test")) + }) + + t.Run("complicated", func(t *testing.T) { + // TODO(a.garipov): Properly decouple the DHCP server from the client + // storage. + if runtime.GOOS == "windows" { + t.Skip("skipping dhcp test on windows") + } + + // First, init a DHCP server with a single static lease. + config := &dhcpd.ServerConfig{ + Enabled: true, + DataDir: t.TempDir(), + Conf4: dhcpd.V4ServerConf{ + Enabled: true, + GatewayIP: netip.MustParseAddr("1.2.3.1"), + SubnetMask: netip.MustParseAddr("255.255.255.0"), + RangeStart: netip.MustParseAddr("1.2.3.2"), + RangeEnd: netip.MustParseAddr("1.2.3.10"), + }, + } + + dhcpServer, err := dhcpd.Create(config) + require.NoError(t, err) + + storage, err := client.NewStorage(&client.StorageConfig{ + DHCP: dhcpServer, + }) + require.NoError(t, err) + + ip := netip.MustParseAddr("1.2.3.4") + + err = dhcpServer.AddStaticLease(&dhcpsvc.Lease{ + HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, + IP: ip, + Hostname: "testhost", + Expiry: time.Now().Add(time.Hour), + }) + require.NoError(t, err) + + // Add a new client with the same IP as for a client with MAC. + err = storage.Add(&client.Persistent{ + Name: "client2", + UID: client.MustNewUID(), + IPs: []netip.Addr{ip}, + }) + require.NoError(t, err) + + // Add a new client with the IP from the first client's IP range. + err = storage.Add(&client.Persistent{ + Name: "client3", + UID: client.MustNewUID(), + IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, + }) + require.NoError(t, err) + }) +} + // newStorage is a helper function that returns a client storage filled with // persistent clients from the m. It also generates a UID for each client. func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) { tb.Helper() - s = client.NewStorage(&client.Config{ - AllowedTags: nil, + s, err := client.NewStorage(&client.StorageConfig{ + DHCP: client.EmptyDHCP{}, }) + require.NoError(tb, err) for _, c := range m { c.UID = client.MustNewUID() @@ -31,14 +521,6 @@ func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) { return s } -// newRuntimeClient is a helper function that returns a new runtime client. -func newRuntimeClient(ip netip.Addr, source client.Source, host string) (rc *client.Runtime) { - rc = client.NewRuntime(ip) - rc.SetInfo(source, []string{host}) - - return rc -} - // mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an // error. func mustParseMAC(s string) (mac net.HardwareAddr) { @@ -55,7 +537,7 @@ func TestStorage_Add(t *testing.T) { existingName = "existing_name" existingClientID = "existing_client_id" - allowedTag = "tag" + allowedTag = "user_admin" notAllowedTag = "not_allowed_tag" ) @@ -73,10 +555,20 @@ func TestStorage_Add(t *testing.T) { UID: existingClientUID, } - s := client.NewStorage(&client.Config{ - AllowedTags: []string{allowedTag}, - }) - err := s.Add(existingClient) + s, err := client.NewStorage(&client.StorageConfig{}) + require.NoError(t, err) + + tags := s.AllowedTags() + require.NotZero(t, len(tags)) + require.True(t, slices.IsSorted(tags)) + + _, ok := slices.BinarySearch(tags, allowedTag) + require.True(t, ok) + + _, ok = slices.BinarySearch(tags, notAllowedTag) + require.False(t, ok) + + err = s.Add(existingClient) require.NoError(t, err) testCases := []struct { @@ -136,12 +628,43 @@ func TestStorage_Add(t *testing.T) { }, { name: "not_allowed_tag", cli: &client.Persistent{ - Name: "nont_allowed_tag", + Name: "not_allowed_tag", Tags: []string{notAllowedTag}, IPs: []netip.Addr{netip.MustParseAddr("4.4.4.4")}, UID: client.MustNewUID(), }, wantErrMsg: `adding client: invalid tag: "not_allowed_tag"`, + }, { + name: "allowed_tag", + cli: &client.Persistent{ + Name: "allowed_tag", + Tags: []string{allowedTag}, + IPs: []netip.Addr{netip.MustParseAddr("5.5.5.5")}, + UID: client.MustNewUID(), + }, + wantErrMsg: "", + }, { + name: "", + cli: &client.Persistent{ + Name: "", + IPs: []netip.Addr{netip.MustParseAddr("6.6.6.6")}, + UID: client.MustNewUID(), + }, + wantErrMsg: "adding client: empty name", + }, { + name: "no_id", + cli: &client.Persistent{ + Name: "no_id", + UID: client.MustNewUID(), + }, + wantErrMsg: "adding client: id required", + }, { + name: "no_uid", + cli: &client.Persistent{ + Name: "no_uid", + IPs: []netip.Addr{netip.MustParseAddr("7.7.7.7")}, + }, + wantErrMsg: "adding client: uid required", }} for _, tc := range testCases { @@ -164,10 +687,10 @@ func TestStorage_RemoveByName(t *testing.T) { UID: client.MustNewUID(), } - s := client.NewStorage(&client.Config{ - AllowedTags: nil, - }) - err := s.Add(existingClient) + s, err := client.NewStorage(&client.StorageConfig{}) + require.NoError(t, err) + + err = s.Add(existingClient) require.NoError(t, err) testCases := []struct { @@ -191,9 +714,9 @@ func TestStorage_RemoveByName(t *testing.T) { } t.Run("duplicate_remove", func(t *testing.T) { - s = client.NewStorage(&client.Config{ - AllowedTags: nil, - }) + s, err = client.NewStorage(&client.StorageConfig{}) + require.NoError(t, err) + err = s.Add(existingClient) require.NoError(t, err) @@ -623,157 +1146,3 @@ func TestStorage_RangeByName(t *testing.T) { }) } } - -func TestStorage_UpdateRuntime(t *testing.T) { - const ( - addedARP = "added_arp" - addedSecondARP = "added_arp" - - updatedARP = "updated_arp" - - cliCity = "City" - cliCountry = "Country" - cliOrgname = "Orgname" - ) - - var ( - ip = netip.MustParseAddr("1.1.1.1") - ip2 = netip.MustParseAddr("2.2.2.2") - ) - - updated := client.NewRuntime(ip) - updated.SetInfo(client.SourceARP, []string{updatedARP}) - - info := &whois.Info{ - City: cliCity, - Country: cliCountry, - Orgname: cliOrgname, - } - updated.SetWHOIS(info) - - s := client.NewStorage(&client.Config{ - AllowedTags: nil, - }) - - t.Run("add_arp_client", func(t *testing.T) { - added := client.NewRuntime(ip) - added.SetInfo(client.SourceARP, []string{addedARP}) - - require.True(t, s.UpdateRuntime(added)) - require.Equal(t, 1, s.SizeRuntime()) - - got := s.ClientRuntime(ip) - source, host := got.Info() - assert.Equal(t, client.SourceARP, source) - assert.Equal(t, addedARP, host) - }) - - t.Run("add_second_arp_client", func(t *testing.T) { - added := client.NewRuntime(ip2) - added.SetInfo(client.SourceARP, []string{addedSecondARP}) - - require.True(t, s.UpdateRuntime(added)) - require.Equal(t, 2, s.SizeRuntime()) - - got := s.ClientRuntime(ip2) - source, host := got.Info() - assert.Equal(t, client.SourceARP, source) - assert.Equal(t, addedSecondARP, host) - }) - - t.Run("update_first_client", func(t *testing.T) { - require.False(t, s.UpdateRuntime(updated)) - got := s.ClientRuntime(ip) - require.Equal(t, 2, s.SizeRuntime()) - - source, host := got.Info() - assert.Equal(t, client.SourceARP, source) - assert.Equal(t, updatedARP, host) - }) - - t.Run("remove_arp_info", func(t *testing.T) { - n := s.DeleteBySource(client.SourceARP) - require.Equal(t, 1, n) - require.Equal(t, 1, s.SizeRuntime()) - - got := s.ClientRuntime(ip) - source, _ := got.Info() - assert.Equal(t, client.SourceWHOIS, source) - assert.Equal(t, info, got.WHOIS()) - }) - - t.Run("remove_whois_info", func(t *testing.T) { - n := s.DeleteBySource(client.SourceWHOIS) - require.Equal(t, 1, n) - require.Equal(t, 0, s.SizeRuntime()) - }) -} - -func TestStorage_BatchUpdateBySource(t *testing.T) { - const ( - defSrc = client.SourceARP - - cliFirstHost1 = "host1" - cliFirstHost2 = "host2" - cliUpdatedHost3 = "host3" - cliUpdatedHost4 = "host4" - cliUpdatedHost5 = "host5" - ) - - var ( - cliFirstIP1 = netip.MustParseAddr("1.1.1.1") - cliFirstIP2 = netip.MustParseAddr("2.2.2.2") - cliUpdatedIP3 = netip.MustParseAddr("3.3.3.3") - cliUpdatedIP4 = netip.MustParseAddr("4.4.4.4") - cliUpdatedIP5 = netip.MustParseAddr("5.5.5.5") - ) - - firstClients := []*client.Runtime{ - newRuntimeClient(cliFirstIP1, defSrc, cliFirstHost1), - newRuntimeClient(cliFirstIP2, defSrc, cliFirstHost2), - } - - updatedClients := []*client.Runtime{ - newRuntimeClient(cliUpdatedIP3, defSrc, cliUpdatedHost3), - newRuntimeClient(cliUpdatedIP4, defSrc, cliUpdatedHost4), - newRuntimeClient(cliUpdatedIP5, defSrc, cliUpdatedHost5), - } - - s := client.NewStorage(&client.Config{ - AllowedTags: nil, - }) - - t.Run("populate_storage_with_first_clients", func(t *testing.T) { - added, removed := s.BatchUpdateBySource(defSrc, firstClients) - require.Equal(t, len(firstClients), added) - require.Equal(t, 0, removed) - require.Equal(t, len(firstClients), s.SizeRuntime()) - - rc := s.ClientRuntime(cliFirstIP1) - src, host := rc.Info() - assert.Equal(t, defSrc, src) - assert.Equal(t, cliFirstHost1, host) - }) - - t.Run("update_storage", func(t *testing.T) { - added, removed := s.BatchUpdateBySource(defSrc, updatedClients) - require.Equal(t, len(updatedClients), added) - require.Equal(t, len(firstClients), removed) - require.Equal(t, len(updatedClients), s.SizeRuntime()) - - rc := s.ClientRuntime(cliUpdatedIP3) - src, host := rc.Info() - assert.Equal(t, defSrc, src) - assert.Equal(t, cliUpdatedHost3, host) - - rc = s.ClientRuntime(cliFirstIP1) - assert.Nil(t, rc) - }) - - t.Run("remove_all", func(t *testing.T) { - added, removed := s.BatchUpdateBySource(defSrc, []*client.Runtime{}) - require.Equal(t, 0, added) - require.Equal(t, len(updatedClients), removed) - require.Equal(t, 0, s.SizeRuntime()) - }) -} diff --git a/internal/home/clients.go b/internal/home/clients.go index 341064f4..66a44a62 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -1,8 +1,8 @@ package home import ( + "context" "fmt" - "net" "net/netip" "slices" "sync" @@ -11,7 +11,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/arpdb" "github.com/AdguardTeam/AdGuardHome/internal/client" - "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/querylog" @@ -20,47 +19,18 @@ import ( "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/hostsfile" - "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/stringutil" ) -// DHCP is an interface for accessing DHCP lease data the [clientsContainer] -// needs. -type DHCP interface { - // Leases returns all the DHCP leases. - Leases() (leases []*dhcpsvc.Lease) - - // HostByIP returns the hostname of the DHCP client with the given IP - // address. The address will be netip.Addr{} if there is no such client, - // due to an assumption that a DHCP client must always have a hostname. - HostByIP(ip netip.Addr) (host string) - - // MACByIP returns the MAC address for the given IP address leased. It - // returns nil if there is no such client, due to an assumption that a DHCP - // client must always have a MAC address. - MACByIP(ip netip.Addr) (mac net.HardwareAddr) -} - // clientsContainer is the storage of all runtime and persistent clients. type clientsContainer struct { // storage stores information about persistent clients. storage *client.Storage - // dhcp is the DHCP service implementation. - dhcp DHCP - // clientChecker checks if a client is blocked by the current access // settings. clientChecker BlockedClientChecker - // etcHosts contains list of rewrite rules taken from the operating system's - // hosts database. - etcHosts *aghnet.HostsContainer - - // arpDB stores the neighbors retrieved from ARP. - arpDB arpdb.Interface - // lock protects all fields. // // TODO(a.garipov): Use a pointer and describe which fields are protected in @@ -92,7 +62,7 @@ type BlockedClientChecker interface { // Note: this function must be called only once func (clients *clientsContainer) Init( objects []*clientObject, - dhcpServer DHCP, + dhcpServer client.DHCP, etcHosts *aghnet.HostsContainer, arpDB arpdb.Interface, filteringConf *filtering.Config, @@ -102,26 +72,15 @@ func (clients *clientsContainer) Init( return errors.Error("clients container already initialized") } - clients.storage = client.NewStorage(&client.Config{ - AllowedTags: clientTags, - }) + confClients := make([]*client.Persistent, 0, len(objects)) + for i, o := range objects { + var p *client.Persistent + p, err = o.toPersistent(filteringConf) + if err != nil { + return fmt.Errorf("init persistent client at index %d: %w", i, err) + } - // TODO(e.burkov): Use [dhcpsvc] implementation when it's ready. - clients.dhcp = dhcpServer - - clients.etcHosts = etcHosts - clients.arpDB = arpDB - err = clients.addFromConfig(objects, filteringConf) - if err != nil { - // Don't wrap the error, because it's informative enough as is. - return err - } - - clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize - clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime) - - if clients.testing { - return nil + confClients = append(confClients, p) } // The clients.etcHosts may be nil even if config.Clients.Sources.HostsFile @@ -130,21 +89,26 @@ func (clients *clientsContainer) Init( // TODO(e.burkov): The option should probably be returned, since hosts file // currently used not only for clients' information enrichment, but also in // the filtering module and upstream addresses resolution. - if config.Clients.Sources.HostsFile && clients.etcHosts != nil { - go clients.handleHostsUpdates() + var hosts client.HostsContainer = etcHosts + if !config.Clients.Sources.HostsFile { + hosts = nil + } + + clients.storage, err = client.NewStorage(&client.StorageConfig{ + InitialClients: confClients, + DHCP: dhcpServer, + EtcHosts: hosts, + ARPDB: arpDB, + ARPClientsUpdatePeriod: arpClientsUpdatePeriod, + RuntimeSourceDHCP: config.Clients.Sources.DHCP, + }) + if err != nil { + return fmt.Errorf("init client storage: %w", err) } return nil } -// handleHostsUpdates receives the updates from the hosts container and adds -// them to the clients container. It is intended to be used as a goroutine. -func (clients *clientsContainer) handleHostsUpdates() { - for upd := range clients.etcHosts.Upd() { - clients.addFromHostsFile(upd) - } -} - // webHandlersRegistered prevents a [clientsContainer] from registering its web // handlers more than once. // @@ -152,7 +116,7 @@ func (clients *clientsContainer) handleHostsUpdates() { var webHandlersRegistered = false // Start starts the clients container. -func (clients *clientsContainer) Start() { +func (clients *clientsContainer) Start(ctx context.Context) (err error) { if clients.testing { return } @@ -162,14 +126,7 @@ func (clients *clientsContainer) Start() { clients.registerWebHandlers() } - go clients.periodicUpdate() -} - -// reloadARP reloads runtime clients from ARP, if configured. -func (clients *clientsContainer) reloadARP() { - if clients.arpDB != nil { - clients.addFromSystemARP() - } + return clients.storage.Start(ctx) } // clientObject is the YAML representation of a persistent client. @@ -270,28 +227,6 @@ func (o *clientObject) toPersistent( return cli, nil } -// addFromConfig initializes the clients container with objects from the -// configuration file. -func (clients *clientsContainer) addFromConfig( - objects []*clientObject, - filteringConf *filtering.Config, -) (err error) { - for i, o := range objects { - var cli *client.Persistent - cli, err = o.toPersistent(filteringConf) - if err != nil { - return fmt.Errorf("clients: init persistent client at index %d: %w", i, err) - } - - err = clients.storage.Add(cli) - if err != nil { - return fmt.Errorf("adding client %q at index %d: %w", cli.Name, i, err) - } - } - - return nil -} - // forConfig returns all currently known persistent clients as objects for the // configuration file. func (clients *clientsContainer) forConfig() (objs []*clientObject) { @@ -332,39 +267,6 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) { // arpClientsUpdatePeriod defines how often ARP clients are updated. const arpClientsUpdatePeriod = 10 * time.Minute -func (clients *clientsContainer) periodicUpdate() { - defer log.OnPanic("clients container") - - for { - clients.reloadARP() - time.Sleep(arpClientsUpdatePeriod) - } -} - -// clientSource checks if client with this IP address already exists and returns -// the source which updated it last. It returns [client.SourceNone] if the -// client doesn't exist. Note that it is only used in tests. -func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source) { - clients.lock.Lock() - defer clients.lock.Unlock() - - _, ok := clients.findLocked(ip.String()) - if ok { - return client.SourcePersistent - } - - rc := clients.storage.ClientRuntime(ip) - if rc != nil { - src, _ = rc.Info() - } - - if src < client.SourceDHCP && clients.dhcp.HostByIP(ip) != "" { - src = client.SourceDHCP - } - - return src -} - // findMultiple is a wrapper around [clientsContainer.find] to make it a valid // client finder for the query log. c is never nil; if no information about the // client is found, it returns an artificial client record by only setting the @@ -410,7 +312,7 @@ func (clients *clientsContainer) clientOrArtificial( }, false } - rc := clients.findRuntimeClient(ip) + rc := clients.storage.ClientRuntime(ip) if rc != nil { _, host := rc.Info() @@ -425,19 +327,6 @@ func (clients *clientsContainer) clientOrArtificial( }, true } -// find returns a shallow copy of the client if there is one found. -func (clients *clientsContainer) find(id string) (c *client.Persistent, ok bool) { - clients.lock.Lock() - defer clients.lock.Unlock() - - c, ok = clients.findLocked(id) - if !ok { - return nil, false - } - - return c, true -} - // shouldCountClient is a wrapper around [clientsContainer.find] to make it a // valid client information finder for the statistics. If no information about // the client is found, it returns true. @@ -446,7 +335,7 @@ func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) { defer clients.lock.Unlock() for _, id := range ids { - client, ok := clients.findLocked(id) + client, ok := clients.storage.Find(id) if ok { return !client.IgnoreStatistics } @@ -468,7 +357,7 @@ func (clients *clientsContainer) UpstreamConfigByID( clients.lock.Lock() defer clients.lock.Unlock() - c, ok := clients.findLocked(id) + c, ok := clients.storage.Find(id) if !ok { return nil, nil } else if c.UpstreamConfig != nil { @@ -506,198 +395,17 @@ func (clients *clientsContainer) UpstreamConfigByID( return conf, nil } -// findLocked searches for a client by its ID. clients.lock is expected to be -// locked. -func (clients *clientsContainer) findLocked(id string) (c *client.Persistent, ok bool) { - c, ok = clients.storage.Find(id) - if ok { - return c, true - } - - ip, err := netip.ParseAddr(id) - if err != nil { - return nil, false - } - - // TODO(e.burkov): Iterate through clients.list only once. - return clients.findDHCP(ip) -} - -// findDHCP searches for a client by its MAC, if the DHCP server is active and -// there is such client. clients.lock is expected to be locked. -func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent, ok bool) { - foundMAC := clients.dhcp.MACByIP(ip) - if foundMAC == nil { - return nil, false - } - - return clients.storage.FindByMAC(foundMAC) -} - -// findRuntimeClient finds a runtime client by their IP. -func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime) { - rc = clients.storage.ClientRuntime(ip) - host := clients.dhcp.HostByIP(ip) - - if host != "" { - if rc == nil { - rc = client.NewRuntime(ip) - } - - rc.SetInfo(client.SourceDHCP, []string{host}) - - return rc - } - - return rc -} - -// setWHOISInfo sets the WHOIS information for a client. clients.lock is -// expected to be locked. -func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) { - _, ok := clients.findLocked(ip.String()) - if ok { - log.Debug("clients: client for %s is already created, ignore whois info", ip) - - return - } - - rc := client.NewRuntime(ip) - rc.SetWHOIS(wi) - clients.storage.UpdateRuntime(rc) - - log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi) -} - -// addHost adds a new IP-hostname pairing. The priorities of the sources are -// taken into account. ok is true if the pairing was added. -// -// TODO(a.garipov): Only used in internal tests. Consider removing. -func (clients *clientsContainer) addHost( - ip netip.Addr, - host string, - src client.Source, -) (ok bool) { - clients.lock.Lock() - defer clients.lock.Unlock() - - return clients.addHostLocked(ip, host, src) -} - // type check var _ client.AddressUpdater = (*clientsContainer)(nil) // UpdateAddress implements the [client.AddressUpdater] interface for // *clientsContainer func (clients *clientsContainer) UpdateAddress(ip netip.Addr, host string, info *whois.Info) { - // Common fast path optimization. - if host == "" && info == nil { - return - } - - clients.lock.Lock() - defer clients.lock.Unlock() - - if host != "" { - ok := clients.addHostLocked(ip, host, client.SourceRDNS) - if !ok { - log.Debug("clients: host for client %q already set with higher priority source", ip) - } - } - - if info != nil { - clients.setWHOISInfo(ip, info) - } -} - -// addHostLocked adds a new IP-hostname pairing. clients.lock is expected to be -// locked. -func (clients *clientsContainer) addHostLocked( - ip netip.Addr, - host string, - src client.Source, -) (ok bool) { - rc := client.NewRuntime(ip) - rc.SetInfo(src, []string{host}) - - if config.Clients.Sources.DHCP { - if dhcpHost := clients.dhcp.HostByIP(ip); dhcpHost != "" { - rc.SetInfo(client.SourceDHCP, []string{dhcpHost}) - } - } - - clients.storage.UpdateRuntime(rc) - - log.Debug( - "clients: adding client info %s -> %q %q [%d]", - ip, - src, - host, - clients.storage.SizeRuntime(), - ) - - return true -} - -// addFromHostsFile fills the client-hostname pairing index from the system's -// hosts files. -func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorage) { - clients.lock.Lock() - defer clients.lock.Unlock() - - var rcs []*client.Runtime - hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) { - // Only the first name of the first record is considered a canonical - // hostname for the IP address. - // - // TODO(e.burkov): Consider using all the names from all the records. - rc := client.NewRuntime(addr) - rc.SetInfo(client.SourceHostsFile, []string{names[0]}) - - rcs = append(rcs, rc) - - return true - }) - - added, removed := clients.storage.BatchUpdateBySource(client.SourceHostsFile, rcs) - log.Debug("clients: added %d, removed %d client aliases from system hosts file", added, removed) -} - -// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a -// command. -func (clients *clientsContainer) addFromSystemARP() { - if err := clients.arpDB.Refresh(); err != nil { - log.Error("refreshing arp container: %s", err) - - clients.arpDB = arpdb.Empty{} - - return - } - - ns := clients.arpDB.Neighbors() - if len(ns) == 0 { - log.Debug("refreshing arp container: the update is empty") - - return - } - - clients.lock.Lock() - defer clients.lock.Unlock() - - var rcs []*client.Runtime - for _, n := range ns { - rc := client.NewRuntime(n.IP) - rc.SetInfo(client.SourceARP, []string{n.Name}) - - rcs = append(rcs, rc) - } - - added, removed := clients.storage.BatchUpdateBySource(client.SourceARP, rcs) - log.Debug("clients: added %d, removed %d client aliases from arp neighborhood", added, removed) + clients.storage.UpdateAddress(ip, host, info) } // close gracefully closes all the client-specific upstream configurations of // the persistent clients. -func (clients *clientsContainer) close() (err error) { - return clients.storage.CloseUpstreams() +func (clients *clientsContainer) close(ctx context.Context) (err error) { + return clients.storage.Shutdown(ctx) } diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index f47676f0..c23f4b23 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -3,34 +3,14 @@ package home import ( "net" "net/netip" - "runtime" "testing" - "time" "github.com/AdguardTeam/AdGuardHome/internal/client" - "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" - "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/AdGuardHome/internal/filtering" - "github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -type testDHCP struct { - OnLeases func() (leases []*dhcpsvc.Lease) - OnHostBy func(ip netip.Addr) (host string) - OnMACBy func(ip netip.Addr) (mac net.HardwareAddr) -} - -// Lease implements the [DHCP] interface for testDHCP. -func (t *testDHCP) Leases() (leases []*dhcpsvc.Lease) { return t.OnLeases() } - -// HostByIP implements the [DHCP] interface for testDHCP. -func (t *testDHCP) HostByIP(ip netip.Addr) (host string) { return t.OnHostBy(ip) } - -// MACByIP implements the [DHCP] interface for testDHCP. -func (t *testDHCP) MACByIP(ip netip.Addr) (mac net.HardwareAddr) { return t.OnMACBy(ip) } - // newClientsContainer is a helper that creates a new clients container for // tests. func newClientsContainer(t *testing.T) (c *clientsContainer) { @@ -40,316 +20,11 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) { testing: true, } - dhcp := &testDHCP{ - OnLeases: func() (leases []*dhcpsvc.Lease) { return nil }, - OnHostBy: func(ip netip.Addr) (host string) { return "" }, - OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) { return nil }, - } - - require.NoError(t, c.Init(nil, dhcp, nil, nil, &filtering.Config{})) + require.NoError(t, c.Init(nil, client.EmptyDHCP{}, nil, nil, &filtering.Config{})) return c } -func TestClients(t *testing.T) { - clients := newClientsContainer(t) - - t.Run("add_success", func(t *testing.T) { - var ( - cliNone = "1.2.3.4" - cli1 = "1.1.1.1" - cli2 = "2.2.2.2" - - cli1IP = netip.MustParseAddr(cli1) - cli2IP = netip.MustParseAddr(cli2) - - cliIPv6 = netip.MustParseAddr("1:2:3::4") - ) - - c := &client.Persistent{ - Name: "client1", - UID: client.MustNewUID(), - IPs: []netip.Addr{cli1IP, cliIPv6}, - } - - err := clients.storage.Add(c) - require.NoError(t, err) - - c = &client.Persistent{ - Name: "client2", - UID: client.MustNewUID(), - IPs: []netip.Addr{cli2IP}, - } - - err = clients.storage.Add(c) - require.NoError(t, err) - - c, ok := clients.find(cli1) - require.True(t, ok) - - assert.Equal(t, "client1", c.Name) - - c, ok = clients.find("1:2:3::4") - require.True(t, ok) - - assert.Equal(t, "client1", c.Name) - - c, ok = clients.find(cli2) - require.True(t, ok) - - assert.Equal(t, "client2", c.Name) - - _, ok = clients.find(cliNone) - assert.False(t, ok) - - assert.Equal(t, clients.clientSource(cli1IP), client.SourcePersistent) - assert.Equal(t, clients.clientSource(cli2IP), client.SourcePersistent) - }) - - t.Run("add_fail_name", func(t *testing.T) { - err := clients.storage.Add(&client.Persistent{ - Name: "client1", - UID: client.MustNewUID(), - IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")}, - }) - require.Error(t, err) - }) - - t.Run("add_fail_ip", func(t *testing.T) { - err := clients.storage.Add(&client.Persistent{ - Name: "client3", - UID: client.MustNewUID(), - }) - require.Error(t, err) - }) - - t.Run("update_fail_ip", func(t *testing.T) { - err := clients.storage.Update("client1", &client.Persistent{ - Name: "client1", - UID: client.MustNewUID(), - }) - assert.Error(t, err) - }) - - t.Run("update_success", func(t *testing.T) { - var ( - cliOld = "1.1.1.1" - cliNew = "1.1.1.2" - - cliNewIP = netip.MustParseAddr(cliNew) - ) - - prev, ok := clients.storage.FindByName("client1") - require.True(t, ok) - require.NotNil(t, prev) - - err := clients.storage.Update("client1", &client.Persistent{ - Name: "client1", - UID: prev.UID, - IPs: []netip.Addr{cliNewIP}, - }) - require.NoError(t, err) - - _, ok = clients.find(cliOld) - assert.False(t, ok) - - assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent) - - prev, ok = clients.storage.FindByName("client1") - require.True(t, ok) - require.NotNil(t, prev) - - err = clients.storage.Update("client1", &client.Persistent{ - Name: "client1-renamed", - UID: prev.UID, - IPs: []netip.Addr{cliNewIP}, - UseOwnSettings: true, - }) - require.NoError(t, err) - - c, ok := clients.find(cliNew) - require.True(t, ok) - - assert.Equal(t, "client1-renamed", c.Name) - assert.True(t, c.UseOwnSettings) - - nilCli, ok := clients.storage.FindByName("client1") - require.False(t, ok) - - assert.Nil(t, nilCli) - - require.Len(t, c.IDs(), 1) - - assert.Equal(t, cliNewIP, c.IPs[0]) - }) - - t.Run("del_success", func(t *testing.T) { - ok := clients.storage.RemoveByName("client1-renamed") - require.True(t, ok) - - _, ok = clients.find("1.1.1.2") - assert.False(t, ok) - }) - - t.Run("del_fail", func(t *testing.T) { - ok := clients.storage.RemoveByName("client3") - assert.False(t, ok) - }) - - t.Run("addhost_success", func(t *testing.T) { - ip := netip.MustParseAddr("1.1.1.1") - ok := clients.addHost(ip, "host", client.SourceARP) - assert.True(t, ok) - - ok = clients.addHost(ip, "host2", client.SourceARP) - assert.True(t, ok) - - ok = clients.addHost(ip, "host3", client.SourceHostsFile) - assert.True(t, ok) - - assert.Equal(t, clients.clientSource(ip), client.SourceHostsFile) - }) - - t.Run("dhcp_replaces_arp", func(t *testing.T) { - ip := netip.MustParseAddr("1.2.3.4") - ok := clients.addHost(ip, "from_arp", client.SourceARP) - assert.True(t, ok) - assert.Equal(t, clients.clientSource(ip), client.SourceARP) - - ok = clients.addHost(ip, "from_dhcp", client.SourceDHCP) - assert.True(t, ok) - assert.Equal(t, clients.clientSource(ip), client.SourceDHCP) - }) - - t.Run("addhost_priority", func(t *testing.T) { - ip := netip.MustParseAddr("1.1.1.1") - ok := clients.addHost(ip, "host1", client.SourceRDNS) - assert.True(t, ok) - - assert.Equal(t, client.SourceHostsFile, clients.clientSource(ip)) - }) -} - -func TestClientsWHOIS(t *testing.T) { - clients := newClientsContainer(t) - whois := &whois.Info{ - Country: "AU", - Orgname: "Example Org", - } - - t.Run("new_client", func(t *testing.T) { - ip := netip.MustParseAddr("1.1.1.255") - clients.setWHOISInfo(ip, whois) - rc := clients.storage.ClientRuntime(ip) - require.NotNil(t, rc) - - assert.Equal(t, whois, rc.WHOIS()) - }) - - t.Run("existing_auto-client", func(t *testing.T) { - ip := netip.MustParseAddr("1.1.1.1") - ok := clients.addHost(ip, "host", client.SourceRDNS) - assert.True(t, ok) - - clients.setWHOISInfo(ip, whois) - rc := clients.storage.ClientRuntime(ip) - require.NotNil(t, rc) - - assert.Equal(t, whois, rc.WHOIS()) - }) - - t.Run("can't_set_manually-added", func(t *testing.T) { - ip := netip.MustParseAddr("1.1.1.2") - - err := clients.storage.Add(&client.Persistent{ - Name: "client1", - UID: client.MustNewUID(), - IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")}, - }) - require.NoError(t, err) - - clients.setWHOISInfo(ip, whois) - rc := clients.storage.ClientRuntime(ip) - require.Nil(t, rc) - - assert.True(t, clients.storage.RemoveByName("client1")) - }) -} - -func TestClientsAddExisting(t *testing.T) { - clients := newClientsContainer(t) - - t.Run("simple", func(t *testing.T) { - ip := netip.MustParseAddr("1.1.1.1") - - // Add a client. - err := clients.storage.Add(&client.Persistent{ - Name: "client1", - UID: client.MustNewUID(), - IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")}, - Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")}, - MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}}, - }) - require.NoError(t, err) - - // Now add an auto-client with the same IP. - ok := clients.addHost(ip, "test", client.SourceRDNS) - assert.True(t, ok) - }) - - t.Run("complicated", func(t *testing.T) { - // TODO(a.garipov): Properly decouple the DHCP server from the client - // storage. - if runtime.GOOS == "windows" { - t.Skip("skipping dhcp test on windows") - } - - ip := netip.MustParseAddr("1.2.3.4") - - // First, init a DHCP server with a single static lease. - config := &dhcpd.ServerConfig{ - Enabled: true, - DataDir: t.TempDir(), - Conf4: dhcpd.V4ServerConf{ - Enabled: true, - GatewayIP: netip.MustParseAddr("1.2.3.1"), - SubnetMask: netip.MustParseAddr("255.255.255.0"), - RangeStart: netip.MustParseAddr("1.2.3.2"), - RangeEnd: netip.MustParseAddr("1.2.3.10"), - }, - } - - dhcpServer, err := dhcpd.Create(config) - require.NoError(t, err) - - clients.dhcp = dhcpServer - - err = dhcpServer.AddStaticLease(&dhcpsvc.Lease{ - HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, - IP: ip, - Hostname: "testhost", - Expiry: time.Now().Add(time.Hour), - }) - require.NoError(t, err) - - // Add a new client with the same IP as for a client with MAC. - err = clients.storage.Add(&client.Persistent{ - Name: "client2", - UID: client.MustNewUID(), - IPs: []netip.Addr{ip}, - }) - require.NoError(t, err) - - // Add a new client with the IP from the first client's IP range. - err = clients.storage.Add(&client.Persistent{ - Name: "client3", - UID: client.MustNewUID(), - IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, - }) - require.NoError(t, err) - }) -} - func TestClientsCustomUpstream(t *testing.T) { clients := newClientsContainer(t) diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index f5f061e4..73259d29 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -103,6 +103,8 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http return true }) + clients.storage.UpdateDHCP() + clients.storage.RangeRuntime(func(rc *client.Runtime) (cont bool) { src, host := rc.Info() cj := runtimeClientJSON{ @@ -117,20 +119,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http return true }) - if config.Clients.Sources.DHCP { - for _, l := range clients.dhcp.Leases() { - cj := runtimeClientJSON{ - Name: l.Hostname, - Source: client.SourceDHCP, - IP: l.IP, - WHOIS: &whois.Info{}, - } - - data.RuntimeClients = append(data.RuntimeClients, cj) - } - } - - data.Tags = clientTags + data.Tags = clients.storage.AllowedTags() aghhttp.WriteJSONResponseOK(w, r, data) } @@ -432,7 +421,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http } ip, _ := netip.ParseAddr(idStr) - c, ok := clients.find(idStr) + c, ok := clients.storage.Find(idStr) var cj *clientJSON if !ok { cj = clients.findRuntime(ip, idStr) @@ -454,7 +443,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http // /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be // non-nil. func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) { - rc := clients.findRuntimeClient(ip) + rc := clients.storage.ClientRuntime(ip) if rc == nil { // It is still possible that the IP used to be in the runtime clients // list, but then the server was reloaded. So, check the DNS server's diff --git a/internal/home/clientstags.go b/internal/home/clientstags.go deleted file mode 100644 index b4fc518a..00000000 --- a/internal/home/clientstags.go +++ /dev/null @@ -1,27 +0,0 @@ -package home - -var clientTags = []string{ - "device_audio", - "device_camera", - "device_gameconsole", - "device_laptop", - "device_nas", // Network-attached Storage - "device_other", - "device_pc", - "device_phone", - "device_printer", - "device_securityalarm", - "device_tablet", - "device_tv", - - "os_android", - "os_ios", - "os_linux", - "os_macos", - "os_other", - "os_windows", - - "user_admin", - "user_child", - "user_regular", -} diff --git a/internal/home/dns.go b/internal/home/dns.go index 41e521ce..9dd711f5 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -1,6 +1,7 @@ package home import ( + "context" "fmt" "log/slog" "net" @@ -414,9 +415,9 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte setts.ClientIP = clientIP - c, ok := Context.clients.find(clientID) + c, ok := Context.clients.storage.Find(clientID) if !ok { - c, ok = Context.clients.find(clientIP.String()) + c, ok = Context.clients.storage.Find(clientIP.String()) if !ok { log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID) @@ -459,11 +460,15 @@ func startDNSServer() error { Context.filters.EnableFilters(false) - Context.clients.Start() - - err := Context.dnsServer.Start() + // TODO(s.chzhen): Pass context. + err := Context.clients.Start(context.TODO()) if err != nil { - return fmt.Errorf("couldn't start forwarding DNS server: %w", err) + return fmt.Errorf("starting clients container: %w", err) + } + + err = Context.dnsServer.Start() + if err != nil { + return fmt.Errorf("starting dns server: %w", err) } Context.filters.Start() @@ -500,7 +505,7 @@ func stopDNSServer() (err error) { return fmt.Errorf("stopping forwarding dns server: %w", err) } - err = Context.clients.close() + err = Context.clients.close(context.TODO()) if err != nil { return fmt.Errorf("closing clients container: %w", err) } diff --git a/internal/home/dns_internal_test.go b/internal/home/dns_internal_test.go index 4adaec81..d3712890 100644 --- a/internal/home/dns_internal_test.go +++ b/internal/home/dns_internal_test.go @@ -18,9 +18,8 @@ var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4}) func newStorage(tb testing.TB, clients []*client.Persistent) (s *client.Storage) { tb.Helper() - s = client.NewStorage(&client.Config{ - AllowedTags: nil, - }) + s, err := client.NewStorage(&client.StorageConfig{}) + require.NoError(tb, err) for _, p := range clients { p.UID = client.MustNewUID() diff --git a/internal/home/home.go b/internal/home/home.go index 8c343dbc..69a67223 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -119,7 +119,7 @@ func Main(clientBuildFS fs.FS) { log.Info("Received signal %q", sig) switch sig { case syscall.SIGHUP: - Context.clients.reloadARP() + Context.clients.storage.ReloadARP() Context.tls.reload() default: cleanup(context.Background())