diff --git a/CHANGELOG.md b/CHANGELOG.md index d7904040..7e9bf7e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,9 +27,9 @@ NOTE: Add new changes BELOW THIS COMMENT. - The new command-line flag `--web-addr` is the address to serve the web UI on, in the host:port format. -- The ability to set inactivity periods for filtering blocked services in the - configuration file ([#951]). The UI changes are coming in the upcoming - releases. +- The ability to set inactivity periods for filtering blocked services, both + globally and per client, in the configuration file ([#951]). The UI changes + are coming in the upcoming releases. - The ability to edit rewrite rules via `PUT /control/rewrite/update` HTTP API and the Web UI ([#1577]). @@ -37,8 +37,42 @@ NOTE: Add new changes BELOW THIS COMMENT. #### Configuration Changes -In this release, the schema version has changed from 20 to 21. +In this release, the schema version has changed from 20 to 22. +- Property `clients.persistent.blocked_services`, which in schema versions 21 + and earlier used to be a list containing ids of blocked services, is now an + object containing ids and schedule for blocked services: + + ```yaml + # BEFORE: + 'clients': + 'persistent': + - 'name': 'client-name' + 'blocked_services': + - id_1 + - id_2 + + # AFTER: + 'clients': + 'persistent': + - 'name': client-name + 'blocked_services': + 'ids': + - id_1 + - id_2 + 'schedule': + 'time_zone': 'Local' + 'sun': + 'start': '0s' + 'end': '24h' + 'mon': + 'start': '1h' + 'end': '23h' + ``` + + To rollback this change, replace `clients.persistent.blocked_services` object + with the list of ids of blocked services and change the `schema_version` back + to `21`. - Property `dns.blocked_services`, which in schema versions 20 and earlier used to be a list containing ids of blocked services, is now an object containing ids and schedule for blocked services: diff --git a/internal/filtering/blocked.go b/internal/filtering/blocked.go index d8c8b9d2..f403d0ab 100644 --- a/internal/filtering/blocked.go +++ b/internal/filtering/blocked.go @@ -2,6 +2,7 @@ package filtering import ( "encoding/json" + "fmt" "net/http" "time" @@ -55,11 +56,29 @@ type BlockedServices struct { IDs []string `yaml:"ids"` } -// BlockedSvcKnown returns true if a blocked service ID is known. -func BlockedSvcKnown(s string) (ok bool) { - _, ok = serviceRules[s] +// Clone returns a deep copy of blocked services. +func (s *BlockedServices) Clone() (c *BlockedServices) { + if s == nil { + return nil + } - return ok + return &BlockedServices{ + Schedule: s.Schedule.Clone(), + IDs: slices.Clone(s.IDs), + } +} + +// Validate returns an error if blocked services contain unknown service ID. s +// must not be nil. +func (s *BlockedServices) Validate() (err error) { + for _, id := range s.IDs { + _, ok := serviceRules[id] + if !ok { + return fmt.Errorf("unknown blocked-service %q", id) + } + } + + return nil } // ApplyBlockedServices - set blocked services settings for this DNS request diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index ea6d4bfb..7cad6c99 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -988,17 +988,11 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) { } if d.BlockedServices != nil { - bsvcs := []string{} - for _, s := range d.BlockedServices.IDs { - if !BlockedSvcKnown(s) { - log.Debug("skipping unknown blocked-service %q", s) + err = d.BlockedServices.Validate() - continue - } - - bsvcs = append(bsvcs, s) + if err != nil { + return nil, fmt.Errorf("filtering: %w", err) } - d.BlockedServices.IDs = bsvcs } if blockFilters != nil { diff --git a/internal/home/client.go b/internal/home/client.go index 92c88385..c0d39dad 100644 --- a/internal/home/client.go +++ b/internal/home/client.go @@ -23,12 +23,14 @@ type Client struct { safeSearchConf filtering.SafeSearchConfig SafeSearch filtering.SafeSearch + // BlockedServices is the configuration of blocked services of a client. + BlockedServices *filtering.BlockedServices + Name string - IDs []string - Tags []string - BlockedServices []string - Upstreams []string + IDs []string + Tags []string + Upstreams []string UseOwnSettings bool FilteringEnabled bool @@ -44,9 +46,9 @@ type Client struct { func (c *Client) ShallowClone() (sh *Client) { clone := *c + clone.BlockedServices = c.BlockedServices.Clone() clone.IDs = stringutil.CloneSlice(c.IDs) clone.Tags = stringutil.CloneSlice(c.Tags) - clone.BlockedServices = stringutil.CloneSlice(c.BlockedServices) clone.Upstreams = stringutil.CloneSlice(c.Upstreams) return &clone diff --git a/internal/home/clients.go b/internal/home/clients.go index 7771f221..2e6323c8 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -96,7 +96,7 @@ func (clients *clientsContainer) Init( etcHosts *aghnet.HostsContainer, arpdb aghnet.ARPDB, filteringConf *filtering.Config, -) { +) (err error) { if clients.list != nil { log.Fatal("clients.list != nil") } @@ -110,13 +110,17 @@ func (clients *clientsContainer) Init( clients.dhcpServer = dhcpServer clients.etcHosts = etcHosts clients.arpdb = arpdb - clients.addFromConfig(objects, filteringConf) + err = clients.addFromConfig(objects, filteringConf) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return err + } clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime) if clients.testing { - return + return nil } if clients.dhcpServer != nil { @@ -127,6 +131,8 @@ func (clients *clientsContainer) Init( if clients.etcHosts != nil { go clients.handleHostsUpdates() } + + return nil } func (clients *clientsContainer) handleHostsUpdates() { @@ -166,12 +172,14 @@ func (clients *clientsContainer) reloadARP() { type clientObject struct { SafeSearchConf filtering.SafeSearchConfig `yaml:"safe_search"` + // BlockedServices is the configuration of blocked services of a client. + BlockedServices *filtering.BlockedServices `yaml:"blocked_services"` + Name string `yaml:"name"` - Tags []string `yaml:"tags"` - IDs []string `yaml:"ids"` - BlockedServices []string `yaml:"blocked_services"` - Upstreams []string `yaml:"upstreams"` + IDs []string `yaml:"ids"` + Tags []string `yaml:"tags"` + Upstreams []string `yaml:"upstreams"` UseGlobalSettings bool `yaml:"use_global_settings"` FilteringEnabled bool `yaml:"filtering_enabled"` @@ -185,7 +193,10 @@ type clientObject struct { // addFromConfig initializes the clients container with objects from the // configuration file. -func (clients *clientsContainer) addFromConfig(objects []*clientObject, filteringConf *filtering.Config) { +func (clients *clientsContainer) addFromConfig( + objects []*clientObject, + filteringConf *filtering.Config, +) (err error) { for _, o := range objects { cli := &Client{ Name: o.Name, @@ -206,7 +217,7 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin if o.SafeSearchConf.Enabled { o.SafeSearchConf.CustomResolver = safeSearchResolver{} - err := cli.setSafeSearch( + err = cli.setSafeSearch( o.SafeSearchConf, filteringConf.SafeSearchCacheSize, time.Minute*time.Duration(filteringConf.CacheTime), @@ -218,14 +229,13 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin } } - for _, s := range o.BlockedServices { - if filtering.BlockedSvcKnown(s) { - cli.BlockedServices = append(cli.BlockedServices, s) - } else { - log.Info("clients: skipping unknown blocked service %q", s) - } + err = o.BlockedServices.Validate() + if err != nil { + return fmt.Errorf("clients: init client blocked services %q: %w", cli.Name, err) } + cli.BlockedServices = o.BlockedServices.Clone() + for _, t := range o.Tags { if clients.allTags.Has(t) { cli.Tags = append(cli.Tags, t) @@ -236,11 +246,13 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin slices.Sort(cli.Tags) - _, err := clients.Add(cli) + _, err = clients.Add(cli) if err != nil { log.Error("clients: adding clients %s: %s", cli.Name, err) } } + + return nil } // forConfig returns all currently known persistent clients as objects for the @@ -254,10 +266,11 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) { o := &clientObject{ Name: cli.Name, - Tags: stringutil.CloneSlice(cli.Tags), - IDs: stringutil.CloneSlice(cli.IDs), - BlockedServices: stringutil.CloneSlice(cli.BlockedServices), - Upstreams: stringutil.CloneSlice(cli.Upstreams), + BlockedServices: cli.BlockedServices.Clone(), + + IDs: stringutil.CloneSlice(cli.IDs), + Tags: stringutil.CloneSlice(cli.Tags), + Upstreams: stringutil.CloneSlice(cli.Upstreams), UseGlobalSettings: !cli.UseOwnSettings, FilteringEnabled: cli.FilteringEnabled, diff --git a/internal/home/clients_test.go b/internal/home/clients_test.go index b203415f..9ad819ec 100644 --- a/internal/home/clients_test.go +++ b/internal/home/clients_test.go @@ -16,18 +16,19 @@ import ( // newClientsContainer is a helper that creates a new clients container for // tests. -func newClientsContainer() (c *clientsContainer) { +func newClientsContainer(t *testing.T) (c *clientsContainer) { c = &clientsContainer{ testing: true, } - c.Init(nil, nil, nil, nil, &filtering.Config{}) + err := c.Init(nil, nil, nil, nil, &filtering.Config{}) + require.NoError(t, err) return c } func TestClients(t *testing.T) { - clients := newClientsContainer() + clients := newClientsContainer(t) t.Run("add_success", func(t *testing.T) { var ( @@ -198,7 +199,7 @@ func TestClients(t *testing.T) { } func TestClientsWHOIS(t *testing.T) { - clients := newClientsContainer() + clients := newClientsContainer(t) whois := &whois.Info{ Country: "AU", Orgname: "Example Org", @@ -244,7 +245,7 @@ func TestClientsWHOIS(t *testing.T) { } func TestClientsAddExisting(t *testing.T) { - clients := newClientsContainer() + clients := newClientsContainer(t) t.Run("simple", func(t *testing.T) { ip := netip.MustParseAddr("1.1.1.1") @@ -316,7 +317,7 @@ func TestClientsAddExisting(t *testing.T) { } func TestClientsCustomUpstream(t *testing.T) { - clients := newClientsContainer() + clients := newClientsContainer(t) // Add client with upstreams. ok, err := clients.Add(&Client{ diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index 9eb91341..8028164a 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -123,10 +123,14 @@ func (clients *clientsContainer) jsonToClient(cj clientJSON, prev *Client) (c *C Name: cj.Name, - IDs: cj.IDs, - Tags: cj.Tags, - BlockedServices: cj.BlockedServices, - Upstreams: cj.Upstreams, + BlockedServices: &filtering.BlockedServices{ + Schedule: prev.BlockedServices.Schedule.Clone(), + IDs: cj.BlockedServices, + }, + + IDs: cj.IDs, + Tags: cj.Tags, + Upstreams: cj.Upstreams, UseOwnSettings: !cj.UseGlobalSettings, FilteringEnabled: cj.FilteringEnabled, @@ -180,7 +184,8 @@ func clientToJSON(c *Client) (cj *clientJSON) { SafeBrowsingEnabled: c.SafeBrowsingEnabled, UseGlobalBlockedServices: !c.UseOwnBlockedServices, - BlockedServices: c.BlockedServices, + + BlockedServices: c.BlockedServices.IDs, Upstreams: c.Upstreams, diff --git a/internal/home/dns.go b/internal/home/dns.go index 344f0be0..3a37f751 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -474,9 +474,11 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering if c.UseOwnBlockedServices { // TODO(e.burkov): Get rid of this crutch. setts.ServicesRules = nil - svcs := c.BlockedServices - Context.filters.ApplyBlockedServicesList(setts, svcs) - log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs) + svcs := c.BlockedServices.IDs + if !c.BlockedServices.Schedule.Contains(time.Now()) { + Context.filters.ApplyBlockedServicesList(setts, svcs) + log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs) + } } setts.ClientName = c.Name diff --git a/internal/home/dns_internal_test.go b/internal/home/dns_internal_test.go index 9450cdf6..8ba988f2 100644 --- a/internal/home/dns_internal_test.go +++ b/internal/home/dns_internal_test.go @@ -6,9 +6,87 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/schedule" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestApplyAdditionalFiltering(t *testing.T) { + var err error + + Context.filters, err = filtering.New(&filtering.Config{ + BlockedServices: &filtering.BlockedServices{ + Schedule: schedule.EmptyWeekly(), + }, + }, nil) + require.NoError(t, err) + + Context.clients.idIndex = map[string]*Client{ + "default": { + UseOwnSettings: false, + safeSearchConf: filtering.SafeSearchConfig{Enabled: false}, + FilteringEnabled: false, + SafeBrowsingEnabled: false, + ParentalEnabled: false, + }, + "custom_filtering": { + UseOwnSettings: true, + safeSearchConf: filtering.SafeSearchConfig{Enabled: true}, + FilteringEnabled: true, + SafeBrowsingEnabled: true, + ParentalEnabled: true, + }, + "partial_custom_filtering": { + UseOwnSettings: true, + safeSearchConf: filtering.SafeSearchConfig{Enabled: true}, + FilteringEnabled: true, + SafeBrowsingEnabled: false, + ParentalEnabled: false, + }, + } + + testCases := []struct { + name string + id string + FilteringEnabled assert.BoolAssertionFunc + SafeSearchEnabled assert.BoolAssertionFunc + SafeBrowsingEnabled assert.BoolAssertionFunc + ParentalEnabled assert.BoolAssertionFunc + }{{ + name: "global_settings", + id: "default", + FilteringEnabled: assert.False, + SafeSearchEnabled: assert.False, + SafeBrowsingEnabled: assert.False, + ParentalEnabled: assert.False, + }, { + name: "custom_settings", + id: "custom_filtering", + FilteringEnabled: assert.True, + SafeSearchEnabled: assert.True, + SafeBrowsingEnabled: assert.True, + ParentalEnabled: assert.True, + }, { + name: "partial", + id: "partial_custom_filtering", + FilteringEnabled: assert.True, + SafeSearchEnabled: assert.True, + SafeBrowsingEnabled: assert.False, + ParentalEnabled: assert.False, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + setts := &filtering.Settings{} + + applyAdditionalFiltering(net.IP{1, 2, 3, 4}, tc.id, setts) + tc.FilteringEnabled(t, setts.FilteringEnabled) + tc.SafeSearchEnabled(t, setts.SafeSearchEnabled) + tc.SafeBrowsingEnabled(t, setts.SafeBrowsingEnabled) + tc.ParentalEnabled(t, setts.ParentalEnabled) + }) + } +} + func TestApplyAdditionalFiltering_blockedServices(t *testing.T) { filtering.InitModule() @@ -29,43 +107,61 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) { require.NoError(t, err) Context.clients.idIndex = map[string]*Client{ - "client_1": { + "default": { UseOwnBlockedServices: false, }, - "client_2": { + "no_services": { + BlockedServices: &filtering.BlockedServices{ + Schedule: schedule.EmptyWeekly(), + }, UseOwnBlockedServices: true, }, - "client_3": { - BlockedServices: clientBlockedServices, + "services": { + BlockedServices: &filtering.BlockedServices{ + Schedule: schedule.EmptyWeekly(), + IDs: clientBlockedServices, + }, UseOwnBlockedServices: true, }, - "client_4": { - BlockedServices: invalidBlockedServices, + "invalid_services": { + BlockedServices: &filtering.BlockedServices{ + Schedule: schedule.EmptyWeekly(), + IDs: invalidBlockedServices, + }, + UseOwnBlockedServices: true, + }, + "allow_all": { + BlockedServices: &filtering.BlockedServices{ + Schedule: schedule.FullWeekly(), + IDs: clientBlockedServices, + }, UseOwnBlockedServices: true, }, } testCases := []struct { name string - ip net.IP id string - setts *filtering.Settings wantLen int }{{ name: "global_settings", - id: "client_1", + id: "default", wantLen: len(globalBlockedServices), }, { name: "custom_settings", - id: "client_2", + id: "no_services", wantLen: 0, }, { name: "custom_settings_block", - id: "client_3", + id: "services", wantLen: len(clientBlockedServices), }, { name: "custom_settings_invalid", - id: "client_4", + id: "invalid_services", + wantLen: 0, + }, { + name: "custom_settings_inactive_schedule", + id: "allow_all", wantLen: 0, }} diff --git a/internal/home/home.go b/internal/home/home.go index a1525e95..87dba706 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -355,13 +355,17 @@ func initContextClients() (err error) { arpdb = aghnet.NewARPDB() } - Context.clients.Init( + err = Context.clients.Init( config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, arpdb, config.DNS.DnsfilterConf, ) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return err + } return nil } diff --git a/internal/home/rdns_test.go b/internal/home/rdns_test.go index 078a092f..5582bf5b 100644 --- a/internal/home/rdns_test.go +++ b/internal/home/rdns_test.go @@ -228,7 +228,7 @@ func TestRDNS_WorkerLoop(t *testing.T) { for _, tc := range testCases { w.Reset() - cc := newClientsContainer() + cc := newClientsContainer(t) ch := make(chan netip.Addr) rdns := &RDNS{ exchanger: &rDNSExchanger{ diff --git a/internal/home/upgrade.go b/internal/home/upgrade.go index 2326bdb7..d099378b 100644 --- a/internal/home/upgrade.go +++ b/internal/home/upgrade.go @@ -22,7 +22,7 @@ import ( ) // currentSchemaVersion is the current schema version. -const currentSchemaVersion = 21 +const currentSchemaVersion = 22 // These aliases are provided for convenience. type ( @@ -95,6 +95,7 @@ func upgradeConfigSchema(oldVersion int, diskConf yobj) (err error) { upgradeSchema18to19, upgradeSchema19to20, upgradeSchema20to21, + upgradeSchema21to22, } n := 0 @@ -1179,6 +1180,82 @@ func upgradeSchema20to21(diskConf yobj) (err error) { return nil } +// upgradeSchema21to22 performs the following changes: +// +// # BEFORE: +// 'persistent': +// - 'name': 'client_name' +// 'blocked_services': +// - 'svc_name' +// +// # AFTER: +// 'persistent': +// - 'name': 'client_name' +// 'blocked_services': +// 'ids': +// - 'svc_name' +// 'schedule': +// 'time_zone': 'Local' +func upgradeSchema21to22(diskConf yobj) (err error) { + log.Println("Upgrade yaml: 21 to 22") + diskConf["schema_version"] = 22 + + const field = "blocked_services" + + clientsVal, ok := diskConf["clients"] + if !ok { + return nil + } + + clients, ok := clientsVal.(yobj) + if !ok { + return fmt.Errorf("unexpected type of clients: %T", clientsVal) + } + + persistentVal, ok := clients["persistent"] + if !ok { + return nil + } + + persistent, ok := persistentVal.([]any) + if !ok { + return fmt.Errorf("unexpected type of persistent clients: %T", persistentVal) + } + + for i, val := range persistent { + var c yobj + c, ok = val.(yobj) + if !ok { + return fmt.Errorf("persistent client at index %d: unexpected type %T", i, val) + } + + var blockedVal any + blockedVal, ok = c[field] + if !ok { + continue + } + + var services yarr + services, ok = blockedVal.(yarr) + if !ok { + return fmt.Errorf( + "persistent client at index %d: unexpected type of blocked services: %T", + i, + blockedVal, + ) + } + + c[field] = yobj{ + "ids": services, + "schedule": yobj{ + "time_zone": "Local", + }, + } + } + + return nil +} + // TODO(a.garipov): Replace with log.Output when we port it to our logging // package. func funcName() string { diff --git a/internal/home/upgrade_test.go b/internal/home/upgrade_test.go index aabf514f..1839cc28 100644 --- a/internal/home/upgrade_test.go +++ b/internal/home/upgrade_test.go @@ -1183,3 +1183,73 @@ func TestUpgradeSchema20to21(t *testing.T) { }) } } + +func TestUpgradeSchema21to22(t *testing.T) { + const newSchemaVer = 22 + + testCases := []struct { + in yobj + want yobj + name string + }{{ + in: yobj{ + "clients": yobj{}, + }, + want: yobj{ + "clients": yobj{}, + "schema_version": newSchemaVer, + }, + name: "nothing", + }, { + in: yobj{ + "clients": yobj{ + "persistent": []any{yobj{"name": "localhost", "blocked_services": yarr{}}}, + }, + }, + want: yobj{ + "clients": yobj{ + "persistent": []any{yobj{ + "name": "localhost", + "blocked_services": yobj{ + "ids": yarr{}, + "schedule": yobj{ + "time_zone": "Local", + }, + }, + }}, + }, + "schema_version": newSchemaVer, + }, + name: "no_services", + }, { + in: yobj{ + "clients": yobj{ + "persistent": []any{yobj{"name": "localhost", "blocked_services": yarr{"ok"}}}, + }, + }, + want: yobj{ + "clients": yobj{ + "persistent": []any{yobj{ + "name": "localhost", + "blocked_services": yobj{ + "ids": yarr{"ok"}, + "schedule": yobj{ + "time_zone": "Local", + }, + }, + }}, + }, + "schema_version": newSchemaVer, + }, + name: "services", + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := upgradeSchema21to22(tc.in) + require.NoError(t, err) + + assert.Equal(t, tc.want, tc.in) + }) + } +} diff --git a/internal/schedule/schedule.go b/internal/schedule/schedule.go index ba3757f9..1bf96016 100644 --- a/internal/schedule/schedule.go +++ b/internal/schedule/schedule.go @@ -28,6 +28,36 @@ func EmptyWeekly() (w *Weekly) { } } +// FullWeekly creates full weekly schedule with local time zone. +// +// TODO(s.chzhen): Consider moving into tests. +func FullWeekly() (w *Weekly) { + fullDay := dayRange{start: 0, end: maxDayRange} + + return &Weekly{ + location: time.Local, + days: [7]dayRange{ + time.Sunday: fullDay, + time.Monday: fullDay, + time.Tuesday: fullDay, + time.Wednesday: fullDay, + time.Thursday: fullDay, + time.Friday: fullDay, + time.Saturday: fullDay, + }, + } +} + +// Clone returns a deep copy of a weekly. +func (w *Weekly) Clone() (c *Weekly) { + // NOTE: Do not use time.LoadLocation, because the results will be + // different on time zone database update. + return &Weekly{ + location: w.location, + days: w.days, + } +} + // Contains returns true if t is within the corresponding day range of the // schedule in the schedule's time zone. func (w *Weekly) Contains(t time.Time) (ok bool) {