mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-21 20:45:33 +03:00
Pull request 2208: AG-27492-client-persistent-list
Squashed commit of the following: commit 1b1a21b07baa15499e5e4963d35bfd2e542533ed Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed May 8 17:32:38 2024 +0300 client: imp tests commit 7e6d17158a254aa29bf4033fb68171d4209bb954 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed May 8 17:27:00 2024 +0300 client: imp tests commit 5e4cd2b3ca9557929b9b79a0610151ce09c792f9 Merge: 7faddd8aa1a62ce471
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed May 8 15:57:33 2024 +0300 Merge branch 'master' into AG-27492-client-persistent-list commit 7faddd8aade2b1b791beec694b88513b0a2a520e Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon May 6 20:55:43 2024 +0300 client: imp code commit 54212e975b700f792a53fc3bfe1c2970778e05ea Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon May 6 20:24:18 2024 +0300 all: imp code commit 3f23c9af470036c2166e20c8d0b5d84810b35b6e Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon May 6 17:07:40 2024 +0300 home: imp tests commit 39b99fc050047cebadc51ae64e220ec1cb873d83 Merge: 76469ac5917c4eeb64
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon May 6 16:39:56 2024 +0300 Merge branch 'master' into AG-27492-client-persistent-list commit 76469ac59400aae2f7563750a981138b8cbf3aa1 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon May 6 14:36:22 2024 +0300 home: imp naming commit 4e4aa5802c9aafc67c52b8a290d8046531f8a1c8 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu May 2 19:50:45 2024 +0300 client: imp docs commit bf5c23a72c93e58c8bc7e0ca896b2ea28519cf54 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu May 2 19:40:53 2024 +0300 home: add tests commit c6cdba7a8d0dfce22634f88258f61abb09ecca5a Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Apr 24 14:21:44 2024 +0300 all: add tests commit 1fc43cb45efbd428abaae9eba030f9bea818dfe3 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Apr 19 19:19:48 2024 +0300 all: add tests commit ccc423b296d9037f0aa23a125a5ad3af95b8c9f3 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Apr 19 15:37:15 2024 +0300 all: client persistent list
This commit is contained in:
parent
1a62ce471e
commit
71c44fa40c
7 changed files with 634 additions and 151 deletions
|
@ -4,9 +4,12 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
// macKey contains MAC as byte array of 6, 8, or 20 bytes.
|
||||
|
@ -29,6 +32,9 @@ func macToKey(mac net.HardwareAddr) (key macKey) {
|
|||
|
||||
// Index stores all information about persistent clients.
|
||||
type Index struct {
|
||||
// nameToUID maps client name to UID.
|
||||
nameToUID map[string]UID
|
||||
|
||||
// clientIDToUID maps client ID to UID.
|
||||
clientIDToUID map[string]UID
|
||||
|
||||
|
@ -48,6 +54,7 @@ type Index struct {
|
|||
// NewIndex initializes the new instance of client index.
|
||||
func NewIndex() (ci *Index) {
|
||||
return &Index{
|
||||
nameToUID: map[string]UID{},
|
||||
clientIDToUID: map[string]UID{},
|
||||
ipToUID: map[netip.Addr]UID{},
|
||||
subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare),
|
||||
|
@ -63,6 +70,8 @@ func (ci *Index) Add(c *Persistent) {
|
|||
panic("client must contain uid")
|
||||
}
|
||||
|
||||
ci.nameToUID[c.Name] = c.UID
|
||||
|
||||
for _, id := range c.ClientIDs {
|
||||
ci.clientIDToUID[id] = c.UID
|
||||
}
|
||||
|
@ -83,21 +92,26 @@ func (ci *Index) Add(c *Persistent) {
|
|||
ci.uidToClient[c.UID] = c
|
||||
}
|
||||
|
||||
// ErrDuplicateUID is an error returned by [Index.Clashes] when adding a
|
||||
// persistent client with a UID that already exists in an index.
|
||||
const ErrDuplicateUID errors.Error = "duplicate uid"
|
||||
// ClashesUID returns existing persistent client with the same UID as c. Note
|
||||
// that this is only possible when configuration contains duplicate fields.
|
||||
func (ci *Index) ClashesUID(c *Persistent) (err error) {
|
||||
p, ok := ci.uidToClient[c.UID]
|
||||
if ok {
|
||||
return fmt.Errorf("another client %q uses the same uid", p.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clashes returns an error if the index contains a different persistent client
|
||||
// with at least a single identifier contained by c. c must be non-nil.
|
||||
func (ci *Index) Clashes(c *Persistent) (err error) {
|
||||
_, ok := ci.uidToClient[c.UID]
|
||||
if ok {
|
||||
return ErrDuplicateUID
|
||||
if p := ci.clashesName(c); p != nil {
|
||||
return fmt.Errorf("another client uses the same name %q", p.Name)
|
||||
}
|
||||
|
||||
for _, id := range c.ClientIDs {
|
||||
var existing UID
|
||||
existing, ok = ci.clientIDToUID[id]
|
||||
existing, ok := ci.clientIDToUID[id]
|
||||
if ok && existing != c.UID {
|
||||
p := ci.uidToClient[existing]
|
||||
|
||||
|
@ -123,6 +137,21 @@ func (ci *Index) Clashes(c *Persistent) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
// clashesName returns existing persistent client with the same name as c or
|
||||
// nil. c must be non-nil.
|
||||
func (ci *Index) clashesName(c *Persistent) (existing *Persistent) {
|
||||
existing, ok := ci.FindByName(c.Name)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if existing.UID != c.UID {
|
||||
return existing
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// clashesIP returns a previous client with the same IP address as c. c must be
|
||||
// non-nil.
|
||||
func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) {
|
||||
|
@ -195,13 +224,23 @@ func (ci *Index) Find(id string) (c *Persistent, ok bool) {
|
|||
|
||||
mac, err := net.ParseMAC(id)
|
||||
if err == nil {
|
||||
return ci.findByMAC(mac)
|
||||
return ci.FindByMAC(mac)
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// find finds persistent client by IP address.
|
||||
// FindByName finds persistent client by name.
|
||||
func (ci *Index) FindByName(name string) (c *Persistent, found bool) {
|
||||
uid, found := ci.nameToUID[name]
|
||||
if found {
|
||||
return ci.uidToClient[uid], true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findByIP finds persistent client by IP address.
|
||||
func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
|
||||
uid, found := ci.ipToUID[ip]
|
||||
if found {
|
||||
|
@ -227,6 +266,17 @@ func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
|
|||
return nil, false
|
||||
}
|
||||
|
||||
// FindByMAC finds persistent client by MAC.
|
||||
func (ci *Index) FindByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
|
||||
k := macToKey(mac)
|
||||
uid, found := ci.macToUID[k]
|
||||
if found {
|
||||
return ci.uidToClient[uid], true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// FindByIPWithoutZone finds a persistent client by IP address without zone. It
|
||||
// strips the IPv6 zone index from the stored IP addresses before comparing,
|
||||
// because querylog entries don't have it. See TODO on [querylog.logEntry.IP].
|
||||
|
@ -247,20 +297,11 @@ func (ci *Index) FindByIPWithoutZone(ip netip.Addr) (c *Persistent) {
|
|||
return nil
|
||||
}
|
||||
|
||||
// find finds persistent client by MAC.
|
||||
func (ci *Index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
|
||||
k := macToKey(mac)
|
||||
uid, found := ci.macToUID[k]
|
||||
if found {
|
||||
return ci.uidToClient[uid], true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Delete removes information about persistent client from the index. c must be
|
||||
// non-nil.
|
||||
func (ci *Index) Delete(c *Persistent) {
|
||||
delete(ci.nameToUID, c.Name)
|
||||
|
||||
for _, id := range c.ClientIDs {
|
||||
delete(ci.clientIDToUID, id)
|
||||
}
|
||||
|
@ -280,3 +321,48 @@ func (ci *Index) Delete(c *Persistent) {
|
|||
|
||||
delete(ci.uidToClient, c.UID)
|
||||
}
|
||||
|
||||
// Size returns the number of persistent clients.
|
||||
func (ci *Index) Size() (n int) {
|
||||
return len(ci.uidToClient)
|
||||
}
|
||||
|
||||
// Range calls f for each persistent client, unless cont is false. The order is
|
||||
// undefined.
|
||||
func (ci *Index) Range(f func(c *Persistent) (cont bool)) {
|
||||
for _, c := range ci.uidToClient {
|
||||
if !f(c) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RangeByName is like [Index.Range] but sorts the persistent clients by name
|
||||
// before iterating ensuring a predictable order.
|
||||
func (ci *Index) RangeByName(f func(c *Persistent) (cont bool)) {
|
||||
cs := maps.Values(ci.uidToClient)
|
||||
slices.SortFunc(cs, func(a, b *Persistent) (n int) {
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
for _, c := range cs {
|
||||
if !f(c) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CloseUpstreams closes upstream configurations of persistent clients.
|
||||
func (ci *Index) CloseUpstreams() (err error) {
|
||||
var errs []error
|
||||
ci.RangeByName(func(c *Persistent) (cont bool) {
|
||||
err = c.CloseUpstreams()
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ func newIDIndex(m []*Persistent) (ci *Index) {
|
|||
return ci
|
||||
}
|
||||
|
||||
func TestClientIndex(t *testing.T) {
|
||||
func TestClientIndex_Find(t *testing.T) {
|
||||
const (
|
||||
cliIPNone = "1.2.3.4"
|
||||
cliIP1 = "1.1.1.1"
|
||||
|
@ -71,13 +71,14 @@ func TestClientIndex(t *testing.T) {
|
|||
}
|
||||
)
|
||||
|
||||
ci := newIDIndex([]*Persistent{
|
||||
clients := []*Persistent{
|
||||
clientWithBothFams,
|
||||
clientWithSubnet,
|
||||
clientWithMAC,
|
||||
clientWithID,
|
||||
clientLinkLocal,
|
||||
})
|
||||
}
|
||||
ci := newIDIndex(clients)
|
||||
|
||||
testCases := []struct {
|
||||
want *Persistent
|
||||
|
@ -296,3 +297,54 @@ func TestIndex_FindByIPWithoutZone(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIndex_RangeByName(t *testing.T) {
|
||||
sortedClients := []*Persistent{{
|
||||
Name: "clientA",
|
||||
ClientIDs: []string{"A"},
|
||||
}, {
|
||||
Name: "clientB",
|
||||
ClientIDs: []string{"B"},
|
||||
}, {
|
||||
Name: "clientC",
|
||||
ClientIDs: []string{"C"},
|
||||
}, {
|
||||
Name: "clientD",
|
||||
ClientIDs: []string{"D"},
|
||||
}, {
|
||||
Name: "clientE",
|
||||
ClientIDs: []string{"E"},
|
||||
}}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
want []*Persistent
|
||||
}{{
|
||||
name: "basic",
|
||||
want: sortedClients,
|
||||
}, {
|
||||
name: "nil",
|
||||
want: nil,
|
||||
}, {
|
||||
name: "one_element",
|
||||
want: sortedClients[:1],
|
||||
}, {
|
||||
name: "two_elements",
|
||||
want: sortedClients[:2],
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ci := newIDIndex(tc.want)
|
||||
|
||||
var got []*Persistent
|
||||
ci.RangeByName(func(c *Persistent) (cont bool) {
|
||||
got = append(got, c)
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,7 +24,6 @@ import (
|
|||
"github.com/AdguardTeam/golibs/hostsfile"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
// DHCP is an interface for accessing DHCP lease data the [clientsContainer]
|
||||
|
@ -46,10 +45,6 @@ type DHCP interface {
|
|||
|
||||
// clientsContainer is the storage of all runtime and persistent clients.
|
||||
type clientsContainer struct {
|
||||
// TODO(a.garipov): Perhaps use a number of separate indices for different
|
||||
// types (string, netip.Addr, and so on).
|
||||
list map[string]*client.Persistent // name -> client
|
||||
|
||||
// clientIndex stores information about persistent clients.
|
||||
clientIndex *client.Index
|
||||
|
||||
|
@ -61,8 +56,9 @@ type clientsContainer struct {
|
|||
// dhcp is the DHCP service implementation.
|
||||
dhcp DHCP
|
||||
|
||||
// dnsServer is used for checking clients IP status access list status
|
||||
dnsServer *dnsforward.Server
|
||||
// clientChecker checks if a client is blocked by the current access
|
||||
// settings.
|
||||
clientChecker BlockedClientChecker
|
||||
|
||||
// etcHosts contains list of rewrite rules taken from the operating system's
|
||||
// hosts database.
|
||||
|
@ -91,6 +87,12 @@ type clientsContainer struct {
|
|||
testing bool
|
||||
}
|
||||
|
||||
// BlockedClientChecker checks if a client is blocked by the current access
|
||||
// settings.
|
||||
type BlockedClientChecker interface {
|
||||
IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string)
|
||||
}
|
||||
|
||||
// Init initializes clients container
|
||||
// dhcpServer: optional
|
||||
// Note: this function must be called only once
|
||||
|
@ -101,11 +103,11 @@ func (clients *clientsContainer) Init(
|
|||
arpDB arpdb.Interface,
|
||||
filteringConf *filtering.Config,
|
||||
) (err error) {
|
||||
if clients.list != nil {
|
||||
log.Fatal("clients.list != nil")
|
||||
// TODO(s.chzhen): Refactor it.
|
||||
if clients.clientIndex != nil {
|
||||
return errors.Error("clients container already initialized")
|
||||
}
|
||||
|
||||
clients.list = map[string]*client.Persistent{}
|
||||
clients.runtimeIndex = client.NewRuntimeIndex()
|
||||
|
||||
clients.clientIndex = client.NewIndex()
|
||||
|
@ -284,12 +286,14 @@ func (clients *clientsContainer) addFromConfig(
|
|||
return fmt.Errorf("clients: init persistent client at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
_, err = clients.add(cli)
|
||||
// TODO(s.chzhen): Consider moving to the client index constructor.
|
||||
err = clients.clientIndex.ClashesUID(cli)
|
||||
if err != nil {
|
||||
if errors.Is(err, client.ErrDuplicateUID) {
|
||||
return fmt.Errorf("clients: adding client %s at index %d: %w", cli.Name, i, err)
|
||||
return fmt.Errorf("adding client %s at index %d: %w", cli.Name, i, err)
|
||||
}
|
||||
|
||||
err = clients.add(cli)
|
||||
if err != nil {
|
||||
// TODO(s.chzhen): Return an error instead of logging if more
|
||||
// stringent requirements are implemented.
|
||||
log.Error("clients: adding client %s at index %d: %s", cli.Name, i, err)
|
||||
|
@ -305,9 +309,9 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
|||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
objs = make([]*clientObject, 0, len(clients.list))
|
||||
for _, cli := range clients.list {
|
||||
o := &clientObject{
|
||||
objs = make([]*clientObject, 0, clients.clientIndex.Size())
|
||||
clients.clientIndex.Range(func(cli *client.Persistent) (cont bool) {
|
||||
objs = append(objs, &clientObject{
|
||||
Name: cli.Name,
|
||||
|
||||
BlockedServices: cli.BlockedServices.Clone(),
|
||||
|
@ -328,10 +332,10 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
|||
IgnoreStatistics: cli.IgnoreStatistics,
|
||||
UpstreamsCacheEnabled: cli.UpstreamsCacheEnabled,
|
||||
UpstreamsCacheSize: cli.UpstreamsCacheSize,
|
||||
}
|
||||
})
|
||||
|
||||
objs = append(objs, o)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Maps aren't guaranteed to iterate in the same order each time, so the
|
||||
// above loop can generate different orderings when writing to the config
|
||||
|
@ -411,7 +415,7 @@ func (clients *clientsContainer) clientOrArtificial(
|
|||
id string,
|
||||
) (c *querylog.Client, art bool) {
|
||||
defer func() {
|
||||
c.Disallowed, c.DisallowedRule = clients.dnsServer.IsBlockedClient(ip, id)
|
||||
c.Disallowed, c.DisallowedRule = clients.clientChecker.IsBlockedClient(ip, id)
|
||||
if c.WHOIS == nil {
|
||||
c.WHOIS = &whois.Info{}
|
||||
}
|
||||
|
@ -550,14 +554,7 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent,
|
|||
return nil, false
|
||||
}
|
||||
|
||||
for _, c = range clients.list {
|
||||
_, found := slices.BinarySearchFunc(c.MACs, foundMAC, slices.Compare[net.HardwareAddr])
|
||||
if found {
|
||||
return c, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
return clients.clientIndex.FindByMAC(foundMAC)
|
||||
}
|
||||
|
||||
// runtimeClient returns a runtime client from internal index. Note that it
|
||||
|
@ -621,43 +618,32 @@ func (clients *clientsContainer) check(c *client.Persistent) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
// add adds a new client object. ok is false if such client already exists or
|
||||
// if an error occurred.
|
||||
func (clients *clientsContainer) add(c *client.Persistent) (ok bool, err error) {
|
||||
// add adds a persistent client or returns an error.
|
||||
func (clients *clientsContainer) add(c *client.Persistent) (err error) {
|
||||
err = clients.check(c)
|
||||
if err != nil {
|
||||
return false, err
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
// check Name index
|
||||
_, ok = clients.list[c.Name]
|
||||
if ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// check ID index
|
||||
err = clients.clientIndex.Clashes(c)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
|
||||
clients.addLocked(c)
|
||||
|
||||
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs(), len(clients.list))
|
||||
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs(), clients.clientIndex.Size())
|
||||
|
||||
return true, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// addLocked c to the indexes. clients.lock is expected to be locked.
|
||||
func (clients *clientsContainer) addLocked(c *client.Persistent) {
|
||||
// update Name index
|
||||
clients.list[c.Name] = c
|
||||
|
||||
// update ID index
|
||||
clients.clientIndex.Add(c)
|
||||
}
|
||||
|
||||
|
@ -666,8 +652,7 @@ func (clients *clientsContainer) remove(name string) (ok bool) {
|
|||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
var c *client.Persistent
|
||||
c, ok = clients.list[name]
|
||||
c, ok := clients.clientIndex.FindByName(name)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
@ -684,9 +669,6 @@ func (clients *clientsContainer) removeLocked(c *client.Persistent) {
|
|||
log.Error("client container: removing client %s: %s", c.Name, err)
|
||||
}
|
||||
|
||||
// Update the name index.
|
||||
delete(clients.list, c.Name)
|
||||
|
||||
// Update the ID index.
|
||||
clients.clientIndex.Delete(c)
|
||||
}
|
||||
|
@ -702,22 +684,6 @@ func (clients *clientsContainer) update(prev, c *client.Persistent) (err error)
|
|||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
// Check the name index.
|
||||
if prev.Name != c.Name {
|
||||
_, ok := clients.list[c.Name]
|
||||
if ok {
|
||||
return errors.Error("client already exists")
|
||||
}
|
||||
}
|
||||
|
||||
if c.EqualIDs(prev) {
|
||||
clients.removeLocked(prev)
|
||||
clients.addLocked(c)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check the ID index.
|
||||
err = clients.clientIndex.Clashes(c)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
|
@ -891,18 +857,5 @@ func (clients *clientsContainer) addFromSystemARP() {
|
|||
// close gracefully closes all the client-specific upstream configurations of
|
||||
// the persistent clients.
|
||||
func (clients *clientsContainer) close() (err error) {
|
||||
persistent := maps.Values(clients.list)
|
||||
slices.SortFunc(persistent, func(a, b *client.Persistent) (res int) {
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
var errs []error
|
||||
|
||||
for _, cli := range persistent {
|
||||
if err = cli.CloseUpstreams(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Join(errs...)
|
||||
return clients.clientIndex.CloseUpstreams()
|
||||
}
|
||||
|
|
|
@ -41,7 +41,7 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) {
|
|||
}
|
||||
|
||||
dhcp := &testDHCP{
|
||||
OnLeases: func() (leases []*dhcpsvc.Lease) { panic("not implemented") },
|
||||
OnLeases: func() (leases []*dhcpsvc.Lease) { return nil },
|
||||
OnHostBy: func(ip netip.Addr) (host string) { return "" },
|
||||
OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) { return nil },
|
||||
}
|
||||
|
@ -72,23 +72,19 @@ func TestClients(t *testing.T) {
|
|||
IPs: []netip.Addr{cli1IP, cliIPv6},
|
||||
}
|
||||
|
||||
ok, err := clients.add(c)
|
||||
err := clients.add(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
c = &client.Persistent{
|
||||
Name: "client2",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{cli2IP},
|
||||
}
|
||||
|
||||
ok, err = clients.add(c)
|
||||
err = clients.add(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
c, ok = clients.find(cli1)
|
||||
c, ok := clients.find(cli1)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client1", c.Name)
|
||||
|
@ -111,22 +107,20 @@ func TestClients(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("add_fail_name", func(t *testing.T) {
|
||||
ok, err := clients.add(&client.Persistent{
|
||||
err := clients.add(&client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ok)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("add_fail_ip", func(t *testing.T) {
|
||||
ok, err := clients.add(&client.Persistent{
|
||||
err := clients.add(&client.Persistent{
|
||||
Name: "client3",
|
||||
UID: client.MustNewUID(),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("update_fail_ip", func(t *testing.T) {
|
||||
|
@ -145,12 +139,13 @@ func TestClients(t *testing.T) {
|
|||
cliNewIP = netip.MustParseAddr(cliNew)
|
||||
)
|
||||
|
||||
prev, ok := clients.list["client1"]
|
||||
prev, ok := clients.clientIndex.FindByName("client1")
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, prev)
|
||||
|
||||
err := clients.update(prev, &client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
UID: prev.UID,
|
||||
IPs: []netip.Addr{cliNewIP},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@ -160,12 +155,13 @@ func TestClients(t *testing.T) {
|
|||
|
||||
assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent)
|
||||
|
||||
prev, ok = clients.list["client1"]
|
||||
prev, ok = clients.clientIndex.FindByName("client1")
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, prev)
|
||||
|
||||
err = clients.update(prev, &client.Persistent{
|
||||
Name: "client1-renamed",
|
||||
UID: client.MustNewUID(),
|
||||
UID: prev.UID,
|
||||
IPs: []netip.Addr{cliNewIP},
|
||||
UseOwnSettings: true,
|
||||
})
|
||||
|
@ -177,7 +173,7 @@ func TestClients(t *testing.T) {
|
|||
assert.Equal(t, "client1-renamed", c.Name)
|
||||
assert.True(t, c.UseOwnSettings)
|
||||
|
||||
nilCli, ok := clients.list["client1"]
|
||||
nilCli, ok := clients.clientIndex.FindByName("client1")
|
||||
require.False(t, ok)
|
||||
|
||||
assert.Nil(t, nilCli)
|
||||
|
@ -265,13 +261,12 @@ func TestClientsWHOIS(t *testing.T) {
|
|||
t.Run("can't_set_manually-added", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.2")
|
||||
|
||||
ok, err := clients.add(&client.Persistent{
|
||||
err := clients.add(&client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
clients.setWHOISInfo(ip, whois)
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
|
@ -288,7 +283,7 @@ func TestClientsAddExisting(t *testing.T) {
|
|||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
|
||||
// Add a client.
|
||||
ok, err := clients.add(&client.Persistent{
|
||||
err := clients.add(&client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")},
|
||||
|
@ -296,10 +291,9 @@ func TestClientsAddExisting(t *testing.T) {
|
|||
MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Now add an auto-client with the same IP.
|
||||
ok = clients.addHost(ip, "test", client.SourceRDNS)
|
||||
ok := clients.addHost(ip, "test", client.SourceRDNS)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
|
@ -339,22 +333,20 @@ func TestClientsAddExisting(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Add a new client with the same IP as for a client with MAC.
|
||||
ok, err := clients.add(&client.Persistent{
|
||||
err = clients.add(&client.Persistent{
|
||||
Name: "client2",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{ip},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Add a new client with the IP from the first client's IP range.
|
||||
ok, err = clients.add(&client.Persistent{
|
||||
err = clients.add(&client.Persistent{
|
||||
Name: "client3",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -362,7 +354,7 @@ func TestClientsCustomUpstream(t *testing.T) {
|
|||
clients := newClientsContainer(t)
|
||||
|
||||
// Add client with upstreams.
|
||||
ok, err := clients.add(&client.Persistent{
|
||||
err := clients.add(&client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")},
|
||||
|
@ -372,7 +364,6 @@ func TestClientsCustomUpstream(t *testing.T) {
|
|||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
upsConf, err := clients.UpstreamConfigByID("1.2.3.4", net.DefaultResolver)
|
||||
assert.Nil(t, upsConf)
|
||||
|
|
|
@ -96,10 +96,12 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
|||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
for _, c := range clients.list {
|
||||
clients.clientIndex.Range(func(c *client.Persistent) (cont bool) {
|
||||
cj := clientToJSON(c)
|
||||
data.Clients = append(data.Clients, cj)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
clients.runtimeIndex.Range(func(rc *client.Runtime) (cont bool) {
|
||||
src, host := rc.Info()
|
||||
|
@ -334,20 +336,16 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
|
|||
return
|
||||
}
|
||||
|
||||
ok, err := clients.add(c)
|
||||
err = clients.add(c)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !ok {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Client already exists")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !clients.testing {
|
||||
onConfigModified()
|
||||
}
|
||||
}
|
||||
|
||||
// handleDelClient is the handler for POST /control/clients/delete HTTP API.
|
||||
|
@ -372,7 +370,9 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.
|
|||
return
|
||||
}
|
||||
|
||||
if !clients.testing {
|
||||
onConfigModified()
|
||||
}
|
||||
}
|
||||
|
||||
// updateJSON contains the name and data of the updated persistent client.
|
||||
|
@ -406,7 +406,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
|||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
prev, ok = clients.list[dj.Name]
|
||||
prev, ok = clients.clientIndex.FindByName(dj.Name)
|
||||
}()
|
||||
|
||||
if !ok {
|
||||
|
@ -429,7 +429,9 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
|||
return
|
||||
}
|
||||
|
||||
if !clients.testing {
|
||||
onConfigModified()
|
||||
}
|
||||
}
|
||||
|
||||
// handleFindClient is the handler for GET /control/clients/find HTTP API.
|
||||
|
@ -449,7 +451,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
|||
cj = clients.findRuntime(ip, idStr)
|
||||
} else {
|
||||
cj = clientToJSON(c)
|
||||
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
|
||||
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
|
||||
}
|
||||
|
||||
|
@ -472,7 +474,7 @@ func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *c
|
|||
// blocked IP list.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
|
||||
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
|
||||
cj = &clientJSON{
|
||||
IDs: []string{idStr},
|
||||
Disallowed: &disallowed,
|
||||
|
@ -490,7 +492,7 @@ func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *c
|
|||
WHOIS: whoisOrEmpty(rc),
|
||||
}
|
||||
|
||||
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
|
||||
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
|
||||
|
||||
return cj
|
||||
|
|
399
internal/home/clientshttp_internal_test.go
Normal file
399
internal/home/clientshttp_internal_test.go
Normal file
|
@ -0,0 +1,399 @@
|
|||
package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
testClientIP1 = "1.1.1.1"
|
||||
testClientIP2 = "2.2.2.2"
|
||||
)
|
||||
|
||||
// testBlockedClientChecker is a mock implementation of the
|
||||
// [BlockedClientChecker] interface.
|
||||
type testBlockedClientChecker struct {
|
||||
onIsBlockedClient func(ip netip.Addr, clientiD string) (blocked bool, rule string)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ BlockedClientChecker = (*testBlockedClientChecker)(nil)
|
||||
|
||||
// IsBlockedClient implements the [BlockedClientChecker] interface for
|
||||
// *testBlockedClientChecker.
|
||||
func (c *testBlockedClientChecker) IsBlockedClient(
|
||||
ip netip.Addr,
|
||||
clientID string,
|
||||
) (blocked bool, rule string) {
|
||||
return c.onIsBlockedClient(ip, clientID)
|
||||
}
|
||||
|
||||
// newPersistentClient is a helper function that returns a persistent client
|
||||
// with the specified name and newly generated UID.
|
||||
func newPersistentClient(name string) (c *client.Persistent) {
|
||||
return &client.Persistent{
|
||||
Name: name,
|
||||
UID: client.MustNewUID(),
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: &schedule.Weekly{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// newPersistentClientWithIDs is a helper function that returns a persistent
|
||||
// client with the specified name and ids.
|
||||
func newPersistentClientWithIDs(tb testing.TB, name string, ids []string) (c *client.Persistent) {
|
||||
tb.Helper()
|
||||
|
||||
c = newPersistentClient(name)
|
||||
err := c.SetIDs(ids)
|
||||
require.NoError(tb, err)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// assertClients is a helper function that compares lists of persistent clients.
|
||||
func assertClients(tb testing.TB, want, got []*client.Persistent) {
|
||||
tb.Helper()
|
||||
|
||||
require.Len(tb, got, len(want))
|
||||
|
||||
sortFunc := func(a, b *client.Persistent) (n int) {
|
||||
return cmp.Compare(a.Name, b.Name)
|
||||
}
|
||||
|
||||
slices.SortFunc(want, sortFunc)
|
||||
slices.SortFunc(got, sortFunc)
|
||||
|
||||
slices.CompareFunc(want, got, func(a, b *client.Persistent) (n int) {
|
||||
assert.True(tb, a.EqualIDs(b), "%q doesn't have the same ids as %q", a.Name, b.Name)
|
||||
|
||||
return 0
|
||||
})
|
||||
}
|
||||
|
||||
// assertPersistentClients is a helper function that uses HTTP API to check
|
||||
// whether want persistent clients are the same as the persistent clients stored
|
||||
// in the clients container.
|
||||
func assertPersistentClients(tb testing.TB, clients *clientsContainer, want []*client.Persistent) {
|
||||
tb.Helper()
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
clients.handleGetClients(rw, &http.Request{})
|
||||
|
||||
body, err := io.ReadAll(rw.Body)
|
||||
require.NoError(tb, err)
|
||||
|
||||
clientList := &clientListJSON{}
|
||||
err = json.Unmarshal(body, clientList)
|
||||
require.NoError(tb, err)
|
||||
|
||||
var got []*client.Persistent
|
||||
for _, cj := range clientList.Clients {
|
||||
var c *client.Persistent
|
||||
c, err = clients.jsonToClient(*cj, nil)
|
||||
require.NoError(tb, err)
|
||||
|
||||
got = append(got, c)
|
||||
}
|
||||
|
||||
assertClients(tb, want, got)
|
||||
}
|
||||
|
||||
// assertPersistentClientsData is a helper function that checks whether want
|
||||
// persistent clients are the same as the persistent clients stored in data.
|
||||
func assertPersistentClientsData(
|
||||
tb testing.TB,
|
||||
clients *clientsContainer,
|
||||
data []map[string]*clientJSON,
|
||||
want []*client.Persistent,
|
||||
) {
|
||||
tb.Helper()
|
||||
|
||||
var got []*client.Persistent
|
||||
for _, cm := range data {
|
||||
for _, cj := range cm {
|
||||
var c *client.Persistent
|
||||
c, err := clients.jsonToClient(*cj, nil)
|
||||
require.NoError(tb, err)
|
||||
|
||||
got = append(got, c)
|
||||
}
|
||||
}
|
||||
|
||||
assertClients(tb, want, got)
|
||||
}
|
||||
|
||||
func TestClientsContainer_HandleAddClient(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
|
||||
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
|
||||
|
||||
clientEmptyID := newPersistentClient("empty_client_id")
|
||||
clientEmptyID.ClientIDs = []string{""}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
client *client.Persistent
|
||||
wantCode int
|
||||
wantClient []*client.Persistent
|
||||
}{{
|
||||
name: "add_one",
|
||||
client: clientOne,
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{clientOne},
|
||||
}, {
|
||||
name: "add_two",
|
||||
client: clientTwo,
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{clientOne, clientTwo},
|
||||
}, {
|
||||
name: "duplicate_client",
|
||||
client: clientTwo,
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientOne, clientTwo},
|
||||
}, {
|
||||
name: "empty_client_id",
|
||||
client: clientEmptyID,
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientOne, clientTwo},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cj := clientToJSON(tc.client)
|
||||
|
||||
body, err := json.Marshal(cj)
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := http.NewRequest(http.MethodPost, "", bytes.NewReader(body))
|
||||
require.NoError(t, err)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
clients.handleAddClient(rw, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.wantCode, rw.Code)
|
||||
|
||||
assertPersistentClients(t, clients, tc.wantClient)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientsContainer_HandleDelClient(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
|
||||
err := clients.add(clientOne)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
|
||||
err = clients.add(clientTwo)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
client *client.Persistent
|
||||
wantCode int
|
||||
wantClient []*client.Persistent
|
||||
}{{
|
||||
name: "remove_one",
|
||||
client: clientOne,
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{clientTwo},
|
||||
}, {
|
||||
name: "duplicate_client",
|
||||
client: clientOne,
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientTwo},
|
||||
}, {
|
||||
name: "empty_client_name",
|
||||
client: newPersistentClient(""),
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientTwo},
|
||||
}, {
|
||||
name: "remove_two",
|
||||
client: clientTwo,
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cj := clientToJSON(tc.client)
|
||||
|
||||
var body []byte
|
||||
body, err = json.Marshal(cj)
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(body))
|
||||
require.NoError(t, err)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
clients.handleDelClient(rw, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.wantCode, rw.Code)
|
||||
|
||||
assertPersistentClients(t, clients, tc.wantClient)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientsContainer_HandleUpdateClient(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
|
||||
err := clients.add(clientOne)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertPersistentClients(t, clients, []*client.Persistent{clientOne})
|
||||
|
||||
clientModified := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
|
||||
|
||||
clientEmptyID := newPersistentClient("empty_client_id")
|
||||
clientEmptyID.ClientIDs = []string{""}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
clientName string
|
||||
modified *client.Persistent
|
||||
wantCode int
|
||||
wantClient []*client.Persistent
|
||||
}{{
|
||||
name: "update_one",
|
||||
clientName: clientOne.Name,
|
||||
modified: clientModified,
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{clientModified},
|
||||
}, {
|
||||
name: "empty_name",
|
||||
clientName: "",
|
||||
modified: clientOne,
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientModified},
|
||||
}, {
|
||||
name: "client_not_found",
|
||||
clientName: "client_not_found",
|
||||
modified: clientOne,
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientModified},
|
||||
}, {
|
||||
name: "empty_client_id",
|
||||
clientName: clientModified.Name,
|
||||
modified: clientEmptyID,
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientModified},
|
||||
}, {
|
||||
name: "no_ids",
|
||||
clientName: clientModified.Name,
|
||||
modified: newPersistentClient("no_ids"),
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientModified},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
uj := updateJSON{
|
||||
Name: tc.clientName,
|
||||
Data: *clientToJSON(tc.modified),
|
||||
}
|
||||
|
||||
var body []byte
|
||||
body, err = json.Marshal(uj)
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(body))
|
||||
require.NoError(t, err)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
clients.handleUpdateClient(rw, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.wantCode, rw.Code)
|
||||
|
||||
assertPersistentClients(t, clients, tc.wantClient)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientsContainer_HandleFindClient(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
clients.clientChecker = &testBlockedClientChecker{
|
||||
onIsBlockedClient: func(ip netip.Addr, clientID string) (ok bool, rule string) {
|
||||
return false, ""
|
||||
},
|
||||
}
|
||||
|
||||
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
|
||||
err := clients.add(clientOne)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
|
||||
err = clients.add(clientTwo)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
query url.Values
|
||||
wantCode int
|
||||
wantClient []*client.Persistent
|
||||
}{{
|
||||
name: "single",
|
||||
query: url.Values{
|
||||
"ip0": []string{testClientIP1},
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{clientOne},
|
||||
}, {
|
||||
name: "multiple",
|
||||
query: url.Values{
|
||||
"ip0": []string{testClientIP1},
|
||||
"ip1": []string{testClientIP2},
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{clientOne, clientTwo},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodGet, "", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
r.URL.RawQuery = tc.query.Encode()
|
||||
rw := httptest.NewRecorder()
|
||||
clients.handleFindClient(rw, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.wantCode, rw.Code)
|
||||
|
||||
var body []byte
|
||||
body, err = io.ReadAll(rw.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientData := []map[string]*clientJSON{}
|
||||
err = json.Unmarshal(body, &clientData)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertPersistentClientsData(t, clients, clientData, tc.wantClient)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -148,7 +148,7 @@ func initDNSServer(
|
|||
return fmt.Errorf("dnsforward.NewServer: %w", err)
|
||||
}
|
||||
|
||||
Context.clients.dnsServer = Context.dnsServer
|
||||
Context.clients.clientChecker = Context.dnsServer
|
||||
|
||||
dnsConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpReg)
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in a new issue