diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 10ca71bb..e5b67d10 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -492,9 +492,9 @@ func (s *Server) UpdatedProtectionStatus() (enabled bool, disabledUntil *time.Ti s.serverLock.RLock() defer s.serverLock.RUnlock() - disabledUntil = s.dnsFilter.ProtectionDisabledUntil + enabled, disabledUntil = s.dnsFilter.ProtectionStatus() if disabledUntil == nil { - return s.dnsFilter.ProtectionEnabled, nil + return enabled, nil } if time.Now().Before(*disabledUntil) { @@ -526,8 +526,7 @@ func (s *Server) enableProtectionAfterPause() { s.serverLock.Lock() defer s.serverLock.Unlock() - s.dnsFilter.ProtectionEnabled = true - s.dnsFilter.ProtectionDisabledUntil = nil + s.dnsFilter.SetProtectionStatus(true, nil) log.Info("dns: protection is restarted after pause") } diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 1af4c20a..6782a5c3 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -544,11 +544,8 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) { // dnsFilter can be nil during application update. if s.dnsFilter != nil { - err = validateBlockingMode( - s.dnsFilter.BlockingMode, - s.dnsFilter.BlockingIPv4, - s.dnsFilter.BlockingIPv6, - ) + mode, bIPv4, bIPv6 := s.dnsFilter.BlockingMode() + err = validateBlockingMode(mode, bIPv4, bIPv6) if err != nil { return fmt.Errorf("checking blocking mode: %w", err) } diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index a438aa91..8417a0df 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -105,10 +105,6 @@ func createTestServer( }) require.NoError(t, err) - if s.dnsFilter.BlockingMode == "" { - s.dnsFilter.BlockingMode = filtering.BlockingModeDefault - } - err = s.Prepare(&forwardConf) require.NoError(t, err) @@ -178,7 +174,9 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte) var keyPem []byte _, certPem, keyPem = createServerTLSConfig(t) - s = createTestServer(t, &filtering.Config{}, ServerConfig{ + s = createTestServer(t, &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, + }, ServerConfig{ UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, Config: Config{ @@ -351,9 +349,8 @@ func TestServer_timeout(t *testing.T) { }, } - s, err := NewServer(DNSCreateParams{DNSFilter: &filtering.DNSFilter{}}) + s, err := NewServer(DNSCreateParams{DNSFilter: createTestDNSFilter(t)}) require.NoError(t, err) - s.dnsFilter.BlockingMode = filtering.BlockingModeDefault err = s.Prepare(srvConf) require.NoError(t, err) @@ -362,10 +359,9 @@ func TestServer_timeout(t *testing.T) { }) t.Run("default", func(t *testing.T) { - s, err := NewServer(DNSCreateParams{DNSFilter: &filtering.DNSFilter{}}) + s, err := NewServer(DNSCreateParams{DNSFilter: createTestDNSFilter(t)}) require.NoError(t, err) - s.dnsFilter.BlockingMode = filtering.BlockingModeDefault s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{ Enabled: false, } @@ -377,7 +373,9 @@ func TestServer_timeout(t *testing.T) { } func TestServerWithProtectionDisabled(t *testing.T) { - s := createTestServer(t, &filtering.Config{}, ServerConfig{ + s := createTestServer(t, &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, + }, ServerConfig{ UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, Config: Config{ @@ -490,6 +488,7 @@ func TestSafeSearch(t *testing.T) { } filterConf := &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, ProtectionEnabled: true, SafeSearchConf: safeSearchConf, SafeSearchCacheSize: 1000, @@ -564,7 +563,9 @@ func TestSafeSearch(t *testing.T) { } func TestInvalidRequest(t *testing.T) { - s := createTestServer(t, &filtering.Config{}, ServerConfig{ + s := createTestServer(t, &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, + }, ServerConfig{ UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, Config: Config{ @@ -631,7 +632,9 @@ func TestServerCustomClientUpstream(t *testing.T) { }, }, } - s := createTestServer(t, &filtering.Config{}, forwardConf, nil) + s := createTestServer(t, &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, + }, forwardConf, nil) s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) { ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) { return aghalg.Coalesce( @@ -674,7 +677,9 @@ var testIPv4 = map[string][]net.IP{ } func TestBlockCNAMEProtectionEnabled(t *testing.T) { - s := createTestServer(t, &filtering.Config{}, ServerConfig{ + s := createTestServer(t, &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, + }, ServerConfig{ UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, Config: Config{ @@ -789,7 +794,9 @@ func TestClientRulesForCNAMEMatching(t *testing.T) { }, }, } - s := createTestServer(t, &filtering.Config{}, forwardConf, nil) + s := createTestServer(t, &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, + }, forwardConf, nil) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ &aghtest.Upstream{ CName: testCNAMEs, @@ -901,8 +908,10 @@ func TestBlockedCustomIP(t *testing.T) { err = s.Prepare(conf) assert.Error(t, err) - s.dnsFilter.BlockingIPv4 = netip.AddrFrom4([4]byte{0, 0, 0, 1}) - s.dnsFilter.BlockingIPv6 = netip.MustParseAddr("::1") + s.dnsFilter.SetBlockingMode( + filtering.BlockingModeCustomIP, + netip.AddrFrom4([4]byte{0, 0, 0, 1}), + netip.MustParseAddr("::1")) err = s.Prepare(conf) require.NoError(t, err) @@ -980,6 +989,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) { ans4, _ := aghtest.HostToIPs(hostname) filterConf := &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, ProtectionEnabled: true, SafeBrowsingEnabled: true, SafeBrowsingChecker: sbChecker, diff --git a/internal/dnsforward/dnsrewrite_test.go b/internal/dnsforward/dnsrewrite_test.go index 1566e2ff..79aecdef 100644 --- a/internal/dnsforward/dnsrewrite_test.go +++ b/internal/dnsforward/dnsrewrite_test.go @@ -34,9 +34,14 @@ func TestServer_FilterDNSRewrite(t *testing.T) { } // Helper functions and entities. - srv := &Server{ - dnsFilter: &filtering.DNSFilter{}, - } + srv := createTestServer(t, &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, + }, ServerConfig{ + Config: Config{ + EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + }, + }, nil) + makeQ := func(qtype rules.RRType) (req *dns.Msg) { return &dns.Msg{ Question: []dns.Question{{ diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index 05838877..95e7f407 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -114,9 +114,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) { upstreamFile := s.conf.UpstreamDNSFileName bootstraps := stringutil.CloneSliceOrEmpty(s.conf.BootstrapDNS) fallbacks := stringutil.CloneSliceOrEmpty(s.conf.FallbackDNS) - blockingMode := s.dnsFilter.BlockingMode - blockingIPv4 := s.dnsFilter.BlockingIPv4 - blockingIPv6 := s.dnsFilter.BlockingIPv6 + blockingMode, blockingIPv4, blockingIPv6 := s.dnsFilter.BlockingMode() ratelimit := s.conf.Ratelimit customIP := s.conf.EDNSClientSubnet.CustomIP @@ -320,11 +318,11 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) { defer s.serverLock.Unlock() if dc.BlockingMode != nil { - s.dnsFilter.BlockingMode = *dc.BlockingMode - if *dc.BlockingMode == filtering.BlockingModeCustomIP { - s.dnsFilter.BlockingIPv4 = dc.BlockingIPv4 - s.dnsFilter.BlockingIPv6 = dc.BlockingIPv6 - } + s.dnsFilter.SetBlockingMode(*dc.BlockingMode, dc.BlockingIPv4, dc.BlockingIPv6) + } + + if dc.ProtectionEnabled != nil { + s.dnsFilter.SetProtectionEnabled(*dc.ProtectionEnabled) } if dc.UpstreamMode != nil { @@ -336,7 +334,6 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) { s.conf.EDNSClientSubnet.CustomIP = dc.EDNSCSCustomIP } - setIfNotNil(&s.dnsFilter.ProtectionEnabled, dc.ProtectionEnabled) setIfNotNil(&s.conf.EnableDNSSEC, dc.DNSSECEnabled) setIfNotNil(&s.conf.AAAADisabled, dc.DisableIPv6) @@ -690,8 +687,8 @@ func (s *Server) parseUpstreamLine( } // dnsFilter can be nil during application update. - if s.dnsFilter != nil && s.dnsFilter.EtcHosts != nil { - recs := s.dnsFilter.EtcHosts.MatchName(extractUpstreamHost(upstreamAddr)) + if s.dnsFilter != nil { + recs := s.dnsFilter.EtcHostsRecords(extractUpstreamHost(upstreamAddr)) for _, rec := range recs { opts.ServerIPAddrs = append(opts.ServerIPAddrs, rec.Addr.AsSlice()) } @@ -832,8 +829,7 @@ func (s *Server) handleSetProtection(w http.ResponseWriter, r *http.Request) { s.serverLock.Lock() defer s.serverLock.Unlock() - s.dnsFilter.ProtectionEnabled = protectionReq.Enabled - s.dnsFilter.ProtectionDisabledUntil = disabledUntil + s.dnsFilter.SetProtectionStatus(protectionReq.Enabled, disabledUntil) }() s.conf.ConfigModified() diff --git a/internal/dnsforward/http_test.go b/internal/dnsforward/http_test.go index 2721a9b4..8eda01c4 100644 --- a/internal/dnsforward/http_test.go +++ b/internal/dnsforward/http_test.go @@ -157,7 +157,6 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) { s.sysResolvers = &fakeSystemResolvers{} defaultConf := s.conf - defaultFilterConf := filterConf err := s.Start() assert.NoError(t, err) @@ -248,7 +247,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Cleanup(func() { - s.dnsFilter.Config = *defaultFilterConf + s.dnsFilter.SetBlockingMode(filtering.BlockingModeDefault, netip.Addr{}, netip.Addr{}) s.conf = defaultConf s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{} }) @@ -500,7 +499,8 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) { require.NoError(t, err) srv := createTestServer(t, &filtering.Config{ - EtcHosts: hc, + BlockingMode: filtering.BlockingModeDefault, + EtcHosts: hc, }, ServerConfig{ UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, diff --git a/internal/dnsforward/msg.go b/internal/dnsforward/msg.go index b8ed6a10..6aa47103 100644 --- a/internal/dnsforward/msg.go +++ b/internal/dnsforward/msg.go @@ -50,7 +50,8 @@ func (s *Server) genDNSFilterMessage( req := dctx.Req qt := req.Question[0].Qtype if qt != dns.TypeA && qt != dns.TypeAAAA { - if s.dnsFilter.BlockingMode == filtering.BlockingModeNullIP { + m, _, _ := s.dnsFilter.BlockingMode() + if m == filtering.BlockingModeNullIP { return s.makeResponse(req) } @@ -59,9 +60,9 @@ func (s *Server) genDNSFilterMessage( switch res.Reason { case filtering.FilteredSafeBrowsing: - return s.genBlockedHost(req, s.dnsFilter.SafeBrowsingBlockHost, dctx) + return s.genBlockedHost(req, s.dnsFilter.SafeBrowsingBlockHost(), dctx) case filtering.FilteredParental: - return s.genBlockedHost(req, s.dnsFilter.ParentalBlockHost, dctx) + return s.genBlockedHost(req, s.dnsFilter.ParentalBlockHost(), dctx) case filtering.FilteredSafeSearch: // If Safe Search generated the necessary IP addresses, use them. // Otherwise, if there were no errors, there are no addresses for the @@ -75,21 +76,9 @@ func (s *Server) genDNSFilterMessage( // genForBlockingMode generates a filtered response to req based on the server's // blocking mode. func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.Msg) { - qt := req.Question[0].Qtype - switch m := s.dnsFilter.BlockingMode; m { + switch mode, bIPv4, bIPv6 := s.dnsFilter.BlockingMode(); mode { case filtering.BlockingModeCustomIP: - switch qt { - case dns.TypeA: - return s.genARecord(req, s.dnsFilter.BlockingIPv4) - case dns.TypeAAAA: - return s.genAAAARecord(req, s.dnsFilter.BlockingIPv6) - default: - // Generally shouldn't happen, since the types are checked in - // genDNSFilterMessage. - log.Error("dns: invalid msg type %s for blocking mode %s", dns.Type(qt), m) - - return s.makeResponse(req) - } + return s.makeResponseCustomIP(req, bIPv4, bIPv6) case filtering.BlockingModeDefault: if len(ips) > 0 { return s.genResponseWithIPs(req, ips) @@ -103,7 +92,28 @@ func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.M case filtering.BlockingModeREFUSED: return s.makeResponseREFUSED(req) default: - log.Error("dns: invalid blocking mode %q", s.dnsFilter.BlockingMode) + log.Error("dns: invalid blocking mode %q", mode) + + return s.makeResponse(req) + } +} + +// makeResponseCustomIP generates a DNS response message for Custom IP blocking +// mode with the provided IP addresses and an appropriate resource record type. +func (s *Server) makeResponseCustomIP( + req *dns.Msg, + bIPv4 netip.Addr, + bIPv6 netip.Addr, +) (resp *dns.Msg) { + switch qt := req.Question[0].Qtype; qt { + case dns.TypeA: + return s.genARecord(req, bIPv4) + case dns.TypeAAAA: + return s.genAAAARecord(req, bIPv6) + default: + // Generally shouldn't happen, since the types are checked in + // genDNSFilterMessage. + log.Error("dns: invalid msg type %s for custom IP blocking mode", dns.Type(qt)) return s.makeResponse(req) } @@ -132,7 +142,7 @@ func (s *Server) hdr(req *dns.Msg, rrType rules.RRType) (h dns.RR_Header) { return dns.RR_Header{ Name: req.Question[0].Name, Rrtype: rrType, - Ttl: s.dnsFilter.BlockedResponseTTL, + Ttl: s.dnsFilter.BlockedResponseTTL(), Class: dns.ClassINET, } } @@ -352,7 +362,7 @@ func (s *Server) genSOA(request *dns.Msg) []dns.RR { Hdr: dns.RR_Header{ Name: zone, Rrtype: dns.TypeSOA, - Ttl: s.dnsFilter.BlockedResponseTTL, + Ttl: s.dnsFilter.BlockedResponseTTL(), Class: dns.ClassINET, }, Mbox: "hostmaster.", // zone will be appended later if it's not empty or "." diff --git a/internal/dnsforward/process.go b/internal/dnsforward/process.go index 09cef976..4780c856 100644 --- a/internal/dnsforward/process.go +++ b/internal/dnsforward/process.go @@ -607,7 +607,7 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) { Rrtype: dns.TypePTR, // TODO(e.burkov): Use [dhcpsvc.Lease.Expiry]. See // https://github.com/AdguardTeam/AdGuardHome/issues/3932. - Ttl: s.dnsFilter.BlockedResponseTTL, + Ttl: s.dnsFilter.BlockedResponseTTL(), Class: dns.ClassINET, }, Ptr: dns.Fqdn(strings.Join([]string{host, s.localDomainSuffix}, ".")), diff --git a/internal/dnsforward/process_internal_test.go b/internal/dnsforward/process_internal_test.go index 8c100de9..168a97a1 100644 --- a/internal/dnsforward/process_internal_test.go +++ b/internal/dnsforward/process_internal_test.go @@ -83,7 +83,9 @@ func TestServer_ProcessInitial(t *testing.T) { }, } - s := createTestServer(t, &filtering.Config{}, c, nil) + s := createTestServer(t, &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, + }, c, nil) var gotAddr netip.Addr s.addrProc = &aghtest.AddressProcessor{ @@ -180,7 +182,9 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) { }, } - s := createTestServer(t, &filtering.Config{}, c, nil) + s := createTestServer(t, &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, + }, c, nil) resp := newResp(dns.RcodeSuccess, tc.req, tc.respAns) dctx := &dnsContext{ @@ -338,11 +342,23 @@ func TestServer_ProcessDDRQuery(t *testing.T) { } } +// createTestDNSFilter returns the minimum valid DNSFilter. +func createTestDNSFilter(t *testing.T) (f *filtering.DNSFilter) { + t.Helper() + + f, err := filtering.New(&filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, + }, []filtering.Filter{}) + require.NoError(t, err) + + return f +} + func prepareTestServer(t *testing.T, portDoH, portDoT, portDoQ int, ddrEnabled bool) (s *Server) { t.Helper() s = &Server{ - dnsFilter: &filtering.DNSFilter{}, + dnsFilter: createTestDNSFilter(t), dnsProxy: &proxy.Proxy{ Config: proxy.Config{}, }, @@ -467,7 +483,7 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { s := &Server{ - dnsFilter: &filtering.DNSFilter{}, + dnsFilter: createTestDNSFilter(t), dhcpServer: dhcp, localDomainSuffix: localDomainSuffix, } @@ -602,7 +618,7 @@ func TestServer_ProcessDHCPHosts(t *testing.T) { } s := &Server{ - dnsFilter: &filtering.DNSFilter{}, + dnsFilter: createTestDNSFilter(t), dhcpServer: testDHCP, localDomainSuffix: tc.suffix, } @@ -673,7 +689,9 @@ func TestServer_ProcessRestrictLocal(t *testing.T) { ), nil }) - s := createTestServer(t, &filtering.Config{}, ServerConfig{ + s := createTestServer(t, &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, + }, ServerConfig{ UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, // TODO(s.chzhen): Add tests where EDNSClientSubnet.Enabled is true. @@ -749,7 +767,9 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) { s := createTestServer( t, - &filtering.Config{}, + &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, + }, ServerConfig{ UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, diff --git a/internal/dnsforward/svcbmsg_test.go b/internal/dnsforward/svcbmsg_test.go index 2804c52a..83645b32 100644 --- a/internal/dnsforward/svcbmsg_test.go +++ b/internal/dnsforward/svcbmsg_test.go @@ -13,9 +13,13 @@ import ( func TestGenAnswerHTTPS_andSVCB(t *testing.T) { // Preconditions. - s := &Server{ - dnsFilter: &filtering.DNSFilter{}, - } + s := createTestServer(t, &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, + }, ServerConfig{ + Config: Config{ + EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + }, + }, nil) req := &dns.Msg{ Question: []dns.Question{{ diff --git a/internal/dnsforward/upstreams.go b/internal/dnsforward/upstreams.go index c08ab057..0debe8d8 100644 --- a/internal/dnsforward/upstreams.go +++ b/internal/dnsforward/upstreams.go @@ -93,7 +93,7 @@ func (s *Server) prepareUpstreamConfig( } // dnsFilter can be nil during application update. - if s.dnsFilter != nil && s.dnsFilter.EtcHosts != nil { + if s.dnsFilter != nil { err = s.replaceUpstreamsWithHosts(uc, opts) if err != nil { return nil, fmt.Errorf("resolving upstreams with hosts: %w", err) @@ -157,7 +157,7 @@ func (s *Server) resolveUpstreamsWithHosts( withIPs, ok := resolved[host] if !ok { - recs := s.dnsFilter.EtcHosts.MatchName(host) + recs := s.dnsFilter.EtcHostsRecords(host) if len(recs) == 0 { resolved[host] = nil diff --git a/internal/filtering/blocked.go b/internal/filtering/blocked.go index 2dba12da..cd388853 100644 --- a/internal/filtering/blocked.go +++ b/internal/filtering/blocked.go @@ -83,12 +83,12 @@ func (s *BlockedServices) Validate() (err error) { // ApplyBlockedServices - set blocked services settings for this DNS request func (d *DNSFilter) ApplyBlockedServices(setts *Settings) { - d.confLock.RLock() - defer d.confLock.RUnlock() + d.confMu.RLock() + defer d.confMu.RUnlock() setts.ServicesRules = []ServiceEntry{} - bsvc := d.BlockedServices + bsvc := d.conf.BlockedServices // TODO(s.chzhen): Use startTime from [dnsforward.dnsContext]. if !bsvc.Schedule.Contains(time.Now()) { @@ -130,9 +130,13 @@ func (d *DNSFilter) handleBlockedServicesAll(w http.ResponseWriter, r *http.Requ // // Deprecated: Use handleBlockedServicesGet. func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Request) { - d.confLock.RLock() - list := d.Config.BlockedServices.IDs - d.confLock.RUnlock() + var list []string + func() { + d.confMu.Lock() + defer d.confMu.Unlock() + + list = d.conf.BlockedServices.IDs + }() aghhttp.WriteJSONResponseOK(w, r, list) } @@ -150,13 +154,15 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ return } - d.confLock.Lock() - d.Config.BlockedServices.IDs = list - d.confLock.Unlock() + func() { + d.confMu.Lock() + defer d.confMu.Unlock() - log.Debug("Updated blocked services list: %d", len(list)) + d.conf.BlockedServices.IDs = list + log.Debug("Updated blocked services list: %d", len(list)) + }() - d.Config.ConfigModified() + d.conf.ConfigModified() } // handleBlockedServicesGet is the handler for the GET @@ -164,10 +170,10 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ func (d *DNSFilter) handleBlockedServicesGet(w http.ResponseWriter, r *http.Request) { var bsvc *BlockedServices func() { - d.confLock.RLock() - defer d.confLock.RUnlock() + d.confMu.RLock() + defer d.confMu.RUnlock() - bsvc = d.Config.BlockedServices.Clone() + bsvc = d.conf.BlockedServices.Clone() }() aghhttp.WriteJSONResponseOK(w, r, bsvc) @@ -196,13 +202,13 @@ func (d *DNSFilter) handleBlockedServicesUpdate(w http.ResponseWriter, r *http.R } func() { - d.confLock.Lock() - defer d.confLock.Unlock() + d.confMu.Lock() + defer d.confMu.Unlock() - d.Config.BlockedServices = bsvc + d.conf.BlockedServices = bsvc }() log.Debug("updated blocked services schedule: %d", len(bsvc.IDs)) - d.Config.ConfigModified() + d.conf.ConfigModified() } diff --git a/internal/filtering/filter.go b/internal/filtering/filter.go index 0d476802..329b6745 100644 --- a/internal/filtering/filter.go +++ b/internal/filtering/filter.go @@ -91,12 +91,12 @@ func (d *DNSFilter) filterSetProperties( newList FilterYAML, isAllowlist bool, ) (shouldRestart bool, err error) { - d.filtersMu.Lock() - defer d.filtersMu.Unlock() + d.conf.filtersMu.Lock() + defer d.conf.filtersMu.Unlock() - filters := d.Filters + filters := d.conf.Filters if isAllowlist { - filters = d.WhitelistFilters + filters = d.conf.WhitelistFilters } i := slices.IndexFunc(filters, func(flt FilterYAML) bool { return flt.URL == listURL }) @@ -162,8 +162,8 @@ func (d *DNSFilter) filterSetProperties( // filterExists returns true if a filter with the same url exists in d. It's // safe for concurrent use. func (d *DNSFilter) filterExists(url string) (ok bool) { - d.filtersMu.RLock() - defer d.filtersMu.RUnlock() + d.conf.filtersMu.RLock() + defer d.conf.filtersMu.RUnlock() r := d.filterExistsLocked(url) @@ -173,13 +173,13 @@ func (d *DNSFilter) filterExists(url string) (ok bool) { // filterExistsLocked returns true if d contains the filter with the same url. // d.filtersMu is expected to be locked. func (d *DNSFilter) filterExistsLocked(url string) (ok bool) { - for _, f := range d.Filters { + for _, f := range d.conf.Filters { if f.URL == url { return true } } - for _, f := range d.WhitelistFilters { + for _, f := range d.conf.WhitelistFilters { if f.URL == url { return true } @@ -194,8 +194,8 @@ func (d *DNSFilter) filterAdd(flt FilterYAML) (err error) { // Defer annotating to unlock sooner. defer func() { err = errors.Annotate(err, "adding filter: %w") }() - d.filtersMu.Lock() - defer d.filtersMu.Unlock() + d.conf.filtersMu.Lock() + defer d.conf.filtersMu.Unlock() // Check for duplicates. if d.filterExistsLocked(flt.URL) { @@ -203,9 +203,9 @@ func (d *DNSFilter) filterAdd(flt FilterYAML) (err error) { } if flt.white { - d.WhitelistFilters = append(d.WhitelistFilters, flt) + d.conf.WhitelistFilters = append(d.conf.WhitelistFilters, flt) } else { - d.Filters = append(d.Filters, flt) + d.conf.Filters = append(d.conf.Filters, flt) } return nil @@ -269,7 +269,7 @@ func (d *DNSFilter) periodicallyRefreshFilters() { ivl := 5 // use a dynamically increasing time interval for { isNetErr, ok := false, false - if d.FiltersUpdateIntervalHours != 0 { + if d.conf.FiltersUpdateIntervalHours != 0 { _, isNetErr, ok = d.tryRefreshFilters(true, true, false) if ok && !isNetErr { ivl = maxInterval @@ -307,8 +307,8 @@ func (d *DNSFilter) tryRefreshFilters(block, allow, force bool) (updated int, is func (d *DNSFilter) listsToUpdate(filters *[]FilterYAML, force bool) (toUpd []FilterYAML) { now := time.Now() - d.filtersMu.RLock() - defer d.filtersMu.RUnlock() + d.conf.filtersMu.RLock() + defer d.conf.filtersMu.RUnlock() for i := range *filters { flt := &(*filters)[i] // otherwise we will be operating on a copy @@ -318,7 +318,7 @@ func (d *DNSFilter) listsToUpdate(filters *[]FilterYAML, force bool) (toUpd []Fi } if !force { - exp := flt.LastUpdated.Add(time.Duration(d.FiltersUpdateIntervalHours) * time.Hour) + exp := flt.LastUpdated.Add(time.Duration(d.conf.FiltersUpdateIntervalHours) * time.Hour) if now.Before(exp) { continue } @@ -364,8 +364,8 @@ func (d *DNSFilter) refreshFiltersArray(filters *[]FilterYAML, force bool) (int, updateCount := 0 - d.filtersMu.Lock() - defer d.filtersMu.Unlock() + d.conf.filtersMu.Lock() + defer d.conf.filtersMu.Unlock() for i := range updateFilters { uf := &updateFilters[i] @@ -427,10 +427,10 @@ func (d *DNSFilter) refreshFiltersIntl(block, allow, force bool) (int, bool) { isNetErr := false if block { - updNum, lists, toUpd, isNetErr = d.refreshFiltersArray(&d.Filters, force) + updNum, lists, toUpd, isNetErr = d.refreshFiltersArray(&d.conf.Filters, force) } if allow { - updNumAl, listsAl, toUpdAl, isNetErrAl := d.refreshFiltersArray(&d.WhitelistFilters, force) + updNumAl, listsAl, toUpdAl, isNetErrAl := d.refreshFiltersArray(&d.conf.WhitelistFilters, force) updNum += updNumAl lists = append(lists, listsAl...) @@ -451,7 +451,7 @@ func (d *DNSFilter) refreshFiltersIntl(block, allow, force bool) (int, bool) { continue } - p := uf.Path(d.DataDir) + p := uf.Path(d.conf.DataDir) err := os.Remove(p + ".old") if err != nil { log.Debug("filtering: removing old filter file %q: %s", p, err) @@ -468,7 +468,7 @@ func (d *DNSFilter) update(filter *FilterYAML) (b bool, err error) { filter.LastUpdated = time.Now() if !b { chErr := os.Chtimes( - filter.Path(d.DataDir), + filter.Path(d.conf.DataDir), filter.LastUpdated, filter.LastUpdated, ) @@ -491,7 +491,7 @@ func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) { // users. // // See https://github.com/AdguardTeam/AdGuardHome/issues/3198. - tmpFile, err := aghrenameio.NewPendingFile(flt.Path(d.DataDir), 0o644) + tmpFile, err := aghrenameio.NewPendingFile(flt.Path(d.conf.DataDir), 0o644) if err != nil { return false, err } @@ -532,7 +532,7 @@ func (d *DNSFilter) finalizeUpdate( return errors.WithDeferred(returned, file.Cleanup()) } - log.Info("filtering: saving contents of filter %d into %q", id, flt.Path(d.DataDir)) + log.Info("filtering: saving contents of filter %d into %q", id, flt.Path(d.conf.DataDir)) err = file.CloseReplace() if err != nil { @@ -572,7 +572,7 @@ func (d *DNSFilter) reader(fltURL string) (r io.ReadCloser, err error) { // readerFromURL returns an io.ReadCloser reading filtering-rule list data form // the filter's URL. func (d *DNSFilter) readerFromURL(fltURL string) (r io.ReadCloser, err error) { - resp, err := d.HTTPClient.Get(fltURL) + resp, err := d.conf.HTTPClient.Get(fltURL) if err != nil { // Don't wrap the error since it's informative enough as is. return nil, err @@ -587,7 +587,7 @@ func (d *DNSFilter) readerFromURL(fltURL string) (r io.ReadCloser, err error) { // loads filter contents from the file in dataDir func (d *DNSFilter) load(flt *FilterYAML) (err error) { - fileName := flt.Path(d.DataDir) + fileName := flt.Path(d.conf.DataDir) log.Debug("filtering: loading filter %d from %q", flt.ID, fileName) @@ -623,39 +623,39 @@ func (d *DNSFilter) load(flt *FilterYAML) (err error) { } func (d *DNSFilter) EnableFilters(async bool) { - d.filtersMu.RLock() - defer d.filtersMu.RUnlock() + d.conf.filtersMu.RLock() + defer d.conf.filtersMu.RUnlock() d.enableFiltersLocked(async) } func (d *DNSFilter) enableFiltersLocked(async bool) { - filters := make([]Filter, 1, len(d.Filters)+len(d.WhitelistFilters)+1) + filters := make([]Filter, 1, len(d.conf.Filters)+len(d.conf.WhitelistFilters)+1) filters[0] = Filter{ ID: CustomListID, - Data: []byte(strings.Join(d.UserRules, "\n")), + Data: []byte(strings.Join(d.conf.UserRules, "\n")), } - for _, filter := range d.Filters { + for _, filter := range d.conf.Filters { if !filter.Enabled { continue } filters = append(filters, Filter{ ID: filter.ID, - FilePath: filter.Path(d.DataDir), + FilePath: filter.Path(d.conf.DataDir), }) } var allowFilters []Filter - for _, filter := range d.WhitelistFilters { + for _, filter := range d.conf.WhitelistFilters { if !filter.Enabled { continue } allowFilters = append(allowFilters, Filter{ ID: filter.ID, - FilePath: filter.Path(d.DataDir), + FilePath: filter.Path(d.conf.DataDir), }) } @@ -664,5 +664,5 @@ func (d *DNSFilter) enableFiltersLocked(async bool) { log.Error("filtering: enabling filters: %s", err) } - d.SetEnabled(d.FilteringEnabled) + d.SetEnabled(d.conf.FilteringEnabled) } diff --git a/internal/filtering/filter_test.go b/internal/filtering/filter_test.go index e613f1b3..229b7a9b 100644 --- a/internal/filtering/filter_test.go +++ b/internal/filtering/filter_test.go @@ -69,9 +69,9 @@ func updateAndAssert( assert.Equal(t, wantRulesCount, f.RulesCount) - dir, err := os.ReadDir(filepath.Join(dnsFilter.DataDir, filterDir)) + dir, err := os.ReadDir(filepath.Join(dnsFilter.conf.DataDir, filterDir)) require.NoError(t, err) - require.FileExists(t, f.Path(dnsFilter.DataDir)) + require.FileExists(t, f.Path(dnsFilter.conf.DataDir)) assert.Len(t, dir, 1) diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index 506db370..c16897c0 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -240,15 +240,6 @@ type DNSFilter struct { rulesStorageAllow *filterlist.RuleStorage filteringEngineAllow *urlfilter.DNSEngine - // Config contains filtering parameters. For direct access by library - // users, even a = assignment. - // - // TODO(d.kolyshev): Remove this embed. - Config - - // confLock protects Config. - confLock sync.RWMutex - safeSearch SafeSearch // safeBrowsingChecker is the safe browsing hash-prefix checker. @@ -259,6 +250,12 @@ type DNSFilter struct { engineLock sync.RWMutex + // confMu protects conf. + confMu *sync.RWMutex + + // conf contains filtering parameters. + conf *Config + // Channel for passing data to filters-initializer goroutine filtersInitializerChan chan filtersInitializerParams filtersInitializerLock sync.Mutex @@ -358,38 +355,38 @@ func (r Reason) In(reasons ...Reason) (ok bool) { return slices.Contains(reasons // SetEnabled sets the status of the *DNSFilter. func (d *DNSFilter) SetEnabled(enabled bool) { - atomic.StoreUint32(&d.enabled, mathutil.BoolToNumber[uint32](enabled)) + atomic.StoreUint32(&d.conf.enabled, mathutil.BoolToNumber[uint32](enabled)) } // Settings returns filtering settings. func (d *DNSFilter) Settings() (s *Settings) { - d.confLock.RLock() - defer d.confLock.RUnlock() + d.confMu.RLock() + defer d.confMu.RUnlock() return &Settings{ - FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) != 0, - SafeSearchEnabled: d.Config.SafeSearchConf.Enabled, - SafeBrowsingEnabled: d.Config.SafeBrowsingEnabled, - ParentalEnabled: d.Config.ParentalEnabled, + FilteringEnabled: atomic.LoadUint32(&d.conf.enabled) != 0, + SafeSearchEnabled: d.conf.SafeSearchConf.Enabled, + SafeBrowsingEnabled: d.conf.SafeBrowsingEnabled, + ParentalEnabled: d.conf.ParentalEnabled, } } // WriteDiskConfig - write configuration func (d *DNSFilter) WriteDiskConfig(c *Config) { func() { - d.confLock.Lock() - defer d.confLock.Unlock() + d.confMu.Lock() + defer d.confMu.Unlock() - *c = d.Config + *c = *d.conf c.Rewrites = cloneRewrites(c.Rewrites) }() - d.filtersMu.RLock() - defer d.filtersMu.RUnlock() + d.conf.filtersMu.RLock() + defer d.conf.filtersMu.RUnlock() - c.Filters = slices.Clone(d.Filters) - c.WhitelistFilters = slices.Clone(d.WhitelistFilters) - c.UserRules = slices.Clone(d.UserRules) + c.Filters = slices.Clone(d.conf.Filters) + c.WhitelistFilters = slices.Clone(d.conf.WhitelistFilters) + c.UserRules = slices.Clone(d.conf.UserRules) } // setFilters sets new filters, synchronously or asynchronously. When filters @@ -461,6 +458,77 @@ func (d *DNSFilter) reset() { } } +// ProtectionStatus returns the status of protection and time until it's +// disabled if so. +func (d *DNSFilter) ProtectionStatus() (status bool, disabledUntil *time.Time) { + d.confMu.RLock() + defer d.confMu.RUnlock() + + return d.conf.ProtectionEnabled, d.conf.ProtectionDisabledUntil +} + +// SetProtectionStatus updates the status of protection and time until it's +// disabled. +func (d *DNSFilter) SetProtectionStatus(status bool, disabledUntil *time.Time) { + d.confMu.Lock() + defer d.confMu.Unlock() + + d.conf.ProtectionEnabled = status + d.conf.ProtectionDisabledUntil = disabledUntil +} + +// SetProtectionEnabled updates the status of protection. +func (d *DNSFilter) SetProtectionEnabled(status bool) { + d.confMu.Lock() + defer d.confMu.Unlock() + + d.conf.ProtectionEnabled = status +} + +// EtcHostsRecords returns the hosts records for the hostname. +func (d *DNSFilter) EtcHostsRecords(hostname string) (recs []*hostsfile.Record) { + if d.conf.EtcHosts != nil { + return d.conf.EtcHosts.MatchName(hostname) + } + + return recs +} + +// SetBlockingMode sets blocking mode properties. +func (d *DNSFilter) SetBlockingMode(mode BlockingMode, bIPv4, bIPv6 netip.Addr) { + d.confMu.Lock() + defer d.confMu.Unlock() + + d.conf.BlockingMode = mode + if mode == BlockingModeCustomIP { + d.conf.BlockingIPv4 = bIPv4 + d.conf.BlockingIPv6 = bIPv6 + } +} + +// BlockingMode returns blocking mode properties. +func (d *DNSFilter) BlockingMode() (mode BlockingMode, bIPv4, bIPv6 netip.Addr) { + d.confMu.RLock() + defer d.confMu.RUnlock() + + return d.conf.BlockingMode, d.conf.BlockingIPv4, d.conf.BlockingIPv6 +} + +// BlockedResponseTTL returns TTL for blocked responses. +func (d *DNSFilter) BlockedResponseTTL() (ttl uint32) { + return d.conf.BlockedResponseTTL +} + +// SafeBrowsingBlockHost returns a host for safe browsing blocked responses. +func (d *DNSFilter) SafeBrowsingBlockHost() (host string) { + return d.conf.SafeBrowsingBlockHost +} + +// ParentalBlockHost returns a host for parental protection blocked responses. +func (d *DNSFilter) ParentalBlockHost() (host string) { + return d.conf.ParentalBlockHost +} + // ResultRule contains information about applied rules. type ResultRule struct { // Text is the text of the rule. @@ -560,14 +628,14 @@ func (d *DNSFilter) matchSysHosts( setts *Settings, ) (res Result, err error) { // TODO(e.burkov): Where else is this checked? - if !setts.FilteringEnabled || d.EtcHosts == nil { + if !setts.FilteringEnabled || d.conf.EtcHosts == nil { return res, nil } var recs []*hostsfile.Record switch qtype { case dns.TypeA, dns.TypeAAAA: - recs = d.EtcHosts.MatchName(host) + recs = d.conf.EtcHosts.MatchName(host) case dns.TypePTR: var ip net.IP ip, err = netutil.IPFromReversedAddr(host) @@ -578,7 +646,7 @@ func (d *DNSFilter) matchSysHosts( } addr, _ := netip.AddrFromSlice(ip) - recs = d.EtcHosts.MatchAddr(addr) + recs = d.conf.EtcHosts.MatchAddr(addr) default: log.Debug("filtering: unsupported query type %s", dns.Type(qtype)) } @@ -618,10 +686,10 @@ func (d *DNSFilter) matchSysHosts( // accordingly. If the found rewrite has a special value of "A" or "AAAA", the // result is an exception. func (d *DNSFilter) processRewrites(host string, qtype uint16) (res Result) { - d.confLock.RLock() - defer d.confLock.RUnlock() + d.confMu.RLock() + defer d.confMu.RUnlock() - rewrites, matched := findRewrites(d.Rewrites, host, qtype) + rewrites, matched := findRewrites(d.conf.Rewrites, host, qtype) if !matched { return Result{} } @@ -661,7 +729,7 @@ func (d *DNSFilter) processRewrites(host string, qtype uint16) (res Result) { cnames.Add(host) res.CanonName = host - rewrites, matched = findRewrites(d.Rewrites, host, qtype) + rewrites, matched = findRewrites(d.conf.Rewrites, host, qtype) } setRewriteResult(&res, host, rewrites, qtype) @@ -992,6 +1060,7 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) { refreshLock: &sync.Mutex{}, safeBrowsingChecker: c.SafeBrowsingChecker, parentalControlChecker: c.ParentalControlChecker, + confMu: &sync.RWMutex{}, } d.safeSearch = c.SafeSearch @@ -1018,16 +1087,16 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) { defer func() { err = errors.Annotate(err, "filtering: %w") }() - d.Config = *c - d.filtersMu = &sync.RWMutex{} + d.conf = c + d.conf.filtersMu = &sync.RWMutex{} err = d.prepareRewrites() if err != nil { return nil, fmt.Errorf("rewrites: preparing: %s", err) } - if d.BlockedServices != nil { - err = d.BlockedServices.Validate() + if d.conf.BlockedServices != nil { + err = d.conf.BlockedServices.Validate() if err != nil { return nil, fmt.Errorf("filtering: %w", err) } @@ -1042,16 +1111,16 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) { } } - _ = os.MkdirAll(filepath.Join(d.DataDir, filterDir), 0o755) + _ = os.MkdirAll(filepath.Join(d.conf.DataDir, filterDir), 0o755) - d.loadFilters(d.Filters) - d.loadFilters(d.WhitelistFilters) + d.loadFilters(d.conf.Filters) + d.loadFilters(d.conf.WhitelistFilters) - d.Filters = deduplicateFilters(d.Filters) - d.WhitelistFilters = deduplicateFilters(d.WhitelistFilters) + d.conf.Filters = deduplicateFilters(d.conf.Filters) + d.conf.WhitelistFilters = deduplicateFilters(d.conf.WhitelistFilters) - updateUniqueFilterID(d.Filters) - updateUniqueFilterID(d.WhitelistFilters) + updateUniqueFilterID(d.conf.Filters) + updateUniqueFilterID(d.conf.WhitelistFilters) return d, nil } diff --git a/internal/filtering/http.go b/internal/filtering/http.go index df3fe95e..ca6b8cf9 100644 --- a/internal/filtering/http.go +++ b/internal/filtering/http.go @@ -124,7 +124,7 @@ func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request return } - d.ConfigModified() + d.conf.ConfigModified() d.EnableFilters(true) _, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount) @@ -149,12 +149,12 @@ func (d *DNSFilter) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ var deleted FilterYAML func() { - d.filtersMu.Lock() - defer d.filtersMu.Unlock() + d.conf.filtersMu.Lock() + defer d.conf.filtersMu.Unlock() - filters := &d.Filters + filters := &d.conf.Filters if req.Whitelist { - filters = &d.WhitelistFilters + filters = &d.conf.WhitelistFilters } delIdx := slices.IndexFunc(*filters, func(flt FilterYAML) bool { @@ -167,7 +167,7 @@ func (d *DNSFilter) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ } deleted = (*filters)[delIdx] - p := deleted.Path(d.DataDir) + p := deleted.Path(d.conf.DataDir) err = os.Rename(p, p+".old") if err != nil && !errors.Is(err, os.ErrNotExist) { log.Error("deleting filter %d: renaming file %q: %s", deleted.ID, p, err) @@ -180,7 +180,7 @@ func (d *DNSFilter) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ log.Info("deleted filter %d", deleted.ID) }() - d.ConfigModified() + d.conf.ConfigModified() d.EnableFilters(true) // NOTE: The old files "filter.txt.old" aren't deleted. It's not really @@ -242,7 +242,7 @@ func (d *DNSFilter) handleFilteringSetURL(w http.ResponseWriter, r *http.Request return } - d.ConfigModified() + d.conf.ConfigModified() if restart { d.EnableFilters(true) } @@ -266,8 +266,8 @@ func (d *DNSFilter) handleFilteringSetRules(w http.ResponseWriter, r *http.Reque return } - d.UserRules = req.Rules - d.ConfigModified() + d.conf.UserRules = req.Rules + d.conf.ConfigModified() d.EnableFilters(true) } @@ -340,19 +340,19 @@ func filterToJSON(f FilterYAML) filterJSON { // Get filtering configuration func (d *DNSFilter) handleFilteringStatus(w http.ResponseWriter, r *http.Request) { resp := filteringConfig{} - d.filtersMu.RLock() - resp.Enabled = d.FilteringEnabled - resp.Interval = d.FiltersUpdateIntervalHours - for _, f := range d.Filters { + d.conf.filtersMu.RLock() + resp.Enabled = d.conf.FilteringEnabled + resp.Interval = d.conf.FiltersUpdateIntervalHours + for _, f := range d.conf.Filters { fj := filterToJSON(f) resp.Filters = append(resp.Filters, fj) } - for _, f := range d.WhitelistFilters { + for _, f := range d.conf.WhitelistFilters { fj := filterToJSON(f) resp.WhitelistFilters = append(resp.WhitelistFilters, fj) } - resp.UserRules = d.UserRules - d.filtersMu.RUnlock() + resp.UserRules = d.conf.UserRules + d.conf.filtersMu.RUnlock() aghhttp.WriteJSONResponseOK(w, r, resp) } @@ -374,14 +374,14 @@ func (d *DNSFilter) handleFilteringConfig(w http.ResponseWriter, r *http.Request } func() { - d.filtersMu.Lock() - defer d.filtersMu.Unlock() + d.conf.filtersMu.Lock() + defer d.conf.filtersMu.Unlock() - d.FilteringEnabled = req.Enabled - d.FiltersUpdateIntervalHours = req.Interval + d.conf.FilteringEnabled = req.Enabled + d.conf.FiltersUpdateIntervalHours = req.Interval }() - d.ConfigModified() + d.conf.ConfigModified() d.EnableFilters(true) } @@ -484,15 +484,15 @@ func protectedBool(mu *sync.RWMutex, ptr *bool) (val bool) { // handleSafeBrowsingEnable is the handler for the POST // /control/safebrowsing/enable HTTP API. func (d *DNSFilter) handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) { - setProtectedBool(&d.confLock, &d.Config.SafeBrowsingEnabled, true) - d.Config.ConfigModified() + setProtectedBool(d.confMu, &d.conf.SafeBrowsingEnabled, true) + d.conf.ConfigModified() } // handleSafeBrowsingDisable is the handler for the POST // /control/safebrowsing/disable HTTP API. func (d *DNSFilter) handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Request) { - setProtectedBool(&d.confLock, &d.Config.SafeBrowsingEnabled, false) - d.Config.ConfigModified() + setProtectedBool(d.confMu, &d.conf.SafeBrowsingEnabled, false) + d.conf.ConfigModified() } // handleSafeBrowsingStatus is the handler for the GET @@ -501,7 +501,7 @@ func (d *DNSFilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Requ resp := &struct { Enabled bool `json:"enabled"` }{ - Enabled: protectedBool(&d.confLock, &d.Config.SafeBrowsingEnabled), + Enabled: protectedBool(d.confMu, &d.conf.SafeBrowsingEnabled), } aghhttp.WriteJSONResponseOK(w, r, resp) @@ -510,15 +510,15 @@ func (d *DNSFilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Requ // handleParentalEnable is the handler for the POST /control/parental/enable // HTTP API. func (d *DNSFilter) handleParentalEnable(w http.ResponseWriter, r *http.Request) { - setProtectedBool(&d.confLock, &d.Config.ParentalEnabled, true) - d.Config.ConfigModified() + setProtectedBool(d.confMu, &d.conf.ParentalEnabled, true) + d.conf.ConfigModified() } // handleParentalDisable is the handler for the POST /control/parental/disable // HTTP API. func (d *DNSFilter) handleParentalDisable(w http.ResponseWriter, r *http.Request) { - setProtectedBool(&d.confLock, &d.Config.ParentalEnabled, false) - d.Config.ConfigModified() + setProtectedBool(d.confMu, &d.conf.ParentalEnabled, false) + d.conf.ConfigModified() } // handleParentalStatus is the handler for the GET /control/parental/status @@ -527,7 +527,7 @@ func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request) resp := &struct { Enabled bool `json:"enabled"` }{ - Enabled: protectedBool(&d.confLock, &d.Config.ParentalEnabled), + Enabled: protectedBool(d.confMu, &d.conf.ParentalEnabled), } aghhttp.WriteJSONResponseOK(w, r, resp) @@ -535,7 +535,7 @@ func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request) // RegisterFilteringHandlers - register handlers func (d *DNSFilter) RegisterFilteringHandlers() { - registerHTTP := d.HTTPRegister + registerHTTP := d.conf.HTTPRegister if registerHTTP == nil { return } diff --git a/internal/filtering/rewritehttp.go b/internal/filtering/rewritehttp.go index b3dc275a..ed34bb4b 100644 --- a/internal/filtering/rewritehttp.go +++ b/internal/filtering/rewritehttp.go @@ -15,22 +15,27 @@ type rewriteEntryJSON struct { Answer string `json:"answer"` } +// handleRewriteList is the handler for the GET /control/rewrite/list HTTP API. func (d *DNSFilter) handleRewriteList(w http.ResponseWriter, r *http.Request) { arr := []*rewriteEntryJSON{} - d.confLock.Lock() - for _, ent := range d.Config.Rewrites { - jsent := rewriteEntryJSON{ - Domain: ent.Domain, - Answer: ent.Answer, + func() { + d.confMu.RLock() + defer d.confMu.RUnlock() + + for _, ent := range d.conf.Rewrites { + jsonEnt := rewriteEntryJSON{ + Domain: ent.Domain, + Answer: ent.Answer, + } + arr = append(arr, &jsonEnt) } - arr = append(arr, &jsent) - } - d.confLock.Unlock() + }() aghhttp.WriteJSONResponseOK(w, r, arr) } +// handleRewriteAdd is the handler for the POST /control/rewrite/add HTTP API. func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) { rwJSON := rewriteEntryJSON{} err := json.NewDecoder(r.Body).Decode(&rwJSON) @@ -54,14 +59,24 @@ func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) { return } - d.confLock.Lock() - d.Config.Rewrites = append(d.Config.Rewrites, rw) - d.confLock.Unlock() - log.Debug("rewrite: added element: %s -> %s [%d]", rw.Domain, rw.Answer, len(d.Config.Rewrites)) + func() { + d.confMu.Lock() + defer d.confMu.Unlock() - d.Config.ConfigModified() + d.conf.Rewrites = append(d.conf.Rewrites, rw) + log.Debug( + "rewrite: added element: %s -> %s [%d]", + rw.Domain, + rw.Answer, + len(d.conf.Rewrites), + ) + }() + + d.conf.ConfigModified() } +// handleRewriteDelete is the handler for the POST /control/rewrite/delete HTTP +// API. func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request) { jsent := rewriteEntryJSON{} err := json.NewDecoder(r.Body).Decode(&jsent) @@ -77,20 +92,23 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request) } arr := []*LegacyRewrite{} - d.confLock.Lock() - for _, ent := range d.Config.Rewrites { - if ent.equal(entDel) { - log.Debug("rewrite: removed element: %s -> %s", ent.Domain, ent.Answer) + func() { + d.confMu.Lock() + defer d.confMu.Unlock() - continue + for _, ent := range d.conf.Rewrites { + if ent.equal(entDel) { + log.Debug("rewrite: removed element: %s -> %s", ent.Domain, ent.Answer) + + continue + } + + arr = append(arr, ent) } + d.conf.Rewrites = arr + }() - arr = append(arr, ent) - } - d.Config.Rewrites = arr - d.confLock.Unlock() - - d.Config.ConfigModified() + d.conf.ConfigModified() } // rewriteUpdateJSON is a struct for JSON object with rewrite rule update info. @@ -132,21 +150,21 @@ func (d *DNSFilter) handleRewriteUpdate(w http.ResponseWriter, r *http.Request) index := -1 defer func() { if index >= 0 { - d.Config.ConfigModified() + d.conf.ConfigModified() } }() - d.confLock.Lock() - defer d.confLock.Unlock() + d.confMu.Lock() + defer d.confMu.Unlock() - index = slices.IndexFunc(d.Config.Rewrites, rwDel.equal) + index = slices.IndexFunc(d.conf.Rewrites, rwDel.equal) if index == -1 { aghhttp.Error(r, w, http.StatusBadRequest, "target rule not found") return } - d.Config.Rewrites = slices.Replace(d.Config.Rewrites, index, index+1, rwAdd) + d.conf.Rewrites = slices.Replace(d.conf.Rewrites, index, index+1, rwAdd) log.Debug("rewrite: removed element: %s -> %s", rwDel.Domain, rwDel.Answer) log.Debug("rewrite: added element: %s -> %s", rwAdd.Domain, rwAdd.Answer) diff --git a/internal/filtering/rewrites.go b/internal/filtering/rewrites.go index a716ecbf..fa1b7774 100644 --- a/internal/filtering/rewrites.go +++ b/internal/filtering/rewrites.go @@ -139,7 +139,7 @@ func (rw *LegacyRewrite) Compare(b *LegacyRewrite) (res int) { // prepareRewrites normalizes and validates all legacy DNS rewrites. func (d *DNSFilter) prepareRewrites() (err error) { - for i, r := range d.Rewrites { + for i, r := range d.conf.Rewrites { err = r.normalize() if err != nil { return fmt.Errorf("at index %d: %w", i, err) diff --git a/internal/filtering/rewrites_test.go b/internal/filtering/rewrites_test.go index 2b201b62..7f80df09 100644 --- a/internal/filtering/rewrites_test.go +++ b/internal/filtering/rewrites_test.go @@ -26,7 +26,7 @@ func TestRewrites(t *testing.T) { addr2v6 = netip.MustParseAddr("1234::5678") ) - d.Rewrites = []*LegacyRewrite{{ + d.conf.Rewrites = []*LegacyRewrite{{ // This one and below are about CNAME, A and AAAA. Domain: "somecname", Answer: "somehost.com", @@ -202,7 +202,7 @@ func TestRewritesLevels(t *testing.T) { d, _ := newForTest(t, nil, nil) t.Cleanup(d.Close) // Exact host, wildcard L2, wildcard L3. - d.Rewrites = []*LegacyRewrite{{ + d.conf.Rewrites = []*LegacyRewrite{{ Domain: "host.com", Answer: "1.1.1.1", Type: dns.TypeA, @@ -249,7 +249,7 @@ func TestRewritesExceptionCNAME(t *testing.T) { d, _ := newForTest(t, nil, nil) t.Cleanup(d.Close) // Wildcard and exception for a sub-domain. - d.Rewrites = []*LegacyRewrite{{ + d.conf.Rewrites = []*LegacyRewrite{{ Domain: "*.host.com", Answer: "2.2.2.2", }, { @@ -300,7 +300,7 @@ func TestRewritesExceptionIP(t *testing.T) { d, _ := newForTest(t, nil, nil) t.Cleanup(d.Close) // Exception for AAAA record. - d.Rewrites = []*LegacyRewrite{{ + d.conf.Rewrites = []*LegacyRewrite{{ Domain: "host.com", Answer: "1.2.3.4", Type: dns.TypeA, diff --git a/internal/filtering/safesearchhttp.go b/internal/filtering/safesearchhttp.go index d1358a2f..eb6fa401 100644 --- a/internal/filtering/safesearchhttp.go +++ b/internal/filtering/safesearchhttp.go @@ -12,8 +12,8 @@ import ( // // Deprecated: Use handleSafeSearchSettings. func (d *DNSFilter) handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) { - setProtectedBool(&d.confLock, &d.Config.SafeSearchConf.Enabled, true) - d.Config.ConfigModified() + setProtectedBool(d.confMu, &d.conf.SafeSearchConf.Enabled, true) + d.conf.ConfigModified() } // handleSafeSearchDisable is the handler for POST /control/safesearch/disable @@ -21,8 +21,8 @@ func (d *DNSFilter) handleSafeSearchEnable(w http.ResponseWriter, r *http.Reques // // Deprecated: Use handleSafeSearchSettings. func (d *DNSFilter) handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) { - setProtectedBool(&d.confLock, &d.Config.SafeSearchConf.Enabled, false) - d.Config.ConfigModified() + setProtectedBool(d.confMu, &d.conf.SafeSearchConf.Enabled, false) + d.conf.ConfigModified() } // handleSafeSearchStatus is the handler for GET /control/safesearch/status @@ -30,10 +30,10 @@ func (d *DNSFilter) handleSafeSearchDisable(w http.ResponseWriter, r *http.Reque func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) { var resp SafeSearchConfig func() { - d.confLock.RLock() - defer d.confLock.RUnlock() + d.confMu.RLock() + defer d.confMu.RUnlock() - resp = d.Config.SafeSearchConf + resp = d.conf.SafeSearchConf }() aghhttp.WriteJSONResponseOK(w, r, resp) @@ -59,13 +59,13 @@ func (d *DNSFilter) handleSafeSearchSettings(w http.ResponseWriter, r *http.Requ } func() { - d.confLock.Lock() - defer d.confLock.Unlock() + d.confMu.Lock() + defer d.confMu.Unlock() - d.Config.SafeSearchConf = conf + d.conf.SafeSearchConf = conf }() - d.Config.ConfigModified() + d.conf.ConfigModified() aghhttp.OK(w) }