mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-21 20:45:33 +03:00
client: imp code
This commit is contained in:
parent
79272b299a
commit
6cc4ed53a2
4 changed files with 41 additions and 14 deletions
|
@ -13,7 +13,6 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/container"
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
|
@ -136,7 +135,8 @@ type Persistent struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate returns an error if persistent client information contains errors.
|
// validate returns an error if persistent client information contains errors.
|
||||||
func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) {
|
// allTags must be sorted.
|
||||||
|
func (c *Persistent) validate(allTags []string) (err error) {
|
||||||
switch {
|
switch {
|
||||||
case c.Name == "":
|
case c.Name == "":
|
||||||
return errors.Error("empty name")
|
return errors.Error("empty name")
|
||||||
|
@ -157,7 +157,8 @@ func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, t := range c.Tags {
|
for _, t := range c.Tags {
|
||||||
if !allTags.Has(t) {
|
_, ok := slices.BinarySearch(allTags, t)
|
||||||
|
if !ok {
|
||||||
return fmt.Errorf("invalid tag: %q", t)
|
return fmt.Errorf("invalid tag: %q", t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/container"
|
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -132,7 +131,7 @@ func TestPersistent_Validate(t *testing.T) {
|
||||||
notAllowedTag = "not_allowed_tag"
|
notAllowedTag = "not_allowed_tag"
|
||||||
)
|
)
|
||||||
|
|
||||||
allowedTags := container.NewMapSet(allowedTag)
|
allowedTags := []string{allowedTag}
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
|
|
|
@ -5,13 +5,13 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||||
"github.com/AdguardTeam/golibs/container"
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/hostsfile"
|
"github.com/AdguardTeam/golibs/hostsfile"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
@ -108,9 +108,6 @@ type StorageConfig struct {
|
||||||
|
|
||||||
// Storage contains information about persistent and runtime clients.
|
// Storage contains information about persistent and runtime clients.
|
||||||
type Storage struct {
|
type Storage struct {
|
||||||
// allowedTags is a set of all allowed tags.
|
|
||||||
allowedTags *container.MapSet[string]
|
|
||||||
|
|
||||||
// mu protects indexes of persistent and runtime clients.
|
// mu protects indexes of persistent and runtime clients.
|
||||||
mu *sync.Mutex
|
mu *sync.Mutex
|
||||||
|
|
||||||
|
@ -132,6 +129,12 @@ type Storage struct {
|
||||||
// done is the shutdown signaling channel.
|
// done is the shutdown signaling channel.
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
|
|
||||||
|
// allowedTags is a sorted list of all allowed tags. It must not be
|
||||||
|
// modified after initialization.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Use custom type.
|
||||||
|
allowedTags []string
|
||||||
|
|
||||||
// arpClientsUpdatePeriod defines how often [SourceARP] runtime client
|
// arpClientsUpdatePeriod defines how often [SourceARP] runtime client
|
||||||
// information is updated. It must be greater than zero.
|
// information is updated. It must be greater than zero.
|
||||||
arpClientsUpdatePeriod time.Duration
|
arpClientsUpdatePeriod time.Duration
|
||||||
|
@ -143,8 +146,11 @@ type Storage struct {
|
||||||
|
|
||||||
// NewStorage returns initialized client storage. conf must not be nil.
|
// NewStorage returns initialized client storage. conf must not be nil.
|
||||||
func NewStorage(conf *StorageConfig) (s *Storage, err error) {
|
func NewStorage(conf *StorageConfig) (s *Storage, err error) {
|
||||||
|
tags := slices.Clone(allowedTags)
|
||||||
|
slices.Sort(tags)
|
||||||
|
|
||||||
s = &Storage{
|
s = &Storage{
|
||||||
allowedTags: container.NewMapSet(allowedTags...),
|
allowedTags: tags,
|
||||||
mu: &sync.Mutex{},
|
mu: &sync.Mutex{},
|
||||||
index: newIndex(),
|
index: newIndex(),
|
||||||
runtimeIndex: newRuntimeIndex(),
|
runtimeIndex: newRuntimeIndex(),
|
||||||
|
@ -576,7 +582,8 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
|
||||||
s.runtimeIndex.rangeClients(f)
|
s.runtimeIndex.rangeClients(f)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowedTags returns the list of available client tags.
|
// AllowedTags returns the list of available client tags. tags must not be
|
||||||
|
// modified.
|
||||||
func (s *Storage) AllowedTags() (tags []string) {
|
func (s *Storage) AllowedTags() (tags []string) {
|
||||||
return allowedTags
|
return s.allowedTags
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -536,7 +537,7 @@ func TestStorage_Add(t *testing.T) {
|
||||||
existingName = "existing_name"
|
existingName = "existing_name"
|
||||||
existingClientID = "existing_client_id"
|
existingClientID = "existing_client_id"
|
||||||
|
|
||||||
allowedTag = "tag"
|
allowedTag = "user_admin"
|
||||||
notAllowedTag = "not_allowed_tag"
|
notAllowedTag = "not_allowed_tag"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -557,6 +558,16 @@ func TestStorage_Add(t *testing.T) {
|
||||||
s, err := client.NewStorage(&client.StorageConfig{})
|
s, err := client.NewStorage(&client.StorageConfig{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tags := s.AllowedTags()
|
||||||
|
require.NotZero(t, len(tags))
|
||||||
|
require.True(t, slices.IsSorted(tags))
|
||||||
|
|
||||||
|
_, ok := slices.BinarySearch(tags, allowedTag)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
_, ok = slices.BinarySearch(tags, notAllowedTag)
|
||||||
|
require.False(t, ok)
|
||||||
|
|
||||||
err = s.Add(existingClient)
|
err = s.Add(existingClient)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -617,12 +628,21 @@ func TestStorage_Add(t *testing.T) {
|
||||||
}, {
|
}, {
|
||||||
name: "not_allowed_tag",
|
name: "not_allowed_tag",
|
||||||
cli: &client.Persistent{
|
cli: &client.Persistent{
|
||||||
Name: "nont_allowed_tag",
|
Name: "not_allowed_tag",
|
||||||
Tags: []string{notAllowedTag},
|
Tags: []string{notAllowedTag},
|
||||||
IPs: []netip.Addr{netip.MustParseAddr("4.4.4.4")},
|
IPs: []netip.Addr{netip.MustParseAddr("4.4.4.4")},
|
||||||
UID: client.MustNewUID(),
|
UID: client.MustNewUID(),
|
||||||
},
|
},
|
||||||
wantErrMsg: `adding client: invalid tag: "not_allowed_tag"`,
|
wantErrMsg: `adding client: invalid tag: "not_allowed_tag"`,
|
||||||
|
}, {
|
||||||
|
name: "allowed_tag",
|
||||||
|
cli: &client.Persistent{
|
||||||
|
Name: "allowed_tag",
|
||||||
|
Tags: []string{allowedTag},
|
||||||
|
IPs: []netip.Addr{netip.MustParseAddr("5.5.5.5")},
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
},
|
||||||
|
wantErrMsg: "",
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
|
|
Loading…
Reference in a new issue