diff --git a/internal/aghnet/arpdb.go b/internal/aghnet/arpdb.go index 0c2ac7d9..a63e5cd2 100644 --- a/internal/aghnet/arpdb.go +++ b/internal/aghnet/arpdb.go @@ -5,6 +5,7 @@ import ( "bytes" "fmt" "net" + "net/netip" "sync" "github.com/AdguardTeam/golibs/errors" @@ -54,7 +55,7 @@ type Neighbor struct { Name string // IP contains either IPv4 or IPv6. - IP net.IP + IP netip.Addr // MAC contains the hardware address. MAC net.HardwareAddr @@ -64,7 +65,7 @@ type Neighbor struct { func (n Neighbor) Clone() (clone Neighbor) { return Neighbor{ Name: n.Name, - IP: slices.Clone(n.IP), + IP: n.IP, MAC: slices.Clone(n.MAC), } } diff --git a/internal/aghnet/arpdb_bsd.go b/internal/aghnet/arpdb_bsd.go index 9519eeec..c4048939 100644 --- a/internal/aghnet/arpdb_bsd.go +++ b/internal/aghnet/arpdb_bsd.go @@ -5,6 +5,7 @@ package aghnet import ( "bufio" "net" + "net/netip" "strings" "sync" @@ -47,22 +48,28 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) { if ipStr := fields[1]; len(ipStr) < 2 { continue - } else if ip := net.ParseIP(ipStr[1 : len(ipStr)-1]); ip == nil { + } else if ip, err := netip.ParseAddr(ipStr[1 : len(ipStr)-1]); err != nil { + log.Debug("arpdb: parsing arp output: ip: %s", err) + continue } else { n.IP = ip } hwStr := fields[3] - if mac, err := net.ParseMAC(hwStr); err != nil { + mac, err := net.ParseMAC(hwStr) + if err != nil { + log.Debug("arpdb: parsing arp output: mac: %s", err) + continue } else { n.MAC = mac } host := fields[0] - if err := netutil.ValidateDomainName(host); err != nil { - log.Debug("parsing arp output: %s", err) + err = netutil.ValidateDomainName(host) + if err != nil { + log.Debug("arpdb: parsing arp output: host: %s", err) } else { n.Name = host } diff --git a/internal/aghnet/arpdb_bsd_test.go b/internal/aghnet/arpdb_bsd_test.go index 9933c721..b6bd6b9d 100644 --- a/internal/aghnet/arpdb_bsd_test.go +++ b/internal/aghnet/arpdb_bsd_test.go @@ -4,6 +4,7 @@ package aghnet import ( "net" + "net/netip" ) const arpAOutput = ` @@ -17,14 +18,14 @@ hostname.two (::ffff:ffff) at ef:cd:ab:ef:cd:ab on em0 expires in 1198 seconds [ var wantNeighs = []Neighbor{{ Name: "hostname.one", - IP: net.IPv4(192, 168, 1, 2), + IP: netip.MustParseAddr("192.168.1.2"), MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF}, }, { Name: "hostname.two", - IP: net.ParseIP("::ffff:ffff"), + IP: netip.MustParseAddr("::ffff:ffff"), MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB}, }, { Name: "", - IP: net.ParseIP("::1234"), + IP: netip.MustParseAddr("::1234"), MAC: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF}, }} diff --git a/internal/aghnet/arpdb_linux.go b/internal/aghnet/arpdb_linux.go index 82f83adf..e6e51feb 100644 --- a/internal/aghnet/arpdb_linux.go +++ b/internal/aghnet/arpdb_linux.go @@ -7,6 +7,7 @@ import ( "fmt" "io/fs" "net" + "net/netip" "strings" "sync" @@ -94,7 +95,8 @@ func (arp *fsysARPDB) Refresh() (err error) { } n := Neighbor{} - if n.IP = net.ParseIP(fields[0]); n.IP == nil || n.IP.IsUnspecified() { + n.IP, err = netip.ParseAddr(fields[0]) + if err != nil || n.IP.IsUnspecified() { continue } else if n.MAC, err = net.ParseMAC(fields[3]); err != nil { continue @@ -135,15 +137,19 @@ func parseArpAWrt(sc *bufio.Scanner, lenHint int) (ns []Neighbor) { n := Neighbor{} - if ip := net.ParseIP(fields[0]); ip == nil || n.IP.IsUnspecified() { + ip, err := netip.ParseAddr(fields[0]) + if err != nil || n.IP.IsUnspecified() { + log.Debug("arpdb: parsing arp output: ip: %s", err) + continue } else { n.IP = ip } hwStr := fields[3] - if mac, err := net.ParseMAC(hwStr); err != nil { - log.Debug("parsing arp output: %s", err) + mac, err := net.ParseMAC(hwStr) + if err != nil { + log.Debug("arpdb: parsing arp output: mac: %s", err) continue } else { @@ -174,7 +180,9 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) { if ipStr := fields[1]; len(ipStr) < 2 { continue - } else if ip := net.ParseIP(ipStr[1 : len(ipStr)-1]); ip == nil { + } else if ip, err := netip.ParseAddr(ipStr[1 : len(ipStr)-1]); err != nil { + log.Debug("arpdb: parsing arp output: ip: %s", err) + continue } else { n.IP = ip @@ -182,7 +190,7 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) { hwStr := fields[3] if mac, err := net.ParseMAC(hwStr); err != nil { - log.Debug("parsing arp output: %s", err) + log.Debug("arpdb: parsing arp output: mac: %s", err) continue } else { @@ -191,7 +199,7 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) { host := fields[0] if verr := netutil.ValidateDomainName(host); verr != nil { - log.Debug("parsing arp output: %s", verr) + log.Debug("arpdb: parsing arp output: host: %s", verr) } else { n.Name = host } @@ -218,14 +226,18 @@ func parseIPNeigh(sc *bufio.Scanner, lenHint int) (ns []Neighbor) { n := Neighbor{} - if ip := net.ParseIP(fields[0]); ip == nil { + ip, err := netip.ParseAddr(fields[0]) + if err != nil { + log.Debug("arpdb: parsing arp output: ip: %s", err) + continue } else { n.IP = ip } - if mac, err := net.ParseMAC(fields[4]); err != nil { - log.Debug("parsing arp output: %s", err) + mac, err := net.ParseMAC(fields[4]) + if err != nil { + log.Debug("arpdb: parsing arp output: mac: %s", err) continue } else { diff --git a/internal/aghnet/arpdb_linux_test.go b/internal/aghnet/arpdb_linux_test.go index 22fe7135..d07c654d 100644 --- a/internal/aghnet/arpdb_linux_test.go +++ b/internal/aghnet/arpdb_linux_test.go @@ -4,6 +4,7 @@ package aghnet import ( "net" + "net/netip" "sync" "testing" "testing/fstest" @@ -33,10 +34,10 @@ const ipNeighOutput = ` ::ffff:ffff dev enp0s3 lladdr ef:cd:ab:ef:cd:ab router STALE` var wantNeighs = []Neighbor{{ - IP: net.IPv4(192, 168, 1, 2), + IP: netip.MustParseAddr("192.168.1.2"), MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF}, }, { - IP: net.ParseIP("::ffff:ffff"), + IP: netip.MustParseAddr("::ffff:ffff"), MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB}, }} diff --git a/internal/aghnet/arpdb_openbsd.go b/internal/aghnet/arpdb_openbsd.go index 5590f335..2b356d06 100644 --- a/internal/aghnet/arpdb_openbsd.go +++ b/internal/aghnet/arpdb_openbsd.go @@ -5,6 +5,7 @@ package aghnet import ( "bufio" "net" + "net/netip" "strings" "sync" @@ -50,14 +51,18 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) { n := Neighbor{} - if ip := net.ParseIP(fields[0]); ip == nil { + ip, err := netip.ParseAddr(fields[0]) + if err != nil { + log.Debug("arpdb: parsing arp output: ip: %s", err) + continue } else { n.IP = ip } - if mac, err := net.ParseMAC(fields[1]); err != nil { - log.Debug("parsing arp output: %s", err) + mac, err := net.ParseMAC(fields[1]) + if err != nil { + log.Debug("arpdb: parsing arp output: mac: %s", err) continue } else { diff --git a/internal/aghnet/arpdb_openbsd_test.go b/internal/aghnet/arpdb_openbsd_test.go index 0a45514a..a324ed9c 100644 --- a/internal/aghnet/arpdb_openbsd_test.go +++ b/internal/aghnet/arpdb_openbsd_test.go @@ -4,6 +4,7 @@ package aghnet import ( "net" + "net/netip" ) const arpAOutput = ` @@ -15,9 +16,9 @@ Host Ethernet Address Netif Expire Flags ` var wantNeighs = []Neighbor{{ - IP: net.IPv4(192, 168, 1, 2), + IP: netip.MustParseAddr("192.168.1.2"), MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF}, }, { - IP: net.ParseIP("::ffff:ffff"), + IP: netip.MustParseAddr("::ffff:ffff"), MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB}, }} diff --git a/internal/aghnet/arpdb_test.go b/internal/aghnet/arpdb_test.go index d6971448..ab40ab6f 100644 --- a/internal/aghnet/arpdb_test.go +++ b/internal/aghnet/arpdb_test.go @@ -2,6 +2,7 @@ package aghnet import ( "net" + "net/netip" "sync" "testing" @@ -35,7 +36,7 @@ func (arp *TestARPDB) Neighbors() (ns []Neighbor) { } func TestARPDBS(t *testing.T) { - knownIP := net.IP{1, 2, 3, 4} + knownIP := netip.MustParseAddr("1.2.3.4") knownMAC := net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF} succRefrCount, failRefrCount := 0, 0 diff --git a/internal/aghnet/arpdb_windows.go b/internal/aghnet/arpdb_windows.go index f6e27b5b..ed4c8682 100644 --- a/internal/aghnet/arpdb_windows.go +++ b/internal/aghnet/arpdb_windows.go @@ -5,6 +5,7 @@ package aghnet import ( "bufio" "net" + "net/netip" "strings" "sync" ) @@ -43,13 +44,15 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) { n := Neighbor{} - if ip := net.ParseIP(fields[0]); ip == nil { + ip, err := netip.ParseAddr(fields[0]) + if err != nil { continue } else { n.IP = ip } - if mac, err := net.ParseMAC(fields[1]); err != nil { + mac, err := net.ParseMAC(fields[1]) + if err != nil { continue } else { n.MAC = mac diff --git a/internal/aghnet/arpdb_windows_test.go b/internal/aghnet/arpdb_windows_test.go index bb75c988..c3dcfe04 100644 --- a/internal/aghnet/arpdb_windows_test.go +++ b/internal/aghnet/arpdb_windows_test.go @@ -4,6 +4,7 @@ package aghnet import ( "net" + "net/netip" ) const arpAOutput = ` @@ -14,9 +15,9 @@ Interface: 192.168.1.1 --- 0x7 ::ffff:ffff ef-cd-ab-ef-cd-ab static` var wantNeighs = []Neighbor{{ - IP: net.IPv4(192, 168, 1, 2), + IP: netip.MustParseAddr("192.168.1.2"), MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF}, }, { - IP: net.ParseIP("::ffff:ffff"), + IP: netip.MustParseAddr("::ffff:ffff"), MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB}, }} diff --git a/internal/aghnet/hostscontainer.go b/internal/aghnet/hostscontainer.go index 46430e4e..f74ce1f9 100644 --- a/internal/aghnet/hostscontainer.go +++ b/internal/aghnet/hostscontainer.go @@ -5,7 +5,7 @@ import ( "fmt" "io" "io/fs" - "net" + "net/netip" "path" "strings" "sync" @@ -19,10 +19,9 @@ import ( "github.com/AdguardTeam/urlfilter/filterlist" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" + "golang.org/x/exp/maps" ) -//lint:file-ignore SA1019 TODO(a.garipov): Replace [*netutil.IPMap]. - // DefaultHostsPaths returns the slice of paths default for the operating system // to files and directories which are containing the hosts database. The result // is intended to be used within fs.FS so the initial slash is omitted. @@ -108,14 +107,10 @@ type HostsContainer struct { done chan struct{} // updates is the channel for receiving updated hosts. - // - // TODO(e.burkov): Use map[netip.Addr]struct{} instead. - updates chan *netutil.IPMap + updates chan HostsRecords // last is the set of hosts that was cached within last detected change. - // - // TODO(e.burkov): Use map[netip.Addr]struct{} instead. - last *netutil.IPMap + last HostsRecords // fsys is the working file system to read hosts files from. fsys fs.FS @@ -130,6 +125,25 @@ type HostsContainer struct { listID int } +// HostsRecords is a mapping of an IP address to its hosts data. +type HostsRecords map[netip.Addr]*HostsRecord + +// HostsRecord represents a single hosts file record. +type HostsRecord struct { + Aliases *stringutil.Set + Canonical string +} + +// equal returns true if all fields of rec are equal to field in other or they +// both are nil. +func (rec *HostsRecord) equal(other *HostsRecord) (ok bool) { + if rec == nil { + return other == nil + } + + return rec.Canonical == other.Canonical && rec.Aliases.Equal(other.Aliases) +} + // ErrNoHostsPaths is returned when there are no valid paths to watch passed to // the HostsContainer. const ErrNoHostsPaths errors.Error = "no valid paths to hosts files provided" @@ -164,7 +178,7 @@ func NewHostsContainer( }, listID: listID, done: make(chan struct{}, 1), - updates: make(chan *netutil.IPMap, 1), + updates: make(chan HostsRecords, 1), fsys: fsys, w: w, patterns: patterns, @@ -202,9 +216,8 @@ func (hc *HostsContainer) Close() (err error) { return nil } -// Upd returns the channel into which the updates are sent. The receivable -// map's values are guaranteed to be of type of *HostsRecord. -func (hc *HostsContainer) Upd() (updates <-chan *netutil.IPMap) { +// Upd returns the channel into which the updates are sent. +func (hc *HostsContainer) Upd() (updates <-chan HostsRecords) { return hc.updates } @@ -270,7 +283,7 @@ type hostsParser struct { // table stores only the unique IP-hostname pairs. It's also sent to the // updates channel afterwards. - table *netutil.IPMap + table HostsRecords } // newHostsParser creates a new *hostsParser with buffers of size taken from the @@ -279,7 +292,7 @@ func (hc *HostsContainer) newHostsParser() (hp *hostsParser) { return &hostsParser{ rulesBuilder: &strings.Builder{}, translations: map[string]string{}, - table: netutil.NewIPMap(hc.last.Len()), + table: make(HostsRecords, len(hc.last)), } } @@ -291,7 +304,7 @@ func (hp *hostsParser) parseFile(r io.Reader) (patterns []string, cont bool, err s := bufio.NewScanner(r) for s.Scan() { ip, hosts := hp.parseLine(s.Text()) - if ip == nil || len(hosts) == 0 { + if ip == (netip.Addr{}) || len(hosts) == 0 { continue } @@ -302,14 +315,15 @@ func (hp *hostsParser) parseFile(r io.Reader) (patterns []string, cont bool, err } // parseLine parses the line having the hosts syntax ignoring invalid ones. -func (hp *hostsParser) parseLine(line string) (ip net.IP, hosts []string) { +func (hp *hostsParser) parseLine(line string) (ip netip.Addr, hosts []string) { fields := strings.Fields(line) if len(fields) < 2 { - return nil, nil + return netip.Addr{}, nil } - if ip = net.ParseIP(fields[0]); ip == nil { - return nil, nil + ip, err := netip.ParseAddr(fields[0]) + if err != nil { + return netip.Addr{}, nil } for _, f := range fields[1:] { @@ -327,7 +341,7 @@ func (hp *hostsParser) parseLine(line string) (ip net.IP, hosts []string) { // See https://github.com/AdguardTeam/AdGuardHome/issues/3946. // // TODO(e.burkov): Investigate if hosts may contain DNS-SD domains. - err := netutil.ValidateDomainName(f) + err = netutil.ValidateDomainName(f) if err != nil { log.Error("%s: host %q is invalid, ignoring", hostsContainerPref, f) @@ -340,30 +354,13 @@ func (hp *hostsParser) parseLine(line string) (ip net.IP, hosts []string) { return ip, hosts } -// HostsRecord represents a single hosts file record. -type HostsRecord struct { - Aliases *stringutil.Set - Canonical string -} - -// Equal returns true if all fields of rec are equal to field in other or they -// both are nil. -func (rec *HostsRecord) Equal(other *HostsRecord) (ok bool) { - if rec == nil { - return other == nil - } - - return rec.Canonical == other.Canonical && rec.Aliases.Equal(other.Aliases) -} - // addRecord puts the record for the IP address to the rules builder if needed. // The first host is considered to be the canonical name for the IP address. // hosts must have at least one name. -func (hp *hostsParser) addRecord(ip net.IP, hosts []string) { +func (hp *hostsParser) addRecord(ip netip.Addr, hosts []string) { line := strings.Join(append([]string{ip.String()}, hosts...), " ") - var rec *HostsRecord - v, ok := hp.table.Get(ip) + rec, ok := hp.table[ip] if !ok { rec = &HostsRecord{ Aliases: stringutil.NewSet(), @@ -371,14 +368,7 @@ func (hp *hostsParser) addRecord(ip net.IP, hosts []string) { rec.Canonical, hosts = hosts[0], hosts[1:] hp.addRules(ip, rec.Canonical, line) - hp.table.Set(ip, rec) - } else { - rec, ok = v.(*HostsRecord) - if !ok { - log.Error("%s: adding pairs: unexpected type %T", hostsContainerPref, v) - - return - } + hp.table[ip] = rec } for _, host := range hosts { @@ -393,7 +383,7 @@ func (hp *hostsParser) addRecord(ip net.IP, hosts []string) { } // addRules adds rules and rule translations for the line. -func (hp *hostsParser) addRules(ip net.IP, host, line string) { +func (hp *hostsParser) addRules(ip netip.Addr, host, line string) { rule, rulePtr := hp.writeRules(host, ip) hp.translations[rule], hp.translations[rulePtr] = line, line @@ -402,8 +392,9 @@ func (hp *hostsParser) addRules(ip net.IP, host, line string) { // writeRules writes the actual rule for the qtype and the PTR for the host-ip // pair into internal builders. -func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string) { - arpa, err := netutil.IPToReversedAddr(ip) +func (hp *hostsParser) writeRules(host string, ip netip.Addr) (rule, rulePtr string) { + // TODO(a.garipov): Add a netip.Addr version to netutil. + arpa, err := netutil.IPToReversedAddr(ip.AsSlice()) if err != nil { return "", "" } @@ -421,7 +412,7 @@ func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string) var qtype string // The validation of the IP address has been performed earlier so it is // guaranteed to be either an IPv4 or an IPv6. - if ip.To4() != nil { + if ip.Is4() { qtype = "A" } else { qtype = "AAAA" @@ -448,51 +439,8 @@ func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string) return rule, rulePtr } -// equalSet returns true if the internal hosts table just parsed equals target. -// target's values must be of type *HostsRecord. -func (hp *hostsParser) equalSet(target *netutil.IPMap) (ok bool) { - if target == nil { - // hp.table shouldn't appear nil since it's initialized on each refresh. - return target == hp.table - } - - if hp.table.Len() != target.Len() { - return false - } - - hp.table.Range(func(ip net.IP, recVal any) (cont bool) { - var targetVal any - targetVal, ok = target.Get(ip) - if !ok { - return false - } - - var rec *HostsRecord - rec, ok = recVal.(*HostsRecord) - if !ok { - log.Error("%s: comparing: unexpected type %T", hostsContainerPref, recVal) - - return false - } - - var targetRec *HostsRecord - targetRec, ok = targetVal.(*HostsRecord) - if !ok { - log.Error("%s: comparing: target: unexpected type %T", hostsContainerPref, targetVal) - - return false - } - - ok = rec.Equal(targetRec) - - return ok - }) - - return ok -} - // sendUpd tries to send the parsed data to the ch. -func (hp *hostsParser) sendUpd(ch chan *netutil.IPMap) { +func (hp *hostsParser) sendUpd(ch chan HostsRecords) { log.Debug("%s: sending upd", hostsContainerPref) upd := hp.table @@ -530,14 +478,14 @@ func (hc *HostsContainer) refresh() (err error) { return fmt.Errorf("refreshing : %w", err) } - if hp.equalSet(hc.last) { + if maps.EqualFunc(hp.table, hc.last, (*HostsRecord).equal) { log.Debug("%s: no changes detected", hostsContainerPref) return nil } defer hp.sendUpd(hc.updates) - hc.last = hp.table.ShallowClone() + hc.last = maps.Clone(hp.table) var rulesStrg *filterlist.RuleStorage if rulesStrg, err = hp.newStrg(hc.listID); err != nil { diff --git a/internal/aghnet/hostscontainer_test.go b/internal/aghnet/hostscontainer_test.go index d2637d85..a5822530 100644 --- a/internal/aghnet/hostscontainer_test.go +++ b/internal/aghnet/hostscontainer_test.go @@ -3,6 +3,7 @@ package aghnet import ( "io/fs" "net" + "net/netip" "path" "strings" "sync/atomic" @@ -13,6 +14,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghchan" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/urlfilter" @@ -135,7 +137,7 @@ func TestNewHostsContainer(t *testing.T) { func TestHostsContainer_refresh(t *testing.T) { // TODO(e.burkov): Test the case with no actual updates. - ip := net.IP{127, 0, 0, 1} + ip := netutil.IPv4Localhost() ipStr := ip.String() testFS := fstest.MapFS{"dir/file1": &fstest.MapFile{Data: []byte(ipStr + ` hostname` + nl)}} @@ -167,17 +169,13 @@ func TestHostsContainer_refresh(t *testing.T) { require.True(t, ok) require.NotNil(t, upd) - assert.Equal(t, 1, upd.Len()) + assert.Len(t, upd, 1) - v, ok := upd.Get(ip) + rec, ok := upd[ip] require.True(t, ok) - - require.IsType(t, (*HostsRecord)(nil), v) - - rec, _ := v.(*HostsRecord) require.NotNil(t, rec) - assert.Truef(t, rec.Equal(want), "%+v != %+v", rec, want) + assert.Truef(t, rec.equal(want), "%+v != %+v", rec, want) } t.Run("initial_refresh", func(t *testing.T) { @@ -562,13 +560,13 @@ func TestHostsContainer(t *testing.T) { } func TestUniqueRules_ParseLine(t *testing.T) { - ip := net.IP{127, 0, 0, 1} + ip := netutil.IPv4Localhost() ipStr := ip.String() testCases := []struct { name string line string - wantIP net.IP + wantIP netip.Addr wantHosts []string }{{ name: "simple", @@ -583,7 +581,7 @@ func TestUniqueRules_ParseLine(t *testing.T) { }, { name: "invalid_line", line: ipStr, - wantIP: nil, + wantIP: netip.Addr{}, wantHosts: nil, }, { name: "invalid_line_hostname", @@ -598,7 +596,7 @@ func TestUniqueRules_ParseLine(t *testing.T) { }, { name: "whole_comment", line: `# ` + ipStr + ` hostname`, - wantIP: nil, + wantIP: netip.Addr{}, wantHosts: nil, }, { name: "partial_comment", @@ -608,7 +606,7 @@ func TestUniqueRules_ParseLine(t *testing.T) { }, { name: "empty", line: ``, - wantIP: nil, + wantIP: netip.Addr{}, wantHosts: nil, }} @@ -616,7 +614,7 @@ func TestUniqueRules_ParseLine(t *testing.T) { hp := hostsParser{} t.Run(tc.name, func(t *testing.T) { got, hosts := hp.parseLine(tc.line) - assert.True(t, tc.wantIP.Equal(got)) + assert.Equal(t, tc.wantIP, got) assert.Equal(t, tc.wantHosts, hosts) }) } diff --git a/internal/home/clients.go b/internal/home/clients.go index bb1a3210..fa230178 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -332,8 +332,8 @@ func (clients *clientsContainer) onDHCPLeaseChanged(flags int) { } } -// Exists checks if client with this IP address already exists. -func (clients *clientsContainer) Exists(ip net.IP, source clientSource) (ok bool) { +// exists checks if client with this IP address already exists. +func (clients *clientsContainer) exists(ip net.IP, source clientSource) (ok bool) { clients.lock.Lock() defer clients.lock.Unlock() @@ -414,7 +414,7 @@ func (clients *clientsContainer) clientOrArtificial( } var rc *RuntimeClient - rc, ok = clients.FindRuntimeClient(ip) + rc, ok = clients.findRuntimeClient(ip) if ok { return &querylog.Client{ Name: rc.Host, @@ -551,8 +551,8 @@ func (clients *clientsContainer) findRuntimeClientLocked(ip net.IP) (rc *Runtime return rc, ok } -// FindRuntimeClient finds a runtime client by their IP. -func (clients *clientsContainer) FindRuntimeClient(ip net.IP) (rc *RuntimeClient, ok bool) { +// findRuntimeClient finds a runtime client by their IP. +func (clients *clientsContainer) findRuntimeClient(ip net.IP) (rc *RuntimeClient, ok bool) { if ip == nil { return nil, false } @@ -749,8 +749,8 @@ func (clients *clientsContainer) Update(name string, c *Client) (err error) { return nil } -// SetWHOISInfo sets the WHOIS information for a client. -func (clients *clientsContainer) SetWHOISInfo(ip net.IP, wi *RuntimeClientWHOISInfo) { +// setWHOISInfo sets the WHOIS information for a client. +func (clients *clientsContainer) setWHOISInfo(ip net.IP, wi *RuntimeClientWHOISInfo) { clients.lock.Lock() defer clients.lock.Unlock() @@ -795,12 +795,23 @@ func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSourc clients.lock.Lock() defer clients.lock.Unlock() - return clients.addHostLocked(ip, host, src), nil + // TODO(a.garipov): Remove once we switch to netip.Addr more fully. + ipAddr, err := netutil.IPToAddrNoMapped(ip) + if err != nil { + return false, fmt.Errorf("adding host: %w", err) + } + + return clients.addHostLocked(ipAddr, host, src), nil } -// addHostLocked adds a new IP-hostname pairing. For internal use only. -func (clients *clientsContainer) addHostLocked(ip net.IP, host string, src clientSource) (ok bool) { - rc, ok := clients.findRuntimeClientLocked(ip) +// addHostLocked adds a new IP-hostname pairing. clients.lock is expected to be +// locked. +func (clients *clientsContainer) addHostLocked( + ip netip.Addr, + host string, + src clientSource, +) (ok bool) { + rc, ok := clients.ipToRC[ip] if ok { if rc.Source > src { return false @@ -815,15 +826,7 @@ func (clients *clientsContainer) addHostLocked(ip net.IP, host string, src clien WHOISInfo: &RuntimeClientWHOISInfo{}, } - // TODO(a.garipov): Remove once we switch to netip.Addr more fully. - ipAddr, err := netutil.IPToAddrNoMapped(ip) - if err != nil { - log.Error("clients: bad client ip %v: %s", ip, err) - - return false - } - - clients.ipToRC[ipAddr] = rc + clients.ipToRC[ip] = rc } log.Debug("clients: added %s -> %q [%d]", ip, host, len(clients.ipToRC)) @@ -846,28 +849,17 @@ func (clients *clientsContainer) rmHostsBySrc(src clientSource) { // addFromHostsFile fills the client-hostname pairing index from the system's // hosts files. -// -//lint:ignore SA1019 TODO(a.garipov): Replace [*netutil.IPMap]. -func (clients *clientsContainer) addFromHostsFile(hosts *netutil.IPMap) { +func (clients *clientsContainer) addFromHostsFile(hosts aghnet.HostsRecords) { clients.lock.Lock() defer clients.lock.Unlock() clients.rmHostsBySrc(ClientSourceHostsFile) n := 0 - hosts.Range(func(ip net.IP, v any) (cont bool) { - rec, ok := v.(*aghnet.HostsRecord) - if !ok { - log.Error("clients: bad type %T in hosts for %s", v, ip) - - return true - } - + for ip, rec := range hosts { clients.addHostLocked(ip, rec.Canonical, ClientSourceHostsFile) n++ - - return true - }) + } log.Debug("clients: added %d client aliases from system hosts file", n) } @@ -928,7 +920,15 @@ func (clients *clientsContainer) updateFromDHCP(add bool) { continue } - ok := clients.addHostLocked(l.IP, l.Hostname, ClientSourceDHCP) + // TODO(a.garipov): Remove once we switch to netip.Addr more fully. + ipAddr, err := netutil.IPToAddrNoMapped(l.IP) + if err != nil { + log.Error("clients: bad client ip %v from dhcp: %s", l.IP, err) + + continue + } + + ok := clients.addHostLocked(ipAddr, l.Hostname, ClientSourceDHCP) if ok { n++ } diff --git a/internal/home/clients_test.go b/internal/home/clients_test.go index 407b2513..636971eb 100644 --- a/internal/home/clients_test.go +++ b/internal/home/clients_test.go @@ -57,9 +57,9 @@ func TestClients(t *testing.T) { assert.Equal(t, "client2", c.Name) - assert.False(t, clients.Exists(net.IP{1, 2, 3, 4}, ClientSourceHostsFile)) - assert.True(t, clients.Exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile)) - assert.True(t, clients.Exists(net.IP{2, 2, 2, 2}, ClientSourceHostsFile)) + assert.False(t, clients.exists(net.IP{1, 2, 3, 4}, ClientSourceHostsFile)) + assert.True(t, clients.exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile)) + assert.True(t, clients.exists(net.IP{2, 2, 2, 2}, ClientSourceHostsFile)) }) t.Run("add_fail_name", func(t *testing.T) { @@ -109,8 +109,8 @@ func TestClients(t *testing.T) { }) require.NoError(t, err) - assert.False(t, clients.Exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile)) - assert.True(t, clients.Exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile)) + assert.False(t, clients.exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile)) + assert.True(t, clients.exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile)) err = clients.Update("client1", &Client{ IDs: []string{"1.1.1.2"}, @@ -139,7 +139,7 @@ func TestClients(t *testing.T) { ok := clients.Del("client1-renamed") require.True(t, ok) - assert.False(t, clients.Exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile)) + assert.False(t, clients.exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile)) }) t.Run("del_fail", func(t *testing.T) { @@ -165,7 +165,7 @@ func TestClients(t *testing.T) { assert.True(t, ok) - assert.True(t, clients.Exists(ip, ClientSourceHostsFile)) + assert.True(t, clients.exists(ip, ClientSourceHostsFile)) }) t.Run("dhcp_replaces_arp", func(t *testing.T) { @@ -175,13 +175,13 @@ func TestClients(t *testing.T) { require.NoError(t, err) assert.True(t, ok) - assert.True(t, clients.Exists(ip, ClientSourceARP)) + assert.True(t, clients.exists(ip, ClientSourceARP)) ok, err = clients.AddHost(ip, "from_dhcp", ClientSourceDHCP) require.NoError(t, err) assert.True(t, ok) - assert.True(t, clients.Exists(ip, ClientSourceDHCP)) + assert.True(t, clients.exists(ip, ClientSourceDHCP)) }) t.Run("addhost_fail", func(t *testing.T) { @@ -203,7 +203,7 @@ func TestClientsWHOIS(t *testing.T) { t.Run("new_client", func(t *testing.T) { ip := netip.MustParseAddr("1.1.1.255") - clients.SetWHOISInfo(ip.AsSlice(), whois) + clients.setWHOISInfo(ip.AsSlice(), whois) rc := clients.ipToRC[ip] require.NotNil(t, rc) @@ -217,7 +217,7 @@ func TestClientsWHOIS(t *testing.T) { assert.True(t, ok) - clients.SetWHOISInfo(ip.AsSlice(), whois) + clients.setWHOISInfo(ip.AsSlice(), whois) rc := clients.ipToRC[ip] require.NotNil(t, rc) @@ -234,7 +234,7 @@ func TestClientsWHOIS(t *testing.T) { require.NoError(t, err) assert.True(t, ok) - clients.SetWHOISInfo(ip.AsSlice(), whois) + clients.setWHOISInfo(ip.AsSlice(), whois) rc := clients.ipToRC[ip] require.Nil(t, rc) diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index 30b0608e..313fd998 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -241,7 +241,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http // /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be // non-nil. func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj *clientJSON) { - rc, ok := clients.FindRuntimeClient(ip) + rc, ok := clients.findRuntimeClient(ip) if !ok { // It is still possible that the IP used to be in the runtime clients // list, but then the server was reloaded. So, check the DNS server's diff --git a/internal/home/rdns.go b/internal/home/rdns.go index 924aff37..e44000b3 100644 --- a/internal/home/rdns.go +++ b/internal/home/rdns.go @@ -101,7 +101,7 @@ func (r *RDNS) isCached(ip net.IP) (ok bool) { func (r *RDNS) Begin(ip net.IP) { r.ensurePrivateCache() - if r.isCached(ip) || r.clients.Exists(ip, ClientSourceRDNS) { + if r.isCached(ip) || r.clients.exists(ip, ClientSourceRDNS) { return } diff --git a/internal/home/rdns_test.go b/internal/home/rdns_test.go index 3f6ce4c7..e8d28e61 100644 --- a/internal/home/rdns_test.go +++ b/internal/home/rdns_test.go @@ -259,7 +259,7 @@ func TestRDNS_WorkerLoop(t *testing.T) { return } - assert.True(t, cc.Exists(tc.cliIP, ClientSourceRDNS)) + assert.True(t, cc.exists(tc.cliIP, ClientSourceRDNS)) }) } } diff --git a/internal/home/whois.go b/internal/home/whois.go index d7543faa..c9834708 100644 --- a/internal/home/whois.go +++ b/internal/home/whois.go @@ -252,6 +252,6 @@ func (w *WHOIS) workerLoop() { continue } - w.clients.SetWHOISInfo(ip, info) + w.clients.setWHOISInfo(ip, info) } }