diff --git a/.github/ISSUE_TEMPLATE/bug.yml b/.github/ISSUE_TEMPLATE/bug.yml index a5c40e4f..98db03e0 100644 --- a/.github/ISSUE_TEMPLATE/bug.yml +++ b/.github/ISSUE_TEMPLATE/bug.yml @@ -10,52 +10,58 @@ - 'label': > I have checked the [Wiki](https://github.com/AdguardTeam/AdGuardHome/wiki) and - [Discussions](https://github.com/AdguardTeam/AdGuardHome/discussions) + [Discussions](https://github.com/AdguardTeam/AdGuardHome/discussions/categories/q-a) and found no answer 'required': true - 'label': > I have searched other issues and found no duplicates 'required': true - 'label': > - I want to report a bug and not ask a question + I want to report a bug and not [ask a question or ask for + help](https://github.com/AdguardTeam/AdGuardHome/discussions/categories/q-a) + 'required': true + - 'label': > + I have set up AdGuard Home correctly and [configured clients to + use it](https://github.com/AdguardTeam/AdGuardHome/wiki/Clients). + (Use the + [Discussions](https://github.com/AdguardTeam/AdGuardHome/discussions/categories/q-a) + for help with installing and configuring clients.) 'required': true 'id': 'prerequisites' 'type': 'checkboxes' - 'attributes': - 'description': 'On which operating system type does the issue occur?' - 'label': 'Operating system type' + 'description': 'On which Platform does the issue occur?' + 'label': 'Platform (OS and CPU architecture)' 'options': - - 'FreeBSD' - - 'Linux, OpenWrt' - - 'Linux, Other (please mention the version in the description)' - - 'macOS (aka Darwin)' - - 'OpenBSD' - - 'Windows' - - 'Other (please mention in the description)' + - 'Darwin (aka macOS)/AMD64 (aka x86_64)' + - 'Darwin (aka macOS)/ARM64' + - 'FreeBSD/386' + - 'FreeBSD/AMD64 (aka x86_64)' + - 'FreeBSD/ARM64' + - 'FreeBSD/ARMv5' + - 'FreeBSD/ARMv6' + - 'FreeBSD/ARMv7' + - 'Linux/386' + - 'Linux/AMD64 (aka x86_64)' + - 'Linux/ARM64' + - 'Linux/ARMv5' + - 'Linux/ARMv6' + - 'Linux/ARMv7' + - 'Linux/MIPS LE' + - 'Linux/MIPS' + - 'Linux/MIPS64 LE' + - 'Linux/MIPS64' + - 'Linux/PPC64 LE' + - 'OpenBSD/AMD64 (aka x86_64)' + - 'OpenBSD/ARM64' + - 'Windows/386' + - 'Windows/AMD64 (aka x86_64)' + - 'Windows/ARM64' + - 'Custom (please mention in the description)' 'id': 'os' 'type': 'dropdown' 'validations': 'required': true - - 'attributes': - 'description': 'On which CPU architecture does the issue occur?' - 'label': 'CPU architecture' - 'options': - - 'AMD64' - - 'x86' - - '64-bit ARM' - - 'ARMv5' - - 'ARMv6' - - 'ARMv7' - - '64-bit MIPS' - - '64-bit MIPS LE' - - '32-bit MIPS' - - '32-bit MIPS LE' - - '64-bit PowerPC LE' - - 'Other (please mention in the description)' - 'id': 'arch' - 'type': 'dropdown' - 'validations': - 'required': true - 'attributes': 'description': 'How did you install AdGuard Home?' 'label': 'Installation' @@ -63,7 +69,7 @@ - 'GitHub releases or script from README' - 'Docker' - 'Snapcraft' - - 'Custom port' + - 'Custom package (OpenWrt, HomeAssistant, etc; please mention in the description)' - 'Other (please mention in the description)' 'id': 'install' 'type': 'dropdown' @@ -89,21 +95,55 @@ 'validations': 'required': true - 'attributes': - 'description': 'Please describe the bug' - 'label': 'Description' + 'description': > + Please describe what you did. An `nslookup` or a `dig` command is + the best way. For crashes, please provide a full failure log. + 'label': 'Action' 'value': | - #### What did you do? - - #### Expected result - - #### Actual result - - #### Screenshots (if applicable) - - #### Additional information - 'id': 'description' + ```sh + nslookup -debug -type=a 'www.example.com' '$YOUR_AGH_ADDRESS' + ``` + 'id': 'failing_action' 'type': 'textarea' 'validations': 'required': true -'description': 'File a bug report' + - 'attributes': + 'description': > + What did you expect to see? Please add a description and/or + screenshots, if applicable. + 'label': 'Expected result' + 'placeholder': > + What did you expect to see? + 'id': 'expected' + 'type': 'textarea' + 'validations': + 'required': true + - 'attributes': + 'description': > + What happened instead? Please add a description and/or screenshots, + if applicable. + 'label': 'Actual result' + 'placeholder': > + What did you see instead? + 'id': 'result' + 'type': 'textarea' + 'validations': + 'required': true + - 'attributes': + 'description': > + Please add additional information, such as non-standard OS or port, + here. You can also put screenshots here, if applicable. For + example, it is better to copy and paste text from a terminal instead + of posting a screenshot of the terminal. + 'label': 'Additional information and/or screenshots' + 'placeholder': > + Additional OS information, screenshots of the UI, etc. + 'id': 'additional' + 'type': 'textarea' + 'validations': + 'required': false +'description': > + Open a bug report. Please do not open bug reports for questions or help + with configuring clients. If you want to ask for help, use the Discussions + section. 'name': 'Bug' diff --git a/.github/ISSUE_TEMPLATE/feature.yml b/.github/ISSUE_TEMPLATE/feature.yml index 0ad6f5d8..154a137d 100644 --- a/.github/ISSUE_TEMPLATE/feature.yml +++ b/.github/ISSUE_TEMPLATE/feature.yml @@ -23,19 +23,32 @@ 'id': 'prerequisites' 'type': 'checkboxes' - 'attributes': - 'description': 'Please describe the request' - 'label': 'Description' - 'value': | - #### What problem are you trying to solve? - - #### Proposed solution - - #### Alternatives considered - - #### Additional information - 'id': 'description' + 'description': 'Please describe the problem you are trying to solve' + 'label': 'The problem' + 'placeholder': > + Please describe the problem you are trying to solve + 'id': 'problem' 'type': 'textarea' 'validations': 'required': true + - 'attributes': + 'description': 'What feature are you proposing to solve this problem?' + 'label': 'Proposed solution' + 'placeholder': > + What feature are you proposing to solve this problem? + 'id': 'proposed_solution' + 'type': 'textarea' + 'validations': + 'required': true + - 'attributes': + 'label': 'Alternatives considered and additional information' + 'placeholder': > + Are there any other ways to solve the problem? + 'id': 'additional' + 'type': 'textarea' + 'validations': + 'required': false 'description': 'Suggest a feature or an enhancement for AdGuard Home' +'labels': + - 'feature request' 'name': 'Feature request or enhancement' diff --git a/.github/PULL_REQUEST_TEMPLATE b/.github/PULL_REQUEST_TEMPLATE new file mode 100644 index 00000000..d969343d --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE @@ -0,0 +1,20 @@ +Before submitting a PR please make sure that: + +1. You have discussed your solution in an issue and have got an + approval from a maintainer. + +2. This isn't a localization fix; please send those to our + [CrowdIn](https://crowdin.com/project/adguard-applications/en#/adguard-home) + page. + +3. Your code follows our + [code guidelines](https://github.com/AdguardTeam/CodeGuidelines/blob/master/Go/Go.md). + +Add a short description here. The description should include: + +1. Which issue this PR closes (`Closes #NNNN.`) or updates (`Updates + #NNNN.`). + +2. A short description of how the change achieves that. + +Do not forget to remove these instructions. diff --git a/.github/workflows/potential-duplicates.yml b/.github/workflows/potential-duplicates.yml new file mode 100644 index 00000000..dd065845 --- /dev/null +++ b/.github/workflows/potential-duplicates.yml @@ -0,0 +1,18 @@ +'name': 'potential-duplicates' +'on': + 'issues': + 'types': + - 'opened' +'jobs': + 'run': + 'runs-on': 'ubuntu-latest' + 'steps': + - 'uses': 'wow-actions/potential-duplicates@v1' + 'with': + 'GITHUB_TOKEN': '${{ secrets.GITHUB_TOKEN }}' + 'state': 'all' + 'threshold': 0.6 + 'comment': | + Potential duplicates: {{#issues}} + * [#{{ number }}] {{ title }} ({{ accuracy }}%) + {{/issues}} diff --git a/CHANGELOG.md b/CHANGELOG.md index 91f53b16..99ba84e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,9 +28,67 @@ NOTE: Add new changes BELOW THIS COMMENT. - The new HTTP API, `GET /control/querylog/export`, which can be used to export query log items. See `openapi/openapi.yaml` for the full description ([#3389]). +- The ability to set inactivity periods for filtering blocked services in the + configuration file ([#951]). The UI changes are coming in the upcoming + releases. - The ability to edit rewrite rules via `PUT /control/rewrite/update` HTTP API ([#1577]). +### Changed + +#### Configuration Changes + +In this release, the schema version has changed from 20 to 21. + +- Property `dns.blocked_services`, which in schema versions 20 and earlier used + to be a list containing ids of blocked services, is now an object containing + ids and schedule for blocked services: + + ```yaml + # BEFORE: + 'blocked_services': + - id_1 + - id_2 + + # AFTER: + 'blocked_services': + 'ids': + - id_1 + - id_2 + 'schedule': + 'time_zone': 'Local' + 'sun': + 'start': '0s' + 'end': '24h' + 'mon': + 'start': '10m' + 'end': '23h30m' + 'tue': + 'start': '20m' + 'end': '23h' + 'wed': + 'start': '30m' + 'end': '22h30m' + 'thu': + 'start': '40m' + 'end': '22h' + 'fri': + 'start': '50m' + 'end': '21h30m' + 'sat': + 'start': '1h' + 'end': '21h' + ``` + + To rollback this change, replace `dns.blocked_services` object with the list + of ids of blocked services and change the `schema_version` back to `20`. + +### Fixed + + - DNSCrypt upstream not resetting the client and resolver information on + dialing errors ([#5872]). + +[#951]: https://github.com/AdguardTeam/AdGuardHome/issues/951 [#1577]: https://github.com/AdguardTeam/AdGuardHome/issues/1577 [#3389]: https://github.com/AdguardTeam/AdGuardHome/issues/3389 diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index 6ee4e0f3..f55e3059 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -53,14 +53,14 @@ func (s *Server) beforeRequestHandler( // getClientRequestFilteringSettings looks up client filtering settings using // the client's IP address and ID, if any, from dctx. func (s *Server) getClientRequestFilteringSettings(dctx *dnsContext) *filtering.Settings { - setts := s.dnsFilter.GetConfig() + setts := s.dnsFilter.Settings() setts.ProtectionEnabled = dctx.protectionEnabled if s.conf.FilterHandler != nil { ip, _ := netutil.IPAndPortFromAddr(dctx.proxyCtx.Addr) - s.conf.FilterHandler(ip, dctx.clientID, &setts) + s.conf.FilterHandler(ip, dctx.clientID, setts) } - return &setts + return setts } // filterDNSRequest applies the dnsFilter and sets dctx.proxyCtx.Res if the diff --git a/internal/filtering/blocked.go b/internal/filtering/blocked.go index 4e4878aa..d8c8b9d2 100644 --- a/internal/filtering/blocked.go +++ b/internal/filtering/blocked.go @@ -3,8 +3,10 @@ package filtering import ( "encoding/json" "net/http" + "time" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" + "github.com/AdguardTeam/AdGuardHome/internal/schedule" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/urlfilter/rules" "golang.org/x/exp/slices" @@ -44,6 +46,15 @@ func initBlockedServices() { log.Debug("filtering: initialized %d services", l) } +// BlockedServices is the configuration of blocked services. +type BlockedServices struct { + // Schedule is blocked services schedule for every day of the week. + Schedule *schedule.Weekly `yaml:"schedule"` + + // IDs is the names of blocked services. + IDs []string `yaml:"ids"` +} + // BlockedSvcKnown returns true if a blocked service ID is known. func BlockedSvcKnown(s string) (ok bool) { _, ok = serviceRules[s] @@ -52,15 +63,22 @@ func BlockedSvcKnown(s string) (ok bool) { } // ApplyBlockedServices - set blocked services settings for this DNS request -func (d *DNSFilter) ApplyBlockedServices(setts *Settings, list []string) { +func (d *DNSFilter) ApplyBlockedServices(setts *Settings) { + d.confLock.RLock() + defer d.confLock.RUnlock() + setts.ServicesRules = []ServiceEntry{} - if list == nil { - d.confLock.RLock() - defer d.confLock.RUnlock() - list = d.Config.BlockedServices + bsvc := d.BlockedServices + + // TODO(s.chzhen): Use startTime from [dnsforward.dnsContext]. + if !bsvc.Schedule.Contains(time.Now()) { + d.ApplyBlockedServicesList(setts, bsvc.IDs) } +} +// ApplyBlockedServicesList appends filtering rules to the settings. +func (d *DNSFilter) ApplyBlockedServicesList(setts *Settings, list []string) { for _, name := range list { rules, ok := serviceRules[name] if !ok { @@ -90,7 +108,7 @@ func (d *DNSFilter) handleBlockedServicesAll(w http.ResponseWriter, r *http.Requ func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Request) { d.confLock.RLock() - list := d.Config.BlockedServices + list := d.Config.BlockedServices.IDs d.confLock.RUnlock() _ = aghhttp.WriteJSONResponse(w, r, list) @@ -106,7 +124,7 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ } d.confLock.Lock() - d.Config.BlockedServices = list + d.Config.BlockedServices.IDs = list d.confLock.Unlock() log.Debug("Updated blocked services list: %d", len(list)) diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index 5c30c645..ea6d4bfb 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -103,9 +103,9 @@ type Config struct { Rewrites []*LegacyRewrite `yaml:"rewrites"` - // Names of services to block (globally). + // BlockedServices is the configuration of blocked services. // Per-client settings can override this configuration. - BlockedServices []string `yaml:"blocked_services"` + BlockedServices *BlockedServices `yaml:"blocked_services"` // EtcHosts is a container of IP-hostname pairs taken from the operating // system configuration files (e.g. /etc/hosts). @@ -298,12 +298,12 @@ func (d *DNSFilter) SetEnabled(enabled bool) { atomic.StoreUint32(&d.enabled, mathutil.BoolToNumber[uint32](enabled)) } -// GetConfig - get configuration -func (d *DNSFilter) GetConfig() (s Settings) { +// Settings returns filtering settings. +func (d *DNSFilter) Settings() (s *Settings) { d.confLock.RLock() defer d.confLock.RUnlock() - return Settings{ + return &Settings{ FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) != 0, SafeSearchEnabled: d.Config.SafeSearchConf.Enabled, SafeBrowsingEnabled: d.Config.SafeBrowsingEnabled, @@ -987,16 +987,19 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) { return nil, fmt.Errorf("rewrites: preparing: %s", err) } - bsvcs := []string{} - for _, s := range d.BlockedServices { - if !BlockedSvcKnown(s) { - log.Debug("skipping unknown blocked-service %q", s) + if d.BlockedServices != nil { + bsvcs := []string{} + for _, s := range d.BlockedServices.IDs { + if !BlockedSvcKnown(s) { + log.Debug("skipping unknown blocked-service %q", s) - continue + continue + } + + bsvcs = append(bsvcs, s) } - bsvcs = append(bsvcs, s) + d.BlockedServices.IDs = bsvcs } - d.BlockedServices = bsvcs if blockFilters != nil { err = d.initFiltering(nil, blockFilters) diff --git a/internal/filtering/http.go b/internal/filtering/http.go index dbfe6889..41964965 100644 --- a/internal/filtering/http.go +++ b/internal/filtering/http.go @@ -416,12 +416,12 @@ type checkHostResp struct { func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) { host := r.URL.Query().Get("name") - setts := d.GetConfig() + setts := d.Settings() setts.FilteringEnabled = true setts.ProtectionEnabled = true - d.ApplyBlockedServices(&setts, nil) - result, err := d.CheckHost(host, dns.TypeA, &setts) + d.ApplyBlockedServices(setts) + result, err := d.CheckHost(host, dns.TypeA, setts) if err != nil { aghhttp.Error( r, diff --git a/internal/home/config.go b/internal/home/config.go index 8d9fa422..2eb0aff5 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -14,6 +14,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/querylog" + "github.com/AdguardTeam/AdGuardHome/internal/schedule" "github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/dnsproxy/fastip" "github.com/AdguardTeam/golibs/errors" @@ -316,6 +317,11 @@ var config = &configuration{ Yandex: true, YouTube: true, }, + + BlockedServices: &filtering.BlockedServices{ + Schedule: schedule.EmptyWeekly(), + IDs: []string{}, + }, }, UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout}, UsePrivateRDNS: true, diff --git a/internal/home/dns.go b/internal/home/dns.go index 9c3b03cc..48b332f2 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -390,7 +390,7 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering // pref is a prefix for logging messages around the scope. const pref = "applying filters" - Context.filters.ApplyBlockedServices(setts, nil) + Context.filters.ApplyBlockedServices(setts) log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID) @@ -418,7 +418,7 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering if svcs == nil { svcs = []string{} } - Context.filters.ApplyBlockedServices(setts, svcs) + Context.filters.ApplyBlockedServicesList(setts, svcs) log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs) } diff --git a/internal/home/upgrade.go b/internal/home/upgrade.go index e429eb41..2326bdb7 100644 --- a/internal/home/upgrade.go +++ b/internal/home/upgrade.go @@ -22,7 +22,7 @@ import ( ) // currentSchemaVersion is the current schema version. -const currentSchemaVersion = 20 +const currentSchemaVersion = 21 // These aliases are provided for convenience. type ( @@ -94,6 +94,7 @@ func upgradeConfigSchema(oldVersion int, diskConf yobj) (err error) { upgradeSchema17to18, upgradeSchema18to19, upgradeSchema19to20, + upgradeSchema20to21, } n := 0 @@ -1128,6 +1129,56 @@ func upgradeSchema19to20(diskConf yobj) (err error) { return nil } +// upgradeSchema20to21 performs the following changes: +// +// # BEFORE: +// 'dns': +// 'blocked_services': +// - 'svc_name' +// +// # AFTER: +// 'dns': +// 'blocked_services': +// 'ids': +// - 'svc_name' +// 'schedule': +// 'time_zone': 'Local' +func upgradeSchema20to21(diskConf yobj) (err error) { + log.Printf("Upgrade yaml: 20 to 21") + diskConf["schema_version"] = 21 + + const field = "blocked_services" + + dnsVal, ok := diskConf["dns"] + if !ok { + return nil + } + + dns, ok := dnsVal.(yobj) + if !ok { + return fmt.Errorf("unexpected type of dns: %T", dnsVal) + } + + blockedVal, ok := dns[field] + if !ok { + return nil + } + + services, ok := blockedVal.(yarr) + if !ok { + return fmt.Errorf("unexpected type of blocked: %T", blockedVal) + } + + dns[field] = yobj{ + "ids": services, + "schedule": yobj{ + "time_zone": "Local", + }, + } + + return nil +} + // TODO(a.garipov): Replace with log.Output when we port it to our logging // package. func funcName() string { diff --git a/internal/home/upgrade_test.go b/internal/home/upgrade_test.go index 11820be0..aabf514f 100644 --- a/internal/home/upgrade_test.go +++ b/internal/home/upgrade_test.go @@ -1140,3 +1140,46 @@ func TestUpgradeSchema19to20(t *testing.T) { assert.Equal(t, 24*time.Hour, ivlVal.Duration) }) } + +func TestUpgradeSchema20to21(t *testing.T) { + const newSchemaVer = 21 + + testCases := []struct { + in yobj + want yobj + name string + }{{ + name: "nothing", + in: yobj{}, + want: yobj{ + "schema_version": newSchemaVer, + }, + }, { + name: "no_clients", + in: yobj{ + "dns": yobj{ + "blocked_services": yarr{"ok"}, + }, + }, + want: yobj{ + "dns": yobj{ + "blocked_services": yobj{ + "ids": yarr{"ok"}, + "schedule": yobj{ + "time_zone": "Local", + }, + }, + }, + "schema_version": newSchemaVer, + }, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := upgradeSchema20to21(tc.in) + require.NoError(t, err) + + assert.Equal(t, tc.want, tc.in) + }) + } +} diff --git a/internal/schedule/schedule.go b/internal/schedule/schedule.go new file mode 100644 index 00000000..ba3757f9 --- /dev/null +++ b/internal/schedule/schedule.go @@ -0,0 +1,220 @@ +// Package schedule provides types for scheduling. +package schedule + +import ( + "fmt" + "time" + + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/timeutil" + "gopkg.in/yaml.v3" +) + +// Weekly is a schedule for one week. Each day of the week has one range with +// a beginning and an end. +type Weekly struct { + // location is used to calculate the offsets of the day ranges. + location *time.Location + + // days are the day ranges of this schedule. The indexes of this array are + // the [time.Weekday] values. + days [7]dayRange +} + +// EmptyWeekly creates empty weekly schedule with local time zone. +func EmptyWeekly() (w *Weekly) { + return &Weekly{ + location: time.Local, + } +} + +// Contains returns true if t is within the corresponding day range of the +// schedule in the schedule's time zone. +func (w *Weekly) Contains(t time.Time) (ok bool) { + t = t.In(w.location) + wd := t.Weekday() + dr := w.days[wd] + + // Calculate the offset of the day range. + // + // NOTE: Do not use [time.Truncate] since it requires UTC time zone. + y, m, d := t.Date() + day := time.Date(y, m, d, 0, 0, 0, 0, w.location) + offset := t.Sub(day) + + return dr.contains(offset) +} + +// type check +var _ yaml.Unmarshaler = (*Weekly)(nil) + +// UnmarshalYAML implements the [yaml.Unmarshaler] interface for *Weekly. +func (w *Weekly) UnmarshalYAML(value *yaml.Node) (err error) { + conf := &weeklyConfig{} + + err = value.Decode(conf) + if err != nil { + // Don't wrap the error since it's informative enough as is. + return err + } + + weekly := Weekly{} + + weekly.location, err = time.LoadLocation(conf.TimeZone) + if err != nil { + // Don't wrap the error since it's informative enough as is. + return err + } + + days := []dayConfig{ + time.Sunday: conf.Sunday, + time.Monday: conf.Monday, + time.Tuesday: conf.Tuesday, + time.Wednesday: conf.Wednesday, + time.Thursday: conf.Thursday, + time.Friday: conf.Friday, + time.Saturday: conf.Saturday, + } + for i, d := range days { + r := dayRange{ + start: d.Start.Duration, + end: d.End.Duration, + } + + err = w.validate(r) + if err != nil { + return fmt.Errorf("weekday %s: %w", time.Weekday(i), err) + } + + weekly.days[i] = r + } + + *w = weekly + + return nil +} + +// weeklyConfig is the YAML configuration structure of Weekly. +type weeklyConfig struct { + // TimeZone is the local time zone. + TimeZone string `yaml:"time_zone"` + + // Days of the week. + + Sunday dayConfig `yaml:"sun,omitempty"` + Monday dayConfig `yaml:"mon,omitempty"` + Tuesday dayConfig `yaml:"tue,omitempty"` + Wednesday dayConfig `yaml:"wed,omitempty"` + Thursday dayConfig `yaml:"thu,omitempty"` + Friday dayConfig `yaml:"fri,omitempty"` + Saturday dayConfig `yaml:"sat,omitempty"` +} + +// dayConfig is the YAML configuration structure of dayRange. +type dayConfig struct { + Start timeutil.Duration `yaml:"start"` + End timeutil.Duration `yaml:"end"` +} + +// maxDayRange is the maximum value for day range end. +const maxDayRange = 24 * time.Hour + +// validate returns the day range rounding errors, if any. +func (w *Weekly) validate(r dayRange) (err error) { + defer func() { err = errors.Annotate(err, "bad day range: %w") }() + + err = r.validate() + if err != nil { + // Don't wrap the error since it's informative enough as is. + return err + } + + start := r.start.Truncate(time.Minute) + end := r.end.Truncate(time.Minute) + + switch { + case start != r.start: + return fmt.Errorf("start %s isn't rounded to minutes", r.start) + case end != r.end: + return fmt.Errorf("end %s isn't rounded to minutes", r.end) + default: + return nil + } +} + +// type check +var _ yaml.Marshaler = (*Weekly)(nil) + +// MarshalYAML implements the [yaml.Marshaler] interface for *Weekly. +func (w *Weekly) MarshalYAML() (v any, err error) { + return weeklyConfig{ + TimeZone: w.location.String(), + Sunday: dayConfig{ + Start: timeutil.Duration{Duration: w.days[time.Sunday].start}, + End: timeutil.Duration{Duration: w.days[time.Sunday].end}, + }, + Monday: dayConfig{ + Start: timeutil.Duration{Duration: w.days[time.Monday].start}, + End: timeutil.Duration{Duration: w.days[time.Monday].end}, + }, + Tuesday: dayConfig{ + Start: timeutil.Duration{Duration: w.days[time.Tuesday].start}, + End: timeutil.Duration{Duration: w.days[time.Tuesday].end}, + }, + Wednesday: dayConfig{ + Start: timeutil.Duration{Duration: w.days[time.Wednesday].start}, + End: timeutil.Duration{Duration: w.days[time.Wednesday].end}, + }, + Thursday: dayConfig{ + Start: timeutil.Duration{Duration: w.days[time.Thursday].start}, + End: timeutil.Duration{Duration: w.days[time.Thursday].end}, + }, + Friday: dayConfig{ + Start: timeutil.Duration{Duration: w.days[time.Friday].start}, + End: timeutil.Duration{Duration: w.days[time.Friday].end}, + }, + Saturday: dayConfig{ + Start: timeutil.Duration{Duration: w.days[time.Saturday].start}, + End: timeutil.Duration{Duration: w.days[time.Saturday].end}, + }, + }, nil +} + +// dayRange represents a single interval within a day. The interval begins at +// start and ends before end. That is, it contains a time point T if start <= +// T < end. +type dayRange struct { + // start is an offset from the beginning of the day. It must be greater + // than or equal to zero and less than 24h. + start time.Duration + + // end is an offset from the beginning of the day. It must be greater than + // or equal to zero and less than or equal to 24h. + end time.Duration +} + +// validate returns the day range validation errors, if any. +func (r dayRange) validate() (err error) { + switch { + case r == dayRange{}: + return nil + case r.start < 0: + return fmt.Errorf("start %s is negative", r.start) + case r.end < 0: + return fmt.Errorf("end %s is negative", r.end) + case r.start >= r.end: + return fmt.Errorf("start %s is greater or equal to end %s", r.start, r.end) + case r.start >= maxDayRange: + return fmt.Errorf("start %s is greater or equal to %s", r.start, maxDayRange) + case r.end > maxDayRange: + return fmt.Errorf("end %s is greater than %s", r.end, maxDayRange) + default: + return nil + } +} + +// contains returns true if start <= offset < end, where offset is the time +// duration from the beginning of the day. +func (r *dayRange) contains(offset time.Duration) (ok bool) { + return r.start <= offset && offset < r.end +} diff --git a/internal/schedule/schedule_internal_test.go b/internal/schedule/schedule_internal_test.go new file mode 100644 index 00000000..f500524e --- /dev/null +++ b/internal/schedule/schedule_internal_test.go @@ -0,0 +1,371 @@ +package schedule + +import ( + "testing" + "time" + + "github.com/AdguardTeam/golibs/testutil" + "github.com/AdguardTeam/golibs/timeutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +func TestWeekly_Contains(t *testing.T) { + baseTime := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC) + otherTime := baseTime.Add(1 * timeutil.Day) + + // NOTE: In the Etc area the sign of the offsets is flipped. So, Etc/GMT-3 + // is actually UTC+03:00. + otherTZ := time.FixedZone("Etc/GMT-3", 3*60*60) + + // baseSchedule, 12:00 to 14:00. + baseSchedule := &Weekly{ + days: [7]dayRange{ + time.Friday: {start: 12 * time.Hour, end: 14 * time.Hour}, + }, + location: time.UTC, + } + + // allDaySchedule, 00:00 to 24:00. + allDaySchedule := &Weekly{ + days: [7]dayRange{ + time.Friday: {start: 0, end: 24 * time.Hour}, + }, + location: time.UTC, + } + + // oneMinSchedule, 00:00 to 00:01. + oneMinSchedule := &Weekly{ + days: [7]dayRange{ + time.Friday: {start: 0, end: 1 * time.Minute}, + }, + location: time.UTC, + } + + testCases := []struct { + schedule *Weekly + assert assert.BoolAssertionFunc + t time.Time + name string + }{{ + schedule: EmptyWeekly(), + assert: assert.False, + t: baseTime, + name: "empty", + }, { + schedule: allDaySchedule, + assert: assert.True, + t: baseTime, + name: "same_day_all_day", + }, { + schedule: baseSchedule, + assert: assert.True, + t: baseTime.Add(13 * time.Hour), + name: "same_day_inside", + }, { + schedule: baseSchedule, + assert: assert.False, + t: baseTime.Add(11 * time.Hour), + name: "same_day_outside", + }, { + schedule: allDaySchedule, + assert: assert.True, + t: baseTime.Add(24*time.Hour - time.Second), + name: "same_day_last_second", + }, { + schedule: allDaySchedule, + assert: assert.False, + t: otherTime, + name: "other_day_all_day", + }, { + schedule: baseSchedule, + assert: assert.False, + t: otherTime.Add(13 * time.Hour), + name: "other_day_inside", + }, { + schedule: baseSchedule, + assert: assert.False, + t: otherTime.Add(11 * time.Hour), + name: "other_day_outside", + }, { + schedule: baseSchedule, + assert: assert.True, + t: baseTime.Add(13 * time.Hour).In(otherTZ), + name: "same_day_inside_other_tz", + }, { + schedule: baseSchedule, + assert: assert.False, + t: baseTime.Add(11 * time.Hour).In(otherTZ), + name: "same_day_outside_other_tz", + }, { + schedule: oneMinSchedule, + assert: assert.True, + t: baseTime, + name: "one_minute_beginning", + }, { + schedule: oneMinSchedule, + assert: assert.True, + t: baseTime.Add(1*time.Minute - 1), + name: "one_minute_end", + }, { + schedule: oneMinSchedule, + assert: assert.False, + t: baseTime.Add(1 * time.Minute), + name: "one_minute_past_end", + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.assert(t, tc.schedule.Contains(tc.t)) + }) + } +} + +const brusselsSunday = ` +sun: + start: 12h + end: 14h +time_zone: Europe/Brussels +` + +func TestWeekly_UnmarshalYAML(t *testing.T) { + const ( + sameTime = ` +sun: + start: 9h + end: 9h +` + negativeStart = ` +sun: + start: -1h + end: 1h +` + badTZ = ` +time_zone: "bad_timezone" +` + badYAML = ` +yaml: "bad" +yaml: "bad" +` + ) + + brusseltsTZ, err := time.LoadLocation("Europe/Brussels") + require.NoError(t, err) + + brusselsWeekly := &Weekly{ + days: [7]dayRange{{ + start: time.Hour * 12, + end: time.Hour * 14, + }}, + location: brusseltsTZ, + } + + testCases := []struct { + name string + wantErrMsg string + data []byte + want *Weekly + }{{ + name: "empty", + wantErrMsg: "", + data: []byte(""), + want: &Weekly{}, + }, { + name: "null", + wantErrMsg: "", + data: []byte("null"), + want: &Weekly{}, + }, { + name: "brussels_sunday", + wantErrMsg: "", + data: []byte(brusselsSunday), + want: brusselsWeekly, + }, { + name: "start_equal_end", + wantErrMsg: "weekday Sunday: bad day range: start 9h0m0s is greater or equal to end 9h0m0s", + data: []byte(sameTime), + want: &Weekly{}, + }, { + name: "start_negative", + wantErrMsg: "weekday Sunday: bad day range: start -1h0m0s is negative", + data: []byte(negativeStart), + want: &Weekly{}, + }, { + name: "bad_time_zone", + wantErrMsg: "unknown time zone bad_timezone", + data: []byte(badTZ), + want: &Weekly{}, + }, { + name: "bad_yaml", + wantErrMsg: "yaml: unmarshal errors:\n line 3: mapping key \"yaml\" already defined at line 2", + data: []byte(badYAML), + want: &Weekly{}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + w := &Weekly{} + err = yaml.Unmarshal(tc.data, w) + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + + assert.Equal(t, tc.want, w) + }) + } +} + +func TestWeekly_MarshalYAML(t *testing.T) { + brusselsTZ, err := time.LoadLocation("Europe/Brussels") + require.NoError(t, err) + + brusselsWeekly := &Weekly{ + days: [7]dayRange{time.Sunday: { + start: time.Hour * 12, + end: time.Hour * 14, + }}, + location: brusselsTZ, + } + + testCases := []struct { + name string + data []byte + want *Weekly + }{{ + name: "empty", + data: []byte(""), + want: &Weekly{}, + }, { + name: "null", + data: []byte("null"), + want: &Weekly{}, + }, { + name: "brussels_sunday", + data: []byte(brusselsSunday), + want: brusselsWeekly, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var data []byte + data, err = yaml.Marshal(brusselsWeekly) + require.NoError(t, err) + + w := &Weekly{} + err = yaml.Unmarshal(data, w) + require.NoError(t, err) + + assert.Equal(t, brusselsWeekly, w) + }) + } +} + +func TestWeekly_Validate(t *testing.T) { + testCases := []struct { + name string + in dayRange + wantErrMsg string + }{{ + name: "empty", + wantErrMsg: "", + in: dayRange{}, + }, { + name: "start_seconds", + wantErrMsg: "bad day range: start 1s isn't rounded to minutes", + in: dayRange{ + start: time.Second, + end: time.Hour, + }, + }, { + name: "end_seconds", + wantErrMsg: "bad day range: end 1s isn't rounded to minutes", + in: dayRange{ + start: 0, + end: time.Second, + }, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + w := &Weekly{} + err := w.validate(tc.in) + + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + }) + } +} + +func TestDayRange_Validate(t *testing.T) { + testCases := []struct { + name string + in dayRange + wantErrMsg string + }{{ + name: "empty", + wantErrMsg: "", + in: dayRange{}, + }, { + name: "valid", + wantErrMsg: "", + in: dayRange{ + start: time.Hour, + end: time.Hour * 2, + }, + }, { + name: "valid_end_max", + wantErrMsg: "", + in: dayRange{ + start: 0, + end: time.Hour * 24, + }, + }, { + name: "start_negative", + wantErrMsg: "start -1h0m0s is negative", + in: dayRange{ + start: time.Hour * -1, + end: time.Hour * 2, + }, + }, { + name: "end_negative", + wantErrMsg: "end -1h0m0s is negative", + in: dayRange{ + start: 0, + end: time.Hour * -1, + }, + }, { + name: "start_equal_end", + wantErrMsg: "start 1h0m0s is greater or equal to end 1h0m0s", + in: dayRange{ + start: time.Hour, + end: time.Hour, + }, + }, { + name: "start_greater_end", + wantErrMsg: "start 2h0m0s is greater or equal to end 1h0m0s", + in: dayRange{ + start: time.Hour * 2, + end: time.Hour, + }, + }, { + name: "start_equal_max", + wantErrMsg: "start 24h0m0s is greater or equal to 24h0m0s", + in: dayRange{ + start: time.Hour * 24, + end: time.Hour * 48, + }, + }, { + name: "end_greater_max", + wantErrMsg: "end 48h0m0s is greater than 24h0m0s", + in: dayRange{ + start: 0, + end: time.Hour * 48, + }, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.in.validate() + + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + }) + } +}