mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2025-04-02 15:33:36 +03:00
Pull request 2265: AG-27492-client-runtime-storage
Squashed commit of the following: commit a164bace2e0333cf95622f34df7b0e79eac69f41 Merge:6567cd330
184f476bd
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Aug 21 16:14:55 2024 +0300 Merge branch 'master' into AG-27492-client-runtime-storage commit6567cd330c
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Aug 20 16:45:43 2024 +0300 all: imp code commit243123a404
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Aug 15 19:15:54 2024 +0300 all: add tests commit6489996878
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Aug 5 15:12:05 2024 +0300 all: client runtime storage
This commit is contained in:
parent
184f476bdc
commit
30c0bbe5cc
7 changed files with 420 additions and 88 deletions
|
@ -8,6 +8,7 @@ import (
|
||||||
"encoding"
|
"encoding"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||||
)
|
)
|
||||||
|
@ -120,6 +121,7 @@ func (r *Runtime) Info() (cs Source, host string) {
|
||||||
|
|
||||||
// SetInfo sets a host as a client information from the cs.
|
// SetInfo sets a host as a client information from the cs.
|
||||||
func (r *Runtime) SetInfo(cs Source, hosts []string) {
|
func (r *Runtime) SetInfo(cs Source, hosts []string) {
|
||||||
|
// TODO(s.chzhen): Use contract where hosts must contain non-empty host.
|
||||||
if len(hosts) == 1 && hosts[0] == "" {
|
if len(hosts) == 1 && hosts[0] == "" {
|
||||||
hosts = []string{}
|
hosts = []string{}
|
||||||
}
|
}
|
||||||
|
@ -175,3 +177,15 @@ func (r *Runtime) isEmpty() (ok bool) {
|
||||||
func (r *Runtime) Addr() (ip netip.Addr) {
|
func (r *Runtime) Addr() (ip netip.Addr) {
|
||||||
return r.ip
|
return r.ip
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clone returns a deep copy of the runtime client.
|
||||||
|
func (r *Runtime) Clone() (c *Runtime) {
|
||||||
|
return &Runtime{
|
||||||
|
ip: r.ip,
|
||||||
|
whois: r.whois.Clone(),
|
||||||
|
arp: slices.Clone(r.arp),
|
||||||
|
rdns: slices.Clone(r.rdns),
|
||||||
|
dhcp: slices.Clone(r.dhcp),
|
||||||
|
hostsFile: slices.Clone(r.hostsFile),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/container"
|
"github.com/AdguardTeam/golibs/container"
|
||||||
|
@ -31,8 +32,6 @@ type Storage struct {
|
||||||
index *index
|
index *index
|
||||||
|
|
||||||
// runtimeIndex contains information about runtime clients.
|
// runtimeIndex contains information about runtime clients.
|
||||||
//
|
|
||||||
// TODO(s.chzhen): Use it.
|
|
||||||
runtimeIndex *RuntimeIndex
|
runtimeIndex *RuntimeIndex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -236,20 +235,75 @@ func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) {
|
||||||
return s.runtimeIndex.Client(ip)
|
return s.runtimeIndex.Client(ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddRuntime saves the runtime client information in the storage. IP address
|
// UpdateRuntime updates the stored runtime client with information from rc. If
|
||||||
// of a client must be unique. rc must not be nil.
|
// no such client exists, saves the copy of rc in storage. rc must not be nil.
|
||||||
//
|
func (s *Storage) UpdateRuntime(rc *Runtime) (added bool) {
|
||||||
// TODO(s.chzhen): Use it.
|
|
||||||
func (s *Storage) AddRuntime(rc *Runtime) {
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
s.runtimeIndex.Add(rc)
|
return s.updateRuntimeLocked(rc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateRuntimeLocked updates the stored runtime client with information from
|
||||||
|
// rc. rc must not be nil. Storage.mu is expected to be locked.
|
||||||
|
func (s *Storage) updateRuntimeLocked(rc *Runtime) (added bool) {
|
||||||
|
stored := s.runtimeIndex.Client(rc.ip)
|
||||||
|
if stored == nil {
|
||||||
|
s.runtimeIndex.Add(rc.Clone())
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if rc.whois != nil {
|
||||||
|
stored.whois = rc.whois.Clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
if rc.arp != nil {
|
||||||
|
stored.arp = slices.Clone(rc.arp)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rc.rdns != nil {
|
||||||
|
stored.rdns = slices.Clone(rc.rdns)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rc.dhcp != nil {
|
||||||
|
stored.dhcp = slices.Clone(rc.dhcp)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rc.hostsFile != nil {
|
||||||
|
stored.hostsFile = slices.Clone(rc.hostsFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchUpdateBySource updates the stored runtime clients information from the
|
||||||
|
// specified source and returns the number of added and removed clients.
|
||||||
|
func (s *Storage) BatchUpdateBySource(src Source, rcs []*Runtime) (added, removed int) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
for _, rc := range s.runtimeIndex.index {
|
||||||
|
rc.unset(src)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rc := range rcs {
|
||||||
|
if s.updateRuntimeLocked(rc) {
|
||||||
|
added++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for ip, rc := range s.runtimeIndex.index {
|
||||||
|
if rc.isEmpty() {
|
||||||
|
delete(s.runtimeIndex.index, ip)
|
||||||
|
removed++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return added, removed
|
||||||
}
|
}
|
||||||
|
|
||||||
// SizeRuntime returns the number of the runtime clients.
|
// SizeRuntime returns the number of the runtime clients.
|
||||||
//
|
|
||||||
// TODO(s.chzhen): Use it.
|
|
||||||
func (s *Storage) SizeRuntime() (n int) {
|
func (s *Storage) SizeRuntime() (n int) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
@ -258,8 +312,6 @@ func (s *Storage) SizeRuntime() (n int) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// RangeRuntime calls f for each runtime client in an undefined order.
|
// RangeRuntime calls f for each runtime client in an undefined order.
|
||||||
//
|
|
||||||
// TODO(s.chzhen): Use it.
|
|
||||||
func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
|
func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
@ -267,16 +319,6 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
|
||||||
s.runtimeIndex.Range(f)
|
s.runtimeIndex.Range(f)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRuntime removes the runtime client by ip.
|
|
||||||
//
|
|
||||||
// TODO(s.chzhen): Use it.
|
|
||||||
func (s *Storage) DeleteRuntime(ip netip.Addr) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
s.runtimeIndex.Delete(ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteBySource removes all runtime clients that have information only from
|
// DeleteBySource removes all runtime clients that have information only from
|
||||||
// the specified source and returns the number of removed clients.
|
// the specified source and returns the number of removed clients.
|
||||||
//
|
//
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -25,9 +26,19 @@ func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) {
|
||||||
require.NoError(tb, s.Add(c))
|
require.NoError(tb, s.Add(c))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
require.Equal(tb, len(m), s.Size())
|
||||||
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// newRuntimeClient is a helper function that returns a new runtime client.
|
||||||
|
func newRuntimeClient(ip netip.Addr, source client.Source, host string) (rc *client.Runtime) {
|
||||||
|
rc = client.NewRuntime(ip)
|
||||||
|
rc.SetInfo(source, []string{host})
|
||||||
|
|
||||||
|
return rc
|
||||||
|
}
|
||||||
|
|
||||||
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
|
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
|
||||||
// error.
|
// error.
|
||||||
func mustParseMAC(s string) (mac net.HardwareAddr) {
|
func mustParseMAC(s string) (mac net.HardwareAddr) {
|
||||||
|
@ -43,6 +54,9 @@ func TestStorage_Add(t *testing.T) {
|
||||||
const (
|
const (
|
||||||
existingName = "existing_name"
|
existingName = "existing_name"
|
||||||
existingClientID = "existing_client_id"
|
existingClientID = "existing_client_id"
|
||||||
|
|
||||||
|
allowedTag = "tag"
|
||||||
|
notAllowedTag = "not_allowed_tag"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -60,7 +74,7 @@ func TestStorage_Add(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
s := client.NewStorage(&client.Config{
|
s := client.NewStorage(&client.Config{
|
||||||
AllowedTags: nil,
|
AllowedTags: []string{allowedTag},
|
||||||
})
|
})
|
||||||
err := s.Add(existingClient)
|
err := s.Add(existingClient)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -119,6 +133,15 @@ func TestStorage_Add(t *testing.T) {
|
||||||
},
|
},
|
||||||
wantErrMsg: `adding client: another client "existing_name" ` +
|
wantErrMsg: `adding client: another client "existing_name" ` +
|
||||||
`uses the same ClientID "existing_client_id"`,
|
`uses the same ClientID "existing_client_id"`,
|
||||||
|
}, {
|
||||||
|
name: "not_allowed_tag",
|
||||||
|
cli: &client.Persistent{
|
||||||
|
Name: "nont_allowed_tag",
|
||||||
|
Tags: []string{notAllowedTag},
|
||||||
|
IPs: []netip.Addr{netip.MustParseAddr("4.4.4.4")},
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
},
|
||||||
|
wantErrMsg: `adding client: invalid tag: "not_allowed_tag"`,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
|
@ -341,6 +364,127 @@ func TestStorage_FindLoose(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStorage_FindByName(t *testing.T) {
|
||||||
|
const (
|
||||||
|
cliIP1 = "1.1.1.1"
|
||||||
|
cliIP2 = "2.2.2.2"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
clientExistingName = "client_existing"
|
||||||
|
clientAnotherExistingName = "client_another_existing"
|
||||||
|
nonExistingClientName = "client_non_existing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
clientExisting = &client.Persistent{
|
||||||
|
Name: clientExistingName,
|
||||||
|
IPs: []netip.Addr{netip.MustParseAddr(cliIP1)},
|
||||||
|
}
|
||||||
|
|
||||||
|
clientAnotherExisting = &client.Persistent{
|
||||||
|
Name: clientAnotherExistingName,
|
||||||
|
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
clients := []*client.Persistent{
|
||||||
|
clientExisting,
|
||||||
|
clientAnotherExisting,
|
||||||
|
}
|
||||||
|
s := newStorage(t, clients)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
want *client.Persistent
|
||||||
|
name string
|
||||||
|
clientName string
|
||||||
|
}{{
|
||||||
|
name: "existing",
|
||||||
|
clientName: clientExistingName,
|
||||||
|
want: clientExisting,
|
||||||
|
}, {
|
||||||
|
name: "another_existing",
|
||||||
|
clientName: clientAnotherExistingName,
|
||||||
|
want: clientAnotherExisting,
|
||||||
|
}, {
|
||||||
|
name: "non_existing",
|
||||||
|
clientName: nonExistingClientName,
|
||||||
|
want: nil,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
c, ok := s.FindByName(tc.clientName)
|
||||||
|
if tc.want == nil {
|
||||||
|
assert.False(t, ok)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, tc.want, c)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStorage_FindByMAC(t *testing.T) {
|
||||||
|
var (
|
||||||
|
cliMAC = mustParseMAC("11:11:11:11:11:11")
|
||||||
|
cliAnotherMAC = mustParseMAC("22:22:22:22:22:22")
|
||||||
|
nonExistingClientMAC = mustParseMAC("33:33:33:33:33:33")
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
clientExisting = &client.Persistent{
|
||||||
|
Name: "client",
|
||||||
|
MACs: []net.HardwareAddr{cliMAC},
|
||||||
|
}
|
||||||
|
|
||||||
|
clientAnotherExisting = &client.Persistent{
|
||||||
|
Name: "another_client",
|
||||||
|
MACs: []net.HardwareAddr{cliAnotherMAC},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
clients := []*client.Persistent{
|
||||||
|
clientExisting,
|
||||||
|
clientAnotherExisting,
|
||||||
|
}
|
||||||
|
s := newStorage(t, clients)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
want *client.Persistent
|
||||||
|
name string
|
||||||
|
clientMAC net.HardwareAddr
|
||||||
|
}{{
|
||||||
|
name: "existing",
|
||||||
|
clientMAC: cliMAC,
|
||||||
|
want: clientExisting,
|
||||||
|
}, {
|
||||||
|
name: "another_existing",
|
||||||
|
clientMAC: cliAnotherMAC,
|
||||||
|
want: clientAnotherExisting,
|
||||||
|
}, {
|
||||||
|
name: "non_existing",
|
||||||
|
clientMAC: nonExistingClientMAC,
|
||||||
|
want: nil,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
c, ok := s.FindByMAC(tc.clientMAC)
|
||||||
|
if tc.want == nil {
|
||||||
|
assert.False(t, ok)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, tc.want, c)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStorage_Update(t *testing.T) {
|
func TestStorage_Update(t *testing.T) {
|
||||||
const (
|
const (
|
||||||
clientName = "client_name"
|
clientName = "client_name"
|
||||||
|
@ -479,3 +623,157 @@ func TestStorage_RangeByName(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStorage_UpdateRuntime(t *testing.T) {
|
||||||
|
const (
|
||||||
|
addedARP = "added_arp"
|
||||||
|
addedSecondARP = "added_arp"
|
||||||
|
|
||||||
|
updatedARP = "updated_arp"
|
||||||
|
|
||||||
|
cliCity = "City"
|
||||||
|
cliCountry = "Country"
|
||||||
|
cliOrgname = "Orgname"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ip = netip.MustParseAddr("1.1.1.1")
|
||||||
|
ip2 = netip.MustParseAddr("2.2.2.2")
|
||||||
|
)
|
||||||
|
|
||||||
|
updated := client.NewRuntime(ip)
|
||||||
|
updated.SetInfo(client.SourceARP, []string{updatedARP})
|
||||||
|
|
||||||
|
info := &whois.Info{
|
||||||
|
City: cliCity,
|
||||||
|
Country: cliCountry,
|
||||||
|
Orgname: cliOrgname,
|
||||||
|
}
|
||||||
|
updated.SetWHOIS(info)
|
||||||
|
|
||||||
|
s := client.NewStorage(&client.Config{
|
||||||
|
AllowedTags: nil,
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("add_arp_client", func(t *testing.T) {
|
||||||
|
added := client.NewRuntime(ip)
|
||||||
|
added.SetInfo(client.SourceARP, []string{addedARP})
|
||||||
|
|
||||||
|
require.True(t, s.UpdateRuntime(added))
|
||||||
|
require.Equal(t, 1, s.SizeRuntime())
|
||||||
|
|
||||||
|
got := s.ClientRuntime(ip)
|
||||||
|
source, host := got.Info()
|
||||||
|
assert.Equal(t, client.SourceARP, source)
|
||||||
|
assert.Equal(t, addedARP, host)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("add_second_arp_client", func(t *testing.T) {
|
||||||
|
added := client.NewRuntime(ip2)
|
||||||
|
added.SetInfo(client.SourceARP, []string{addedSecondARP})
|
||||||
|
|
||||||
|
require.True(t, s.UpdateRuntime(added))
|
||||||
|
require.Equal(t, 2, s.SizeRuntime())
|
||||||
|
|
||||||
|
got := s.ClientRuntime(ip2)
|
||||||
|
source, host := got.Info()
|
||||||
|
assert.Equal(t, client.SourceARP, source)
|
||||||
|
assert.Equal(t, addedSecondARP, host)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("update_first_client", func(t *testing.T) {
|
||||||
|
require.False(t, s.UpdateRuntime(updated))
|
||||||
|
got := s.ClientRuntime(ip)
|
||||||
|
require.Equal(t, 2, s.SizeRuntime())
|
||||||
|
|
||||||
|
source, host := got.Info()
|
||||||
|
assert.Equal(t, client.SourceARP, source)
|
||||||
|
assert.Equal(t, updatedARP, host)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("remove_arp_info", func(t *testing.T) {
|
||||||
|
n := s.DeleteBySource(client.SourceARP)
|
||||||
|
require.Equal(t, 1, n)
|
||||||
|
require.Equal(t, 1, s.SizeRuntime())
|
||||||
|
|
||||||
|
got := s.ClientRuntime(ip)
|
||||||
|
source, _ := got.Info()
|
||||||
|
assert.Equal(t, client.SourceWHOIS, source)
|
||||||
|
assert.Equal(t, info, got.WHOIS())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("remove_whois_info", func(t *testing.T) {
|
||||||
|
n := s.DeleteBySource(client.SourceWHOIS)
|
||||||
|
require.Equal(t, 1, n)
|
||||||
|
require.Equal(t, 0, s.SizeRuntime())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStorage_BatchUpdateBySource(t *testing.T) {
|
||||||
|
const (
|
||||||
|
defSrc = client.SourceARP
|
||||||
|
|
||||||
|
cliFirstHost1 = "host1"
|
||||||
|
cliFirstHost2 = "host2"
|
||||||
|
cliUpdatedHost3 = "host3"
|
||||||
|
cliUpdatedHost4 = "host4"
|
||||||
|
cliUpdatedHost5 = "host5"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
cliFirstIP1 = netip.MustParseAddr("1.1.1.1")
|
||||||
|
cliFirstIP2 = netip.MustParseAddr("2.2.2.2")
|
||||||
|
cliUpdatedIP3 = netip.MustParseAddr("3.3.3.3")
|
||||||
|
cliUpdatedIP4 = netip.MustParseAddr("4.4.4.4")
|
||||||
|
cliUpdatedIP5 = netip.MustParseAddr("5.5.5.5")
|
||||||
|
)
|
||||||
|
|
||||||
|
firstClients := []*client.Runtime{
|
||||||
|
newRuntimeClient(cliFirstIP1, defSrc, cliFirstHost1),
|
||||||
|
newRuntimeClient(cliFirstIP2, defSrc, cliFirstHost2),
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedClients := []*client.Runtime{
|
||||||
|
newRuntimeClient(cliUpdatedIP3, defSrc, cliUpdatedHost3),
|
||||||
|
newRuntimeClient(cliUpdatedIP4, defSrc, cliUpdatedHost4),
|
||||||
|
newRuntimeClient(cliUpdatedIP5, defSrc, cliUpdatedHost5),
|
||||||
|
}
|
||||||
|
|
||||||
|
s := client.NewStorage(&client.Config{
|
||||||
|
AllowedTags: nil,
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("populate_storage_with_first_clients", func(t *testing.T) {
|
||||||
|
added, removed := s.BatchUpdateBySource(defSrc, firstClients)
|
||||||
|
require.Equal(t, len(firstClients), added)
|
||||||
|
require.Equal(t, 0, removed)
|
||||||
|
require.Equal(t, len(firstClients), s.SizeRuntime())
|
||||||
|
|
||||||
|
rc := s.ClientRuntime(cliFirstIP1)
|
||||||
|
src, host := rc.Info()
|
||||||
|
assert.Equal(t, defSrc, src)
|
||||||
|
assert.Equal(t, cliFirstHost1, host)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("update_storage", func(t *testing.T) {
|
||||||
|
added, removed := s.BatchUpdateBySource(defSrc, updatedClients)
|
||||||
|
require.Equal(t, len(updatedClients), added)
|
||||||
|
require.Equal(t, len(firstClients), removed)
|
||||||
|
require.Equal(t, len(updatedClients), s.SizeRuntime())
|
||||||
|
|
||||||
|
rc := s.ClientRuntime(cliUpdatedIP3)
|
||||||
|
src, host := rc.Info()
|
||||||
|
assert.Equal(t, defSrc, src)
|
||||||
|
assert.Equal(t, cliUpdatedHost3, host)
|
||||||
|
|
||||||
|
rc = s.ClientRuntime(cliFirstIP1)
|
||||||
|
assert.Nil(t, rc)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("remove_all", func(t *testing.T) {
|
||||||
|
added, removed := s.BatchUpdateBySource(defSrc, []*client.Runtime{})
|
||||||
|
require.Equal(t, 0, added)
|
||||||
|
require.Equal(t, len(updatedClients), removed)
|
||||||
|
require.Equal(t, 0, s.SizeRuntime())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -47,9 +47,6 @@ type clientsContainer struct {
|
||||||
// storage stores information about persistent clients.
|
// storage stores information about persistent clients.
|
||||||
storage *client.Storage
|
storage *client.Storage
|
||||||
|
|
||||||
// runtimeIndex stores information about runtime clients.
|
|
||||||
runtimeIndex *client.RuntimeIndex
|
|
||||||
|
|
||||||
// dhcp is the DHCP service implementation.
|
// dhcp is the DHCP service implementation.
|
||||||
dhcp DHCP
|
dhcp DHCP
|
||||||
|
|
||||||
|
@ -105,8 +102,6 @@ func (clients *clientsContainer) Init(
|
||||||
return errors.Error("clients container already initialized")
|
return errors.Error("clients container already initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
clients.runtimeIndex = client.NewRuntimeIndex()
|
|
||||||
|
|
||||||
clients.storage = client.NewStorage(&client.Config{
|
clients.storage = client.NewStorage(&client.Config{
|
||||||
AllowedTags: clientTags,
|
AllowedTags: clientTags,
|
||||||
})
|
})
|
||||||
|
@ -358,7 +353,7 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source)
|
||||||
return client.SourcePersistent
|
return client.SourcePersistent
|
||||||
}
|
}
|
||||||
|
|
||||||
rc := clients.runtimeIndex.Client(ip)
|
rc := clients.storage.ClientRuntime(ip)
|
||||||
if rc != nil {
|
if rc != nil {
|
||||||
src, _ = rc.Info()
|
src, _ = rc.Info()
|
||||||
}
|
}
|
||||||
|
@ -539,22 +534,9 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent,
|
||||||
return clients.storage.FindByMAC(foundMAC)
|
return clients.storage.FindByMAC(foundMAC)
|
||||||
}
|
}
|
||||||
|
|
||||||
// runtimeClient returns a runtime client from internal index. Note that it
|
|
||||||
// doesn't include DHCP clients.
|
|
||||||
func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime) {
|
|
||||||
if ip == (netip.Addr{}) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
clients.lock.Lock()
|
|
||||||
defer clients.lock.Unlock()
|
|
||||||
|
|
||||||
return clients.runtimeIndex.Client(ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// findRuntimeClient finds a runtime client by their IP.
|
// findRuntimeClient finds a runtime client by their IP.
|
||||||
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime) {
|
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime) {
|
||||||
rc = clients.runtimeClient(ip)
|
rc = clients.storage.ClientRuntime(ip)
|
||||||
host := clients.dhcp.HostByIP(ip)
|
host := clients.dhcp.HostByIP(ip)
|
||||||
|
|
||||||
if host != "" {
|
if host != "" {
|
||||||
|
@ -580,20 +562,11 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rc := clients.runtimeIndex.Client(ip)
|
rc := client.NewRuntime(ip)
|
||||||
if rc == nil {
|
|
||||||
// Create a RuntimeClient implicitly so that we don't do this check
|
|
||||||
// again.
|
|
||||||
rc = client.NewRuntime(ip)
|
|
||||||
clients.runtimeIndex.Add(rc)
|
|
||||||
|
|
||||||
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
|
||||||
} else {
|
|
||||||
host, _ := rc.Info()
|
|
||||||
log.Debug("clients: set whois info for runtime client %s: %+v", host, wi)
|
|
||||||
}
|
|
||||||
|
|
||||||
rc.SetWHOIS(wi)
|
rc.SetWHOIS(wi)
|
||||||
|
clients.storage.UpdateRuntime(rc)
|
||||||
|
|
||||||
|
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||||
}
|
}
|
||||||
|
|
||||||
// addHost adds a new IP-hostname pairing. The priorities of the sources are
|
// addHost adds a new IP-hostname pairing. The priorities of the sources are
|
||||||
|
@ -644,26 +617,20 @@ func (clients *clientsContainer) addHostLocked(
|
||||||
host string,
|
host string,
|
||||||
src client.Source,
|
src client.Source,
|
||||||
) (ok bool) {
|
) (ok bool) {
|
||||||
rc := clients.runtimeIndex.Client(ip)
|
rc := client.NewRuntime(ip)
|
||||||
if rc == nil {
|
rc.SetInfo(src, []string{host})
|
||||||
if src < client.SourceDHCP {
|
if dhcpHost := clients.dhcp.HostByIP(ip); dhcpHost != "" {
|
||||||
if clients.dhcp.HostByIP(ip) != "" {
|
rc.SetInfo(client.SourceDHCP, []string{dhcpHost})
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rc = client.NewRuntime(ip)
|
|
||||||
clients.runtimeIndex.Add(rc)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rc.SetInfo(src, []string{host})
|
clients.storage.UpdateRuntime(rc)
|
||||||
|
|
||||||
log.Debug(
|
log.Debug(
|
||||||
"clients: adding client info %s -> %q %q [%d]",
|
"clients: adding client info %s -> %q %q [%d]",
|
||||||
ip,
|
ip,
|
||||||
src,
|
src,
|
||||||
host,
|
host,
|
||||||
clients.runtimeIndex.Size(),
|
clients.storage.SizeRuntime(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return true
|
return true
|
||||||
|
@ -675,23 +642,22 @@ func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorag
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
deleted := clients.runtimeIndex.DeleteBySource(client.SourceHostsFile)
|
var rcs []*client.Runtime
|
||||||
log.Debug("clients: removed %d client aliases from system hosts file", deleted)
|
|
||||||
|
|
||||||
added := 0
|
|
||||||
hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) {
|
hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) {
|
||||||
// Only the first name of the first record is considered a canonical
|
// Only the first name of the first record is considered a canonical
|
||||||
// hostname for the IP address.
|
// hostname for the IP address.
|
||||||
//
|
//
|
||||||
// TODO(e.burkov): Consider using all the names from all the records.
|
// TODO(e.burkov): Consider using all the names from all the records.
|
||||||
if clients.addHostLocked(addr, names[0], client.SourceHostsFile) {
|
rc := client.NewRuntime(addr)
|
||||||
added++
|
rc.SetInfo(client.SourceHostsFile, []string{names[0]})
|
||||||
}
|
|
||||||
|
rcs = append(rcs, rc)
|
||||||
|
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
log.Debug("clients: added %d client aliases from system hosts file", added)
|
added, removed := clients.storage.BatchUpdateBySource(client.SourceHostsFile, rcs)
|
||||||
|
log.Debug("clients: added %d, removed %d client aliases from system hosts file", added, removed)
|
||||||
}
|
}
|
||||||
|
|
||||||
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
|
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
|
||||||
|
@ -715,17 +681,16 @@ func (clients *clientsContainer) addFromSystemARP() {
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
deleted := clients.runtimeIndex.DeleteBySource(client.SourceARP)
|
var rcs []*client.Runtime
|
||||||
log.Debug("clients: removed %d client aliases from arp neighborhood", deleted)
|
|
||||||
|
|
||||||
added := 0
|
|
||||||
for _, n := range ns {
|
for _, n := range ns {
|
||||||
if clients.addHostLocked(n.IP, n.Name, client.SourceARP) {
|
rc := client.NewRuntime(n.IP)
|
||||||
added++
|
rc.SetInfo(client.SourceARP, []string{n.Name})
|
||||||
}
|
|
||||||
|
rcs = append(rcs, rc)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("clients: added %d client aliases from arp neighborhood", added)
|
added, removed := clients.storage.BatchUpdateBySource(client.SourceARP, rcs)
|
||||||
|
log.Debug("clients: added %d, removed %d client aliases from arp neighborhood", added, removed)
|
||||||
}
|
}
|
||||||
|
|
||||||
// close gracefully closes all the client-specific upstream configurations of
|
// close gracefully closes all the client-specific upstream configurations of
|
||||||
|
|
|
@ -240,7 +240,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||||
t.Run("new_client", func(t *testing.T) {
|
t.Run("new_client", func(t *testing.T) {
|
||||||
ip := netip.MustParseAddr("1.1.1.255")
|
ip := netip.MustParseAddr("1.1.1.255")
|
||||||
clients.setWHOISInfo(ip, whois)
|
clients.setWHOISInfo(ip, whois)
|
||||||
rc := clients.runtimeIndex.Client(ip)
|
rc := clients.storage.ClientRuntime(ip)
|
||||||
require.NotNil(t, rc)
|
require.NotNil(t, rc)
|
||||||
|
|
||||||
assert.Equal(t, whois, rc.WHOIS())
|
assert.Equal(t, whois, rc.WHOIS())
|
||||||
|
@ -252,7 +252,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
clients.setWHOISInfo(ip, whois)
|
clients.setWHOISInfo(ip, whois)
|
||||||
rc := clients.runtimeIndex.Client(ip)
|
rc := clients.storage.ClientRuntime(ip)
|
||||||
require.NotNil(t, rc)
|
require.NotNil(t, rc)
|
||||||
|
|
||||||
assert.Equal(t, whois, rc.WHOIS())
|
assert.Equal(t, whois, rc.WHOIS())
|
||||||
|
@ -269,7 +269,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
clients.setWHOISInfo(ip, whois)
|
clients.setWHOISInfo(ip, whois)
|
||||||
rc := clients.runtimeIndex.Client(ip)
|
rc := clients.storage.ClientRuntime(ip)
|
||||||
require.Nil(t, rc)
|
require.Nil(t, rc)
|
||||||
|
|
||||||
assert.True(t, clients.storage.RemoveByName("client1"))
|
assert.True(t, clients.storage.RemoveByName("client1"))
|
||||||
|
|
|
@ -103,7 +103,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
clients.runtimeIndex.Range(func(rc *client.Runtime) (cont bool) {
|
clients.storage.RangeRuntime(func(rc *client.Runtime) (cont bool) {
|
||||||
src, host := rc.Info()
|
src, host := rc.Info()
|
||||||
cj := runtimeClientJSON{
|
cj := runtimeClientJSON{
|
||||||
WHOIS: whoisOrEmpty(rc),
|
WHOIS: whoisOrEmpty(rc),
|
||||||
|
|
|
@ -354,6 +354,19 @@ type Info struct {
|
||||||
Orgname string `json:"orgname,omitempty"`
|
Orgname string `json:"orgname,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clone returns a deep copy of the WHOIS info.
|
||||||
|
func (i *Info) Clone() (c *Info) {
|
||||||
|
if i == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Info{
|
||||||
|
City: i.City,
|
||||||
|
Country: i.Country,
|
||||||
|
Orgname: i.Orgname,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// cacheItem represents an item that we will store in the cache.
|
// cacheItem represents an item that we will store in the cache.
|
||||||
type cacheItem struct {
|
type cacheItem struct {
|
||||||
// expiry is the time when cacheItem will expire.
|
// expiry is the time when cacheItem will expire.
|
||||||
|
|
Loading…
Add table
Reference in a new issue