mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-21 12:35:33 +03:00
Pull request 1883: 951-blocked-services-client-schedule
Updates #951. Squashed commit of the following: commit 94e4766932940a99c5265489bccb46d0ed6cec25 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Jun 27 17:21:41 2023 +0300 chlog: upd docs commit b4022c33860c258bf29650413f0c972b849a1758 Merge: cfa24ff01e7e638443
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Jun 27 16:33:20 2023 +0300 Merge branch 'master' into 951-blocked-services-client-schedule commit cfa24ff0190b2bc12736700eeff815525fbaf5fe Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Jun 27 15:04:10 2023 +0300 chlog: imp docs commit dad27590d5eefde82758d58fc06a20c139492db8 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Jun 26 17:38:08 2023 +0300 home: imp err msg commit 7d9ba98c4477000fc2e0f06c3462fe9cd0c65293 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Jun 26 16:58:00 2023 +0300 all: add tests commit 8e952fc4e3b3d433b29efe47c88d6b7806e99ff8 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Jun 23 16:36:10 2023 +0300 schedule: add todo commit 723573a98d5b930334a5fa125eb12593f4a2430d Merge: 2151ab2a6e54fc9b1e
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Jun 23 11:40:03 2023 +0300 Merge branch 'master' into 951-blocked-services-client-schedule commit 2151ab2a627b9833ba6cce9621f72b29d326da75 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Jun 23 11:37:49 2023 +0300 all: add tests commit 81ab341db3e4053f09b181d8111c0da197bdac05 Merge: aa7ae41a866345e855
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Jun 22 17:59:01 2023 +0300 Merge branch 'master' into 951-blocked-services-client-schedule commit aa7ae41a868045fe24e390b25f15551fd8821529 Merge: 304389a4806d465b0d
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Jun 21 17:10:11 2023 +0300 Merge branch 'master' into 951-blocked-services-client-schedule commit 304389a487f728e8ced293ea811a4e0026a37f0d Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Jun 21 17:05:31 2023 +0300 home: imp err msg commit 29cfc7ae2a0bbd5ec3205eae3f6f810519787f26 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Jun 20 20:42:59 2023 +0300 all: imp err handling commit 8543868eef6442fd30131d9567b66222999101e9 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Jun 20 18:21:50 2023 +0300 all: upd chlog commit c5b614d45e5cf25c30c52343f48139fb34c77539 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Jun 20 14:37:47 2023 +0300 all: add blocked services schedule
This commit is contained in:
parent
e7e638443f
commit
d88181343c
14 changed files with 418 additions and 71 deletions
42
CHANGELOG.md
42
CHANGELOG.md
|
@ -27,9 +27,9 @@ NOTE: Add new changes BELOW THIS COMMENT.
|
|||
|
||||
- The new command-line flag `--web-addr` is the address to serve the web UI on,
|
||||
in the host:port format.
|
||||
- The ability to set inactivity periods for filtering blocked services in the
|
||||
configuration file ([#951]). The UI changes are coming in the upcoming
|
||||
releases.
|
||||
- The ability to set inactivity periods for filtering blocked services, both
|
||||
globally and per client, in the configuration file ([#951]). The UI changes
|
||||
are coming in the upcoming releases.
|
||||
- The ability to edit rewrite rules via `PUT /control/rewrite/update` HTTP API
|
||||
and the Web UI ([#1577]).
|
||||
|
||||
|
@ -37,8 +37,42 @@ NOTE: Add new changes BELOW THIS COMMENT.
|
|||
|
||||
#### Configuration Changes
|
||||
|
||||
In this release, the schema version has changed from 20 to 21.
|
||||
In this release, the schema version has changed from 20 to 22.
|
||||
|
||||
- Property `clients.persistent.blocked_services`, which in schema versions 21
|
||||
and earlier used to be a list containing ids of blocked services, is now an
|
||||
object containing ids and schedule for blocked services:
|
||||
|
||||
```yaml
|
||||
# BEFORE:
|
||||
'clients':
|
||||
'persistent':
|
||||
- 'name': 'client-name'
|
||||
'blocked_services':
|
||||
- id_1
|
||||
- id_2
|
||||
|
||||
# AFTER:
|
||||
'clients':
|
||||
'persistent':
|
||||
- 'name': client-name
|
||||
'blocked_services':
|
||||
'ids':
|
||||
- id_1
|
||||
- id_2
|
||||
'schedule':
|
||||
'time_zone': 'Local'
|
||||
'sun':
|
||||
'start': '0s'
|
||||
'end': '24h'
|
||||
'mon':
|
||||
'start': '1h'
|
||||
'end': '23h'
|
||||
```
|
||||
|
||||
To rollback this change, replace `clients.persistent.blocked_services` object
|
||||
with the list of ids of blocked services and change the `schema_version` back
|
||||
to `21`.
|
||||
- Property `dns.blocked_services`, which in schema versions 20 and earlier used
|
||||
to be a list containing ids of blocked services, is now an object containing
|
||||
ids and schedule for blocked services:
|
||||
|
|
|
@ -2,6 +2,7 @@ package filtering
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
|
@ -55,11 +56,29 @@ type BlockedServices struct {
|
|||
IDs []string `yaml:"ids"`
|
||||
}
|
||||
|
||||
// BlockedSvcKnown returns true if a blocked service ID is known.
|
||||
func BlockedSvcKnown(s string) (ok bool) {
|
||||
_, ok = serviceRules[s]
|
||||
// Clone returns a deep copy of blocked services.
|
||||
func (s *BlockedServices) Clone() (c *BlockedServices) {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ok
|
||||
return &BlockedServices{
|
||||
Schedule: s.Schedule.Clone(),
|
||||
IDs: slices.Clone(s.IDs),
|
||||
}
|
||||
}
|
||||
|
||||
// Validate returns an error if blocked services contain unknown service ID. s
|
||||
// must not be nil.
|
||||
func (s *BlockedServices) Validate() (err error) {
|
||||
for _, id := range s.IDs {
|
||||
_, ok := serviceRules[id]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown blocked-service %q", id)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyBlockedServices - set blocked services settings for this DNS request
|
||||
|
|
|
@ -988,17 +988,11 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
|||
}
|
||||
|
||||
if d.BlockedServices != nil {
|
||||
bsvcs := []string{}
|
||||
for _, s := range d.BlockedServices.IDs {
|
||||
if !BlockedSvcKnown(s) {
|
||||
log.Debug("skipping unknown blocked-service %q", s)
|
||||
err = d.BlockedServices.Validate()
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
bsvcs = append(bsvcs, s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("filtering: %w", err)
|
||||
}
|
||||
d.BlockedServices.IDs = bsvcs
|
||||
}
|
||||
|
||||
if blockFilters != nil {
|
||||
|
|
|
@ -23,12 +23,14 @@ type Client struct {
|
|||
safeSearchConf filtering.SafeSearchConfig
|
||||
SafeSearch filtering.SafeSearch
|
||||
|
||||
// BlockedServices is the configuration of blocked services of a client.
|
||||
BlockedServices *filtering.BlockedServices
|
||||
|
||||
Name string
|
||||
|
||||
IDs []string
|
||||
Tags []string
|
||||
BlockedServices []string
|
||||
Upstreams []string
|
||||
IDs []string
|
||||
Tags []string
|
||||
Upstreams []string
|
||||
|
||||
UseOwnSettings bool
|
||||
FilteringEnabled bool
|
||||
|
@ -44,9 +46,9 @@ type Client struct {
|
|||
func (c *Client) ShallowClone() (sh *Client) {
|
||||
clone := *c
|
||||
|
||||
clone.BlockedServices = c.BlockedServices.Clone()
|
||||
clone.IDs = stringutil.CloneSlice(c.IDs)
|
||||
clone.Tags = stringutil.CloneSlice(c.Tags)
|
||||
clone.BlockedServices = stringutil.CloneSlice(c.BlockedServices)
|
||||
clone.Upstreams = stringutil.CloneSlice(c.Upstreams)
|
||||
|
||||
return &clone
|
||||
|
|
|
@ -96,7 +96,7 @@ func (clients *clientsContainer) Init(
|
|||
etcHosts *aghnet.HostsContainer,
|
||||
arpdb aghnet.ARPDB,
|
||||
filteringConf *filtering.Config,
|
||||
) {
|
||||
) (err error) {
|
||||
if clients.list != nil {
|
||||
log.Fatal("clients.list != nil")
|
||||
}
|
||||
|
@ -110,13 +110,17 @@ func (clients *clientsContainer) Init(
|
|||
clients.dhcpServer = dhcpServer
|
||||
clients.etcHosts = etcHosts
|
||||
clients.arpdb = arpdb
|
||||
clients.addFromConfig(objects, filteringConf)
|
||||
err = clients.addFromConfig(objects, filteringConf)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
|
||||
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
|
||||
|
||||
if clients.testing {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
if clients.dhcpServer != nil {
|
||||
|
@ -127,6 +131,8 @@ func (clients *clientsContainer) Init(
|
|||
if clients.etcHosts != nil {
|
||||
go clients.handleHostsUpdates()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) handleHostsUpdates() {
|
||||
|
@ -166,12 +172,14 @@ func (clients *clientsContainer) reloadARP() {
|
|||
type clientObject struct {
|
||||
SafeSearchConf filtering.SafeSearchConfig `yaml:"safe_search"`
|
||||
|
||||
// BlockedServices is the configuration of blocked services of a client.
|
||||
BlockedServices *filtering.BlockedServices `yaml:"blocked_services"`
|
||||
|
||||
Name string `yaml:"name"`
|
||||
|
||||
Tags []string `yaml:"tags"`
|
||||
IDs []string `yaml:"ids"`
|
||||
BlockedServices []string `yaml:"blocked_services"`
|
||||
Upstreams []string `yaml:"upstreams"`
|
||||
IDs []string `yaml:"ids"`
|
||||
Tags []string `yaml:"tags"`
|
||||
Upstreams []string `yaml:"upstreams"`
|
||||
|
||||
UseGlobalSettings bool `yaml:"use_global_settings"`
|
||||
FilteringEnabled bool `yaml:"filtering_enabled"`
|
||||
|
@ -185,7 +193,10 @@ type clientObject struct {
|
|||
|
||||
// addFromConfig initializes the clients container with objects from the
|
||||
// configuration file.
|
||||
func (clients *clientsContainer) addFromConfig(objects []*clientObject, filteringConf *filtering.Config) {
|
||||
func (clients *clientsContainer) addFromConfig(
|
||||
objects []*clientObject,
|
||||
filteringConf *filtering.Config,
|
||||
) (err error) {
|
||||
for _, o := range objects {
|
||||
cli := &Client{
|
||||
Name: o.Name,
|
||||
|
@ -206,7 +217,7 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
|
|||
if o.SafeSearchConf.Enabled {
|
||||
o.SafeSearchConf.CustomResolver = safeSearchResolver{}
|
||||
|
||||
err := cli.setSafeSearch(
|
||||
err = cli.setSafeSearch(
|
||||
o.SafeSearchConf,
|
||||
filteringConf.SafeSearchCacheSize,
|
||||
time.Minute*time.Duration(filteringConf.CacheTime),
|
||||
|
@ -218,14 +229,13 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
|
|||
}
|
||||
}
|
||||
|
||||
for _, s := range o.BlockedServices {
|
||||
if filtering.BlockedSvcKnown(s) {
|
||||
cli.BlockedServices = append(cli.BlockedServices, s)
|
||||
} else {
|
||||
log.Info("clients: skipping unknown blocked service %q", s)
|
||||
}
|
||||
err = o.BlockedServices.Validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("clients: init client blocked services %q: %w", cli.Name, err)
|
||||
}
|
||||
|
||||
cli.BlockedServices = o.BlockedServices.Clone()
|
||||
|
||||
for _, t := range o.Tags {
|
||||
if clients.allTags.Has(t) {
|
||||
cli.Tags = append(cli.Tags, t)
|
||||
|
@ -236,11 +246,13 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
|
|||
|
||||
slices.Sort(cli.Tags)
|
||||
|
||||
_, err := clients.Add(cli)
|
||||
_, err = clients.Add(cli)
|
||||
if err != nil {
|
||||
log.Error("clients: adding clients %s: %s", cli.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// forConfig returns all currently known persistent clients as objects for the
|
||||
|
@ -254,10 +266,11 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
|||
o := &clientObject{
|
||||
Name: cli.Name,
|
||||
|
||||
Tags: stringutil.CloneSlice(cli.Tags),
|
||||
IDs: stringutil.CloneSlice(cli.IDs),
|
||||
BlockedServices: stringutil.CloneSlice(cli.BlockedServices),
|
||||
Upstreams: stringutil.CloneSlice(cli.Upstreams),
|
||||
BlockedServices: cli.BlockedServices.Clone(),
|
||||
|
||||
IDs: stringutil.CloneSlice(cli.IDs),
|
||||
Tags: stringutil.CloneSlice(cli.Tags),
|
||||
Upstreams: stringutil.CloneSlice(cli.Upstreams),
|
||||
|
||||
UseGlobalSettings: !cli.UseOwnSettings,
|
||||
FilteringEnabled: cli.FilteringEnabled,
|
||||
|
|
|
@ -16,18 +16,19 @@ import (
|
|||
|
||||
// newClientsContainer is a helper that creates a new clients container for
|
||||
// tests.
|
||||
func newClientsContainer() (c *clientsContainer) {
|
||||
func newClientsContainer(t *testing.T) (c *clientsContainer) {
|
||||
c = &clientsContainer{
|
||||
testing: true,
|
||||
}
|
||||
|
||||
c.Init(nil, nil, nil, nil, &filtering.Config{})
|
||||
err := c.Init(nil, nil, nil, nil, &filtering.Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func TestClients(t *testing.T) {
|
||||
clients := newClientsContainer()
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
t.Run("add_success", func(t *testing.T) {
|
||||
var (
|
||||
|
@ -198,7 +199,7 @@ func TestClients(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestClientsWHOIS(t *testing.T) {
|
||||
clients := newClientsContainer()
|
||||
clients := newClientsContainer(t)
|
||||
whois := &whois.Info{
|
||||
Country: "AU",
|
||||
Orgname: "Example Org",
|
||||
|
@ -244,7 +245,7 @@ func TestClientsWHOIS(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestClientsAddExisting(t *testing.T) {
|
||||
clients := newClientsContainer()
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
|
@ -316,7 +317,7 @@ func TestClientsAddExisting(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestClientsCustomUpstream(t *testing.T) {
|
||||
clients := newClientsContainer()
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
// Add client with upstreams.
|
||||
ok, err := clients.Add(&Client{
|
||||
|
|
|
@ -123,10 +123,14 @@ func (clients *clientsContainer) jsonToClient(cj clientJSON, prev *Client) (c *C
|
|||
|
||||
Name: cj.Name,
|
||||
|
||||
IDs: cj.IDs,
|
||||
Tags: cj.Tags,
|
||||
BlockedServices: cj.BlockedServices,
|
||||
Upstreams: cj.Upstreams,
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: prev.BlockedServices.Schedule.Clone(),
|
||||
IDs: cj.BlockedServices,
|
||||
},
|
||||
|
||||
IDs: cj.IDs,
|
||||
Tags: cj.Tags,
|
||||
Upstreams: cj.Upstreams,
|
||||
|
||||
UseOwnSettings: !cj.UseGlobalSettings,
|
||||
FilteringEnabled: cj.FilteringEnabled,
|
||||
|
@ -180,7 +184,8 @@ func clientToJSON(c *Client) (cj *clientJSON) {
|
|||
SafeBrowsingEnabled: c.SafeBrowsingEnabled,
|
||||
|
||||
UseGlobalBlockedServices: !c.UseOwnBlockedServices,
|
||||
BlockedServices: c.BlockedServices,
|
||||
|
||||
BlockedServices: c.BlockedServices.IDs,
|
||||
|
||||
Upstreams: c.Upstreams,
|
||||
|
||||
|
|
|
@ -474,9 +474,11 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering
|
|||
if c.UseOwnBlockedServices {
|
||||
// TODO(e.burkov): Get rid of this crutch.
|
||||
setts.ServicesRules = nil
|
||||
svcs := c.BlockedServices
|
||||
Context.filters.ApplyBlockedServicesList(setts, svcs)
|
||||
log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs)
|
||||
svcs := c.BlockedServices.IDs
|
||||
if !c.BlockedServices.Schedule.Contains(time.Now()) {
|
||||
Context.filters.ApplyBlockedServicesList(setts, svcs)
|
||||
log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs)
|
||||
}
|
||||
}
|
||||
|
||||
setts.ClientName = c.Name
|
||||
|
|
|
@ -6,9 +6,87 @@ import (
|
|||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApplyAdditionalFiltering(t *testing.T) {
|
||||
var err error
|
||||
|
||||
Context.filters, err = filtering.New(&filtering.Config{
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
},
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
Context.clients.idIndex = map[string]*Client{
|
||||
"default": {
|
||||
UseOwnSettings: false,
|
||||
safeSearchConf: filtering.SafeSearchConfig{Enabled: false},
|
||||
FilteringEnabled: false,
|
||||
SafeBrowsingEnabled: false,
|
||||
ParentalEnabled: false,
|
||||
},
|
||||
"custom_filtering": {
|
||||
UseOwnSettings: true,
|
||||
safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||
FilteringEnabled: true,
|
||||
SafeBrowsingEnabled: true,
|
||||
ParentalEnabled: true,
|
||||
},
|
||||
"partial_custom_filtering": {
|
||||
UseOwnSettings: true,
|
||||
safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||
FilteringEnabled: true,
|
||||
SafeBrowsingEnabled: false,
|
||||
ParentalEnabled: false,
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
id string
|
||||
FilteringEnabled assert.BoolAssertionFunc
|
||||
SafeSearchEnabled assert.BoolAssertionFunc
|
||||
SafeBrowsingEnabled assert.BoolAssertionFunc
|
||||
ParentalEnabled assert.BoolAssertionFunc
|
||||
}{{
|
||||
name: "global_settings",
|
||||
id: "default",
|
||||
FilteringEnabled: assert.False,
|
||||
SafeSearchEnabled: assert.False,
|
||||
SafeBrowsingEnabled: assert.False,
|
||||
ParentalEnabled: assert.False,
|
||||
}, {
|
||||
name: "custom_settings",
|
||||
id: "custom_filtering",
|
||||
FilteringEnabled: assert.True,
|
||||
SafeSearchEnabled: assert.True,
|
||||
SafeBrowsingEnabled: assert.True,
|
||||
ParentalEnabled: assert.True,
|
||||
}, {
|
||||
name: "partial",
|
||||
id: "partial_custom_filtering",
|
||||
FilteringEnabled: assert.True,
|
||||
SafeSearchEnabled: assert.True,
|
||||
SafeBrowsingEnabled: assert.False,
|
||||
ParentalEnabled: assert.False,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
setts := &filtering.Settings{}
|
||||
|
||||
applyAdditionalFiltering(net.IP{1, 2, 3, 4}, tc.id, setts)
|
||||
tc.FilteringEnabled(t, setts.FilteringEnabled)
|
||||
tc.SafeSearchEnabled(t, setts.SafeSearchEnabled)
|
||||
tc.SafeBrowsingEnabled(t, setts.SafeBrowsingEnabled)
|
||||
tc.ParentalEnabled(t, setts.ParentalEnabled)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
|
||||
filtering.InitModule()
|
||||
|
||||
|
@ -29,43 +107,61 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
Context.clients.idIndex = map[string]*Client{
|
||||
"client_1": {
|
||||
"default": {
|
||||
UseOwnBlockedServices: false,
|
||||
},
|
||||
"client_2": {
|
||||
"no_services": {
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
},
|
||||
UseOwnBlockedServices: true,
|
||||
},
|
||||
"client_3": {
|
||||
BlockedServices: clientBlockedServices,
|
||||
"services": {
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
IDs: clientBlockedServices,
|
||||
},
|
||||
UseOwnBlockedServices: true,
|
||||
},
|
||||
"client_4": {
|
||||
BlockedServices: invalidBlockedServices,
|
||||
"invalid_services": {
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
IDs: invalidBlockedServices,
|
||||
},
|
||||
UseOwnBlockedServices: true,
|
||||
},
|
||||
"allow_all": {
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.FullWeekly(),
|
||||
IDs: clientBlockedServices,
|
||||
},
|
||||
UseOwnBlockedServices: true,
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
ip net.IP
|
||||
id string
|
||||
setts *filtering.Settings
|
||||
wantLen int
|
||||
}{{
|
||||
name: "global_settings",
|
||||
id: "client_1",
|
||||
id: "default",
|
||||
wantLen: len(globalBlockedServices),
|
||||
}, {
|
||||
name: "custom_settings",
|
||||
id: "client_2",
|
||||
id: "no_services",
|
||||
wantLen: 0,
|
||||
}, {
|
||||
name: "custom_settings_block",
|
||||
id: "client_3",
|
||||
id: "services",
|
||||
wantLen: len(clientBlockedServices),
|
||||
}, {
|
||||
name: "custom_settings_invalid",
|
||||
id: "client_4",
|
||||
id: "invalid_services",
|
||||
wantLen: 0,
|
||||
}, {
|
||||
name: "custom_settings_inactive_schedule",
|
||||
id: "allow_all",
|
||||
wantLen: 0,
|
||||
}}
|
||||
|
||||
|
|
|
@ -355,13 +355,17 @@ func initContextClients() (err error) {
|
|||
arpdb = aghnet.NewARPDB()
|
||||
}
|
||||
|
||||
Context.clients.Init(
|
||||
err = Context.clients.Init(
|
||||
config.Clients.Persistent,
|
||||
Context.dhcpServer,
|
||||
Context.etcHosts,
|
||||
arpdb,
|
||||
config.DNS.DnsfilterConf,
|
||||
)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -228,7 +228,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
|||
for _, tc := range testCases {
|
||||
w.Reset()
|
||||
|
||||
cc := newClientsContainer()
|
||||
cc := newClientsContainer(t)
|
||||
ch := make(chan netip.Addr)
|
||||
rdns := &RDNS{
|
||||
exchanger: &rDNSExchanger{
|
||||
|
|
|
@ -22,7 +22,7 @@ import (
|
|||
)
|
||||
|
||||
// currentSchemaVersion is the current schema version.
|
||||
const currentSchemaVersion = 21
|
||||
const currentSchemaVersion = 22
|
||||
|
||||
// These aliases are provided for convenience.
|
||||
type (
|
||||
|
@ -95,6 +95,7 @@ func upgradeConfigSchema(oldVersion int, diskConf yobj) (err error) {
|
|||
upgradeSchema18to19,
|
||||
upgradeSchema19to20,
|
||||
upgradeSchema20to21,
|
||||
upgradeSchema21to22,
|
||||
}
|
||||
|
||||
n := 0
|
||||
|
@ -1179,6 +1180,82 @@ func upgradeSchema20to21(diskConf yobj) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
// upgradeSchema21to22 performs the following changes:
|
||||
//
|
||||
// # BEFORE:
|
||||
// 'persistent':
|
||||
// - 'name': 'client_name'
|
||||
// 'blocked_services':
|
||||
// - 'svc_name'
|
||||
//
|
||||
// # AFTER:
|
||||
// 'persistent':
|
||||
// - 'name': 'client_name'
|
||||
// 'blocked_services':
|
||||
// 'ids':
|
||||
// - 'svc_name'
|
||||
// 'schedule':
|
||||
// 'time_zone': 'Local'
|
||||
func upgradeSchema21to22(diskConf yobj) (err error) {
|
||||
log.Println("Upgrade yaml: 21 to 22")
|
||||
diskConf["schema_version"] = 22
|
||||
|
||||
const field = "blocked_services"
|
||||
|
||||
clientsVal, ok := diskConf["clients"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
clients, ok := clientsVal.(yobj)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of clients: %T", clientsVal)
|
||||
}
|
||||
|
||||
persistentVal, ok := clients["persistent"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
persistent, ok := persistentVal.([]any)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of persistent clients: %T", persistentVal)
|
||||
}
|
||||
|
||||
for i, val := range persistent {
|
||||
var c yobj
|
||||
c, ok = val.(yobj)
|
||||
if !ok {
|
||||
return fmt.Errorf("persistent client at index %d: unexpected type %T", i, val)
|
||||
}
|
||||
|
||||
var blockedVal any
|
||||
blockedVal, ok = c[field]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
var services yarr
|
||||
services, ok = blockedVal.(yarr)
|
||||
if !ok {
|
||||
return fmt.Errorf(
|
||||
"persistent client at index %d: unexpected type of blocked services: %T",
|
||||
i,
|
||||
blockedVal,
|
||||
)
|
||||
}
|
||||
|
||||
c[field] = yobj{
|
||||
"ids": services,
|
||||
"schedule": yobj{
|
||||
"time_zone": "Local",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Replace with log.Output when we port it to our logging
|
||||
// package.
|
||||
func funcName() string {
|
||||
|
|
|
@ -1183,3 +1183,73 @@ func TestUpgradeSchema20to21(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpgradeSchema21to22(t *testing.T) {
|
||||
const newSchemaVer = 22
|
||||
|
||||
testCases := []struct {
|
||||
in yobj
|
||||
want yobj
|
||||
name string
|
||||
}{{
|
||||
in: yobj{
|
||||
"clients": yobj{},
|
||||
},
|
||||
want: yobj{
|
||||
"clients": yobj{},
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
name: "nothing",
|
||||
}, {
|
||||
in: yobj{
|
||||
"clients": yobj{
|
||||
"persistent": []any{yobj{"name": "localhost", "blocked_services": yarr{}}},
|
||||
},
|
||||
},
|
||||
want: yobj{
|
||||
"clients": yobj{
|
||||
"persistent": []any{yobj{
|
||||
"name": "localhost",
|
||||
"blocked_services": yobj{
|
||||
"ids": yarr{},
|
||||
"schedule": yobj{
|
||||
"time_zone": "Local",
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
name: "no_services",
|
||||
}, {
|
||||
in: yobj{
|
||||
"clients": yobj{
|
||||
"persistent": []any{yobj{"name": "localhost", "blocked_services": yarr{"ok"}}},
|
||||
},
|
||||
},
|
||||
want: yobj{
|
||||
"clients": yobj{
|
||||
"persistent": []any{yobj{
|
||||
"name": "localhost",
|
||||
"blocked_services": yobj{
|
||||
"ids": yarr{"ok"},
|
||||
"schedule": yobj{
|
||||
"time_zone": "Local",
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
name: "services",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := upgradeSchema21to22(tc.in)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.want, tc.in)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,6 +28,36 @@ func EmptyWeekly() (w *Weekly) {
|
|||
}
|
||||
}
|
||||
|
||||
// FullWeekly creates full weekly schedule with local time zone.
|
||||
//
|
||||
// TODO(s.chzhen): Consider moving into tests.
|
||||
func FullWeekly() (w *Weekly) {
|
||||
fullDay := dayRange{start: 0, end: maxDayRange}
|
||||
|
||||
return &Weekly{
|
||||
location: time.Local,
|
||||
days: [7]dayRange{
|
||||
time.Sunday: fullDay,
|
||||
time.Monday: fullDay,
|
||||
time.Tuesday: fullDay,
|
||||
time.Wednesday: fullDay,
|
||||
time.Thursday: fullDay,
|
||||
time.Friday: fullDay,
|
||||
time.Saturday: fullDay,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Clone returns a deep copy of a weekly.
|
||||
func (w *Weekly) Clone() (c *Weekly) {
|
||||
// NOTE: Do not use time.LoadLocation, because the results will be
|
||||
// different on time zone database update.
|
||||
return &Weekly{
|
||||
location: w.location,
|
||||
days: w.days,
|
||||
}
|
||||
}
|
||||
|
||||
// Contains returns true if t is within the corresponding day range of the
|
||||
// schedule in the schedule's time zone.
|
||||
func (w *Weekly) Contains(t time.Time) (ok bool) {
|
||||
|
|
Loading…
Reference in a new issue