diff --git a/internal/client/storage.go b/internal/client/storage.go index 23bb6ca8..8684fc28 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -1,23 +1,62 @@ package client import ( + "cmp" "fmt" "net" "net/netip" "slices" "sync" + "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" + "github.com/AdguardTeam/AdGuardHome/internal/arpdb" + "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" + "github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/log" ) +// DHCP is an interface for accessing DHCP lease data the [Storage] needs. +type DHCP interface { + // Leases returns all the DHCP leases. + Leases() (leases []*dhcpsvc.Lease) + + // HostByIP returns the hostname of the DHCP client with the given IP + // address. The address will be netip.Addr{} if there is no such client, + // due to an assumption that a DHCP client must always have a hostname. + HostByIP(ip netip.Addr) (host string) + + // MACByIP returns the MAC address for the given IP address leased. It + // returns nil if there is no such client, due to an assumption that a DHCP + // client must always have a MAC address. + MACByIP(ip netip.Addr) (mac net.HardwareAddr) +} + +type emptyDHCP struct{} + +// type check +var _ DHCP = emptyDHCP{} + +func (emptyDHCP) Leases() (_ []*dhcpsvc.Lease) { return nil } + +func (emptyDHCP) HostByIP(_ netip.Addr) (_ string) { return "" } + +func (emptyDHCP) MACByIP(_ netip.Addr) (_ net.HardwareAddr) { return nil } + // Config is the client storage configuration structure. -// -// TODO(s.chzhen): Expand. type Config struct { + DHCP DHCP + EtcHosts *aghnet.HostsContainer + ARPDB arpdb.Interface + // AllowedTags is a list of all allowed client tags. AllowedTags []string + + InitialClients []*Persistent + ARPClientsUpdatePeriod time.Duration } // Storage contains information about persistent and runtime clients. @@ -33,18 +72,156 @@ type Storage struct { // runtimeIndex contains information about runtime clients. runtimeIndex *RuntimeIndex + + dhcp DHCP + etcHosts *aghnet.HostsContainer + arpDB arpdb.Interface + arpClientsUpdatePeriod time.Duration } // NewStorage returns initialized client storage. conf must not be nil. -func NewStorage(conf *Config) (s *Storage) { +func NewStorage(conf *Config) (s *Storage, err error) { allowedTags := container.NewMapSet(conf.AllowedTags...) - - return &Storage{ - allowedTags: allowedTags, - mu: &sync.Mutex{}, - index: newIndex(), - runtimeIndex: NewRuntimeIndex(), + s = &Storage{ + allowedTags: allowedTags, + mu: &sync.Mutex{}, + index: newIndex(), + runtimeIndex: NewRuntimeIndex(), + dhcp: cmp.Or(conf.DHCP, DHCP(emptyDHCP{})), + etcHosts: conf.EtcHosts, + arpDB: conf.ARPDB, + arpClientsUpdatePeriod: conf.ARPClientsUpdatePeriod, } + + for i, p := range conf.InitialClients { + err = s.Add(p) + if err != nil { + return nil, fmt.Errorf("adding client %q at index %d: %w", p.Name, i, err) + } + } + + return s, nil +} + +// Start starts the goroutines for updating the runtime client information. +func (s *Storage) Start() { + go s.periodicARPUpdate() + go s.handleHostsUpdates() +} + +// periodicARPUpdate periodically reloads runtime clients from ARP. It is +// intended to be used as a goroutine. +func (s *Storage) periodicARPUpdate() { + defer log.OnPanic("storage") + + for { + s.ReloadARP() + time.Sleep(s.arpClientsUpdatePeriod) + } +} + +// ReloadARP reloads runtime clients from ARP, if configured. +func (s *Storage) ReloadARP() { + if s.arpDB != nil { + s.addFromSystemARP() + } +} + +// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a +// command. +func (s *Storage) addFromSystemARP() { + if err := s.arpDB.Refresh(); err != nil { + s.arpDB = arpdb.Empty{} + log.Error("refreshing arp container: %s", err) + + return + } + + ns := s.arpDB.Neighbors() + if len(ns) == 0 { + log.Debug("refreshing arp container: the update is empty") + + return + } + + var rcs []*Runtime + for _, n := range ns { + rc := NewRuntime(n.IP) + rc.SetInfo(SourceARP, []string{n.Name}) + + rcs = append(rcs, rc) + } + + added, removed := s.BatchUpdateBySource(SourceARP, rcs) + log.Debug("storage: added %d, removed %d client aliases from arp neighborhood", added, removed) +} + +// handleHostsUpdates receives the updates from the hosts container and adds +// them to the clients storage. It is intended to be used as a goroutine. +func (s *Storage) handleHostsUpdates() { + defer log.OnPanic("storage") + + for upd := range s.etcHosts.Upd() { + s.addFromHostsFile(upd) + } +} + +// addFromHostsFile fills the client-hostname pairing index from the system's +// hosts files. +func (s *Storage) addFromHostsFile(hosts *hostsfile.DefaultStorage) { + var rcs []*Runtime + hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) { + // Only the first name of the first record is considered a canonical + // hostname for the IP address. + // + // TODO(e.burkov): Consider using all the names from all the records. + rc := NewRuntime(addr) + rc.SetInfo(SourceHostsFile, []string{names[0]}) + + rcs = append(rcs, rc) + + return true + }) + + added, removed := s.BatchUpdateBySource(SourceHostsFile, rcs) + log.Debug("storage: added %d, removed %d client aliases from system hosts file", added, removed) +} + +// type check +var _ AddressUpdater = (*Storage)(nil) + +// UpdateAddress implements the [AddressUpdater] interface for *Storage +func (s *Storage) UpdateAddress(ip netip.Addr, host string, info *whois.Info) { + // Common fast path optimization. + if host == "" && info == nil { + return + } + + if host != "" { + rc := NewRuntime(ip) + rc.SetInfo(SourceRDNS, []string{host}) + s.UpdateRuntime(rc) + } + + if info != nil { + s.setWHOISInfo(ip, info) + } +} + +// setWHOISInfo sets the WHOIS information for a runtime client. +func (s *Storage) setWHOISInfo(ip netip.Addr, wi *whois.Info) { + _, ok := s.Find(ip.String()) + if ok { + log.Debug("storage: client for %s is already created, ignore whois info", ip) + + return + } + + rc := NewRuntime(ip) + rc.SetWHOIS(wi) + s.UpdateRuntime(rc) + + log.Debug("storage: set whois info for runtime client with ip %s: %+v", ip, wi) } // Add stores persistent client information or returns an error. @@ -103,6 +280,16 @@ func (s *Storage) Find(id string) (p *Persistent, ok bool) { return p.ShallowClone(), ok } + ip, err := netip.ParseAddr(id) + if err != nil { + return nil, false + } + + foundMAC := s.dhcp.MACByIP(ip) + if foundMAC != nil { + return s.FindByMAC(foundMAC) + } + return nil, false } @@ -130,11 +317,9 @@ func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) { return nil, false } -// FindByMAC finds persistent client by MAC and returns its shallow copy. +// FindByMAC finds persistent client by MAC and returns its shallow copy. s.mu +// is expected to be locked. func (s *Storage) FindByMAC(mac net.HardwareAddr) (p *Persistent, ok bool) { - s.mu.Lock() - defer s.mu.Unlock() - p, ok = s.index.findByMAC(mac) if ok { return p.ShallowClone(), ok @@ -226,13 +411,25 @@ func (s *Storage) CloseUpstreams() (err error) { // ClientRuntime returns a copy of the saved runtime client by ip. If no such // client exists, returns nil. -// -// TODO(s.chzhen): Use it. func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) { s.mu.Lock() defer s.mu.Unlock() - return s.runtimeIndex.Client(ip) + rc = s.runtimeIndex.Client(ip) + if rc != nil { + return rc + } + + host := s.dhcp.HostByIP(ip) + if host == "" { + return nil + } + + rc = NewRuntime(ip) + rc.SetInfo(SourceDHCP, []string{host}) + s.UpdateRuntime(rc) + + return rc } // UpdateRuntime updates the stored runtime client with information from rc. If diff --git a/internal/client/storage_test.go b/internal/client/storage_test.go index 5ac02747..34a35636 100644 --- a/internal/client/storage_test.go +++ b/internal/client/storage_test.go @@ -17,9 +17,10 @@ import ( func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) { tb.Helper() - s = client.NewStorage(&client.Config{ + s, err := client.NewStorage(&client.Config{ AllowedTags: nil, }) + require.NoError(tb, err) for _, c := range m { c.UID = client.MustNewUID() @@ -73,10 +74,12 @@ func TestStorage_Add(t *testing.T) { UID: existingClientUID, } - s := client.NewStorage(&client.Config{ + s, err := client.NewStorage(&client.Config{ AllowedTags: []string{allowedTag}, }) - err := s.Add(existingClient) + require.NoError(t, err) + + err = s.Add(existingClient) require.NoError(t, err) testCases := []struct { @@ -164,10 +167,12 @@ func TestStorage_RemoveByName(t *testing.T) { UID: client.MustNewUID(), } - s := client.NewStorage(&client.Config{ + s, err := client.NewStorage(&client.Config{ AllowedTags: nil, }) - err := s.Add(existingClient) + require.NoError(t, err) + + err = s.Add(existingClient) require.NoError(t, err) testCases := []struct { @@ -191,9 +196,11 @@ func TestStorage_RemoveByName(t *testing.T) { } t.Run("duplicate_remove", func(t *testing.T) { - s = client.NewStorage(&client.Config{ + s, err = client.NewStorage(&client.Config{ AllowedTags: nil, }) + require.NoError(t, err) + err = s.Add(existingClient) require.NoError(t, err) @@ -651,9 +658,10 @@ func TestStorage_UpdateRuntime(t *testing.T) { } updated.SetWHOIS(info) - s := client.NewStorage(&client.Config{ + s, err := client.NewStorage(&client.Config{ AllowedTags: nil, }) + require.NoError(t, err) t.Run("add_arp_client", func(t *testing.T) { added := client.NewRuntime(ip) @@ -739,9 +747,10 @@ func TestStorage_BatchUpdateBySource(t *testing.T) { newRuntimeClient(cliUpdatedIP5, defSrc, cliUpdatedHost5), } - s := client.NewStorage(&client.Config{ + s, err := client.NewStorage(&client.Config{ AllowedTags: nil, }) + require.NoError(t, err) t.Run("populate_storage_with_first_clients", func(t *testing.T) { added, removed := s.BatchUpdateBySource(defSrc, firstClients) diff --git a/internal/home/clients.go b/internal/home/clients.go index 819564bc..0952991c 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -2,7 +2,6 @@ package home import ( "fmt" - "net" "net/netip" "slices" "sync" @@ -11,7 +10,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/arpdb" "github.com/AdguardTeam/AdGuardHome/internal/client" - "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/querylog" @@ -20,47 +18,21 @@ import ( "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/hostsfile" - "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/stringutil" ) -// DHCP is an interface for accessing DHCP lease data the [clientsContainer] -// needs. -type DHCP interface { - // Leases returns all the DHCP leases. - Leases() (leases []*dhcpsvc.Lease) - - // HostByIP returns the hostname of the DHCP client with the given IP - // address. The address will be netip.Addr{} if there is no such client, - // due to an assumption that a DHCP client must always have a hostname. - HostByIP(ip netip.Addr) (host string) - - // MACByIP returns the MAC address for the given IP address leased. It - // returns nil if there is no such client, due to an assumption that a DHCP - // client must always have a MAC address. - MACByIP(ip netip.Addr) (mac net.HardwareAddr) -} - // clientsContainer is the storage of all runtime and persistent clients. type clientsContainer struct { // storage stores information about persistent clients. storage *client.Storage // dhcp is the DHCP service implementation. - dhcp DHCP + dhcp client.DHCP // clientChecker checks if a client is blocked by the current access // settings. clientChecker BlockedClientChecker - // etcHosts contains list of rewrite rules taken from the operating system's - // hosts database. - etcHosts *aghnet.HostsContainer - - // arpDB stores the neighbors retrieved from ARP. - arpDB arpdb.Interface - // lock protects all fields. // // TODO(a.garipov): Use a pointer and describe which fields are protected in @@ -92,7 +64,7 @@ type BlockedClientChecker interface { // Note: this function must be called only once func (clients *clientsContainer) Init( objects []*clientObject, - dhcpServer DHCP, + dhcpServer client.DHCP, etcHosts *aghnet.HostsContainer, arpDB arpdb.Interface, filteringConf *filtering.Config, @@ -102,49 +74,35 @@ func (clients *clientsContainer) Init( return errors.Error("clients container already initialized") } - clients.storage = client.NewStorage(&client.Config{ - AllowedTags: clientTags, - }) - // TODO(e.burkov): Use [dhcpsvc] implementation when it's ready. clients.dhcp = dhcpServer - clients.etcHosts = etcHosts - clients.arpDB = arpDB - err = clients.addFromConfig(objects, filteringConf) + confClients := make([]*client.Persistent, 0, len(objects)) + for i, o := range objects { + var p *client.Persistent + p, err = o.toPersistent(filteringConf) + if err != nil { + return fmt.Errorf("init persistent client at index %d: %w", i, err) + } + + confClients = append(confClients, p) + } + + clients.storage, err = client.NewStorage(&client.Config{ + AllowedTags: clientTags, + InitialClients: confClients, + DHCP: dhcpServer, + EtcHosts: etcHosts, + ARPDB: arpDB, + ARPClientsUpdatePeriod: arpClientsUpdatePeriod, + }) if err != nil { - // Don't wrap the error, because it's informative enough as is. - return err - } - - clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize - clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime) - - if clients.testing { - return nil - } - - // The clients.etcHosts may be nil even if config.Clients.Sources.HostsFile - // is true, because of the deprecated option --no-etc-hosts. - // - // TODO(e.burkov): The option should probably be returned, since hosts file - // currently used not only for clients' information enrichment, but also in - // the filtering module and upstream addresses resolution. - if config.Clients.Sources.HostsFile && clients.etcHosts != nil { - go clients.handleHostsUpdates() + return fmt.Errorf("init client storage: %w", err) } return nil } -// handleHostsUpdates receives the updates from the hosts container and adds -// them to the clients container. It is intended to be used as a goroutine. -func (clients *clientsContainer) handleHostsUpdates() { - for upd := range clients.etcHosts.Upd() { - clients.addFromHostsFile(upd) - } -} - // webHandlersRegistered prevents a [clientsContainer] from registering its web // handlers more than once. // @@ -162,14 +120,7 @@ func (clients *clientsContainer) Start() { clients.registerWebHandlers() } - go clients.periodicUpdate() -} - -// reloadARP reloads runtime clients from ARP, if configured. -func (clients *clientsContainer) reloadARP() { - if clients.arpDB != nil { - clients.addFromSystemARP() - } + clients.storage.Start() } // clientObject is the YAML representation of a persistent client. @@ -270,28 +221,6 @@ func (o *clientObject) toPersistent( return cli, nil } -// addFromConfig initializes the clients container with objects from the -// configuration file. -func (clients *clientsContainer) addFromConfig( - objects []*clientObject, - filteringConf *filtering.Config, -) (err error) { - for i, o := range objects { - var cli *client.Persistent - cli, err = o.toPersistent(filteringConf) - if err != nil { - return fmt.Errorf("clients: init persistent client at index %d: %w", i, err) - } - - err = clients.storage.Add(cli) - if err != nil { - return fmt.Errorf("adding client %q at index %d: %w", cli.Name, i, err) - } - } - - return nil -} - // forConfig returns all currently known persistent clients as objects for the // configuration file. func (clients *clientsContainer) forConfig() (objs []*clientObject) { @@ -332,39 +261,6 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) { // arpClientsUpdatePeriod defines how often ARP clients are updated. const arpClientsUpdatePeriod = 10 * time.Minute -func (clients *clientsContainer) periodicUpdate() { - defer log.OnPanic("clients container") - - for { - clients.reloadARP() - time.Sleep(arpClientsUpdatePeriod) - } -} - -// clientSource checks if client with this IP address already exists and returns -// the source which updated it last. It returns [client.SourceNone] if the -// client doesn't exist. Note that it is only used in tests. -func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source) { - clients.lock.Lock() - defer clients.lock.Unlock() - - _, ok := clients.findLocked(ip.String()) - if ok { - return client.SourcePersistent - } - - rc := clients.storage.ClientRuntime(ip) - if rc != nil { - src, _ = rc.Info() - } - - if src < client.SourceDHCP && clients.dhcp.HostByIP(ip) != "" { - src = client.SourceDHCP - } - - return src -} - // findMultiple is a wrapper around [clientsContainer.find] to make it a valid // client finder for the query log. c is never nil; if no information about the // client is found, it returns an artificial client record by only setting the @@ -410,7 +306,7 @@ func (clients *clientsContainer) clientOrArtificial( }, false } - rc := clients.findRuntimeClient(ip) + rc := clients.storage.ClientRuntime(ip) if rc != nil { _, host := rc.Info() @@ -425,19 +321,6 @@ func (clients *clientsContainer) clientOrArtificial( }, true } -// find returns a shallow copy of the client if there is one found. -func (clients *clientsContainer) find(id string) (c *client.Persistent, ok bool) { - clients.lock.Lock() - defer clients.lock.Unlock() - - c, ok = clients.findLocked(id) - if !ok { - return nil, false - } - - return c, true -} - // shouldCountClient is a wrapper around [clientsContainer.find] to make it a // valid client information finder for the statistics. If no information about // the client is found, it returns true. @@ -446,7 +329,7 @@ func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) { defer clients.lock.Unlock() for _, id := range ids { - client, ok := clients.findLocked(id) + client, ok := clients.storage.Find(id) if ok { return !client.IgnoreStatistics } @@ -468,7 +351,7 @@ func (clients *clientsContainer) UpstreamConfigByID( clients.lock.Lock() defer clients.lock.Unlock() - c, ok := clients.findLocked(id) + c, ok := clients.storage.Find(id) if !ok { return nil, nil } else if c.UpstreamConfig != nil { @@ -506,191 +389,13 @@ func (clients *clientsContainer) UpstreamConfigByID( return conf, nil } -// findLocked searches for a client by its ID. clients.lock is expected to be -// locked. -func (clients *clientsContainer) findLocked(id string) (c *client.Persistent, ok bool) { - c, ok = clients.storage.Find(id) - if ok { - return c, true - } - - ip, err := netip.ParseAddr(id) - if err != nil { - return nil, false - } - - // TODO(e.burkov): Iterate through clients.list only once. - return clients.findDHCP(ip) -} - -// findDHCP searches for a client by its MAC, if the DHCP server is active and -// there is such client. clients.lock is expected to be locked. -func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent, ok bool) { - foundMAC := clients.dhcp.MACByIP(ip) - if foundMAC == nil { - return nil, false - } - - return clients.storage.FindByMAC(foundMAC) -} - -// findRuntimeClient finds a runtime client by their IP. -func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime) { - rc = clients.storage.ClientRuntime(ip) - host := clients.dhcp.HostByIP(ip) - - if host != "" { - if rc == nil { - rc = client.NewRuntime(ip) - } - - rc.SetInfo(client.SourceDHCP, []string{host}) - - return rc - } - - return rc -} - -// setWHOISInfo sets the WHOIS information for a client. clients.lock is -// expected to be locked. -func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) { - _, ok := clients.findLocked(ip.String()) - if ok { - log.Debug("clients: client for %s is already created, ignore whois info", ip) - - return - } - - rc := client.NewRuntime(ip) - rc.SetWHOIS(wi) - clients.storage.UpdateRuntime(rc) - - log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi) -} - -// addHost adds a new IP-hostname pairing. The priorities of the sources are -// taken into account. ok is true if the pairing was added. -// -// TODO(a.garipov): Only used in internal tests. Consider removing. -func (clients *clientsContainer) addHost( - ip netip.Addr, - host string, - src client.Source, -) (ok bool) { - clients.lock.Lock() - defer clients.lock.Unlock() - - return clients.addHostLocked(ip, host, src) -} - // type check var _ client.AddressUpdater = (*clientsContainer)(nil) // UpdateAddress implements the [client.AddressUpdater] interface for // *clientsContainer func (clients *clientsContainer) UpdateAddress(ip netip.Addr, host string, info *whois.Info) { - // Common fast path optimization. - if host == "" && info == nil { - return - } - - clients.lock.Lock() - defer clients.lock.Unlock() - - if host != "" { - ok := clients.addHostLocked(ip, host, client.SourceRDNS) - if !ok { - log.Debug("clients: host for client %q already set with higher priority source", ip) - } - } - - if info != nil { - clients.setWHOISInfo(ip, info) - } -} - -// addHostLocked adds a new IP-hostname pairing. clients.lock is expected to be -// locked. -func (clients *clientsContainer) addHostLocked( - ip netip.Addr, - host string, - src client.Source, -) (ok bool) { - rc := client.NewRuntime(ip) - rc.SetInfo(src, []string{host}) - if dhcpHost := clients.dhcp.HostByIP(ip); dhcpHost != "" { - rc.SetInfo(client.SourceDHCP, []string{dhcpHost}) - } - - clients.storage.UpdateRuntime(rc) - - log.Debug( - "clients: adding client info %s -> %q %q [%d]", - ip, - src, - host, - clients.storage.SizeRuntime(), - ) - - return true -} - -// addFromHostsFile fills the client-hostname pairing index from the system's -// hosts files. -func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorage) { - clients.lock.Lock() - defer clients.lock.Unlock() - - var rcs []*client.Runtime - hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) { - // Only the first name of the first record is considered a canonical - // hostname for the IP address. - // - // TODO(e.burkov): Consider using all the names from all the records. - rc := client.NewRuntime(addr) - rc.SetInfo(client.SourceHostsFile, []string{names[0]}) - - rcs = append(rcs, rc) - - return true - }) - - added, removed := clients.storage.BatchUpdateBySource(client.SourceHostsFile, rcs) - log.Debug("clients: added %d, removed %d client aliases from system hosts file", added, removed) -} - -// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a -// command. -func (clients *clientsContainer) addFromSystemARP() { - if err := clients.arpDB.Refresh(); err != nil { - log.Error("refreshing arp container: %s", err) - - clients.arpDB = arpdb.Empty{} - - return - } - - ns := clients.arpDB.Neighbors() - if len(ns) == 0 { - log.Debug("refreshing arp container: the update is empty") - - return - } - - clients.lock.Lock() - defer clients.lock.Unlock() - - var rcs []*client.Runtime - for _, n := range ns { - rc := client.NewRuntime(n.IP) - rc.SetInfo(client.SourceARP, []string{n.Name}) - - rcs = append(rcs, rc) - } - - added, removed := clients.storage.BatchUpdateBySource(client.SourceARP, rcs) - log.Debug("clients: added %d, removed %d client aliases from arp neighborhood", added, removed) + clients.storage.UpdateAddress(ip, host, info) } // close gracefully closes all the client-specific upstream configurations of diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index f47676f0..1fd004c8 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -51,6 +51,47 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) { return c } +// addHost adds a new IP-hostname pairing. +func (clients *clientsContainer) addHost( + ip netip.Addr, + host string, + src client.Source, +) (ok bool) { + rc := client.NewRuntime(ip) + rc.SetInfo(src, []string{host}) + clients.storage.UpdateRuntime(rc) + + return true +} + +// setWHOISInfo sets the WHOIS information for a client. +func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) { + _, ok := clients.storage.Find(ip.String()) + if ok { + return + } + + rc := client.NewRuntime(ip) + rc.SetWHOIS(wi) + clients.storage.UpdateRuntime(rc) +} + +// clientSource checks if client with this IP address already exists and returns +// the highest priority client source. +func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source) { + _, ok := clients.storage.Find(ip.String()) + if ok { + return client.SourcePersistent + } + + rc := clients.storage.ClientRuntime(ip) + if rc != nil { + src, _ = rc.Info() + } + + return src +} + func TestClients(t *testing.T) { clients := newClientsContainer(t) @@ -84,22 +125,22 @@ func TestClients(t *testing.T) { err = clients.storage.Add(c) require.NoError(t, err) - c, ok := clients.find(cli1) + c, ok := clients.storage.Find(cli1) require.True(t, ok) assert.Equal(t, "client1", c.Name) - c, ok = clients.find("1:2:3::4") + c, ok = clients.storage.Find("1:2:3::4") require.True(t, ok) assert.Equal(t, "client1", c.Name) - c, ok = clients.find(cli2) + c, ok = clients.storage.Find(cli2) require.True(t, ok) assert.Equal(t, "client2", c.Name) - _, ok = clients.find(cliNone) + _, ok = clients.storage.Find(cliNone) assert.False(t, ok) assert.Equal(t, clients.clientSource(cli1IP), client.SourcePersistent) @@ -150,7 +191,7 @@ func TestClients(t *testing.T) { }) require.NoError(t, err) - _, ok = clients.find(cliOld) + _, ok = clients.storage.Find(cliOld) assert.False(t, ok) assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent) @@ -167,7 +208,7 @@ func TestClients(t *testing.T) { }) require.NoError(t, err) - c, ok := clients.find(cliNew) + c, ok := clients.storage.Find(cliNew) require.True(t, ok) assert.Equal(t, "client1-renamed", c.Name) @@ -187,7 +228,7 @@ func TestClients(t *testing.T) { ok := clients.storage.RemoveByName("client1-renamed") require.True(t, ok) - _, ok = clients.find("1.1.1.2") + _, ok = clients.storage.Find("1.1.1.2") assert.False(t, ok) }) @@ -277,13 +318,33 @@ func TestClientsWHOIS(t *testing.T) { } func TestClientsAddExisting(t *testing.T) { - clients := newClientsContainer(t) + clients := &clientsContainer{ + testing: true, + } + + // First, init a DHCP server with a single static lease. + config := &dhcpd.ServerConfig{ + Enabled: true, + DataDir: t.TempDir(), + Conf4: dhcpd.V4ServerConf{ + Enabled: true, + GatewayIP: netip.MustParseAddr("1.2.3.1"), + SubnetMask: netip.MustParseAddr("255.255.255.0"), + RangeStart: netip.MustParseAddr("1.2.3.2"), + RangeEnd: netip.MustParseAddr("1.2.3.10"), + }, + } + + dhcpServer, err := dhcpd.Create(config) + require.NoError(t, err) + + require.NoError(t, clients.Init(nil, dhcpServer, nil, nil, &filtering.Config{})) t.Run("simple", func(t *testing.T) { ip := netip.MustParseAddr("1.1.1.1") // Add a client. - err := clients.storage.Add(&client.Persistent{ + err = clients.storage.Add(&client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")}, @@ -306,24 +367,6 @@ func TestClientsAddExisting(t *testing.T) { ip := netip.MustParseAddr("1.2.3.4") - // First, init a DHCP server with a single static lease. - config := &dhcpd.ServerConfig{ - Enabled: true, - DataDir: t.TempDir(), - Conf4: dhcpd.V4ServerConf{ - Enabled: true, - GatewayIP: netip.MustParseAddr("1.2.3.1"), - SubnetMask: netip.MustParseAddr("255.255.255.0"), - RangeStart: netip.MustParseAddr("1.2.3.2"), - RangeEnd: netip.MustParseAddr("1.2.3.10"), - }, - } - - dhcpServer, err := dhcpd.Create(config) - require.NoError(t, err) - - clients.dhcp = dhcpServer - err = dhcpServer.AddStaticLease(&dhcpsvc.Lease{ HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, IP: ip, diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index eba22ddb..7766ba57 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -117,6 +117,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http return true }) + // TODO(s.chzhen): Remove. for _, l := range clients.dhcp.Leases() { cj := runtimeClientJSON{ Name: l.Hostname, @@ -430,7 +431,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http } ip, _ := netip.ParseAddr(idStr) - c, ok := clients.find(idStr) + c, ok := clients.storage.Find(idStr) var cj *clientJSON if !ok { cj = clients.findRuntime(ip, idStr) @@ -452,7 +453,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http // /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be // non-nil. func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) { - rc := clients.findRuntimeClient(ip) + rc := clients.storage.ClientRuntime(ip) if rc == nil { // It is still possible that the IP used to be in the runtime clients // list, but then the server was reloaded. So, check the DNS server's diff --git a/internal/home/dns.go b/internal/home/dns.go index ed1f1675..44c8f93f 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -412,9 +412,9 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte setts.ClientIP = clientIP - c, ok := Context.clients.find(clientID) + c, ok := Context.clients.storage.Find(clientID) if !ok { - c, ok = Context.clients.find(clientIP.String()) + c, ok = Context.clients.storage.Find(clientIP.String()) if !ok { log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID) diff --git a/internal/home/dns_internal_test.go b/internal/home/dns_internal_test.go index 4adaec81..ca112e59 100644 --- a/internal/home/dns_internal_test.go +++ b/internal/home/dns_internal_test.go @@ -18,9 +18,10 @@ var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4}) func newStorage(tb testing.TB, clients []*client.Persistent) (s *client.Storage) { tb.Helper() - s = client.NewStorage(&client.Config{ + s, err := client.NewStorage(&client.Config{ AllowedTags: nil, }) + require.NoError(tb, err) for _, p := range clients { p.UID = client.MustNewUID() diff --git a/internal/home/home.go b/internal/home/home.go index 8c343dbc..69a67223 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -119,7 +119,7 @@ func Main(clientBuildFS fs.FS) { log.Info("Received signal %q", sig) switch sig { case syscall.SIGHUP: - Context.clients.reloadARP() + Context.clients.storage.ReloadARP() Context.tls.reload() default: cleanup(context.Background())