diff --git a/internal/client/storage.go b/internal/client/storage.go index fbbfd1b8..d0e9e7fe 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -12,6 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/arpdb" "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" + "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/errors" @@ -666,3 +667,34 @@ func (s *Storage) ClearUpstreamCache() { s.upstreamManager.clearUpstreamCache() } + +// ApplyClientFiltering retrieves persistent client information using the +// ClientID or client IP address, and applies it to the filtering settings. +func (s *Storage) ApplyClientFiltering(id string, addr netip.Addr, setts *filtering.Settings) { + c, ok := s.index.findByClientID(id) + if !ok { + c, ok = s.index.findByIP(addr) + } + + if !ok { + return + } + + setts.ClientIP = addr + + if c.UseOwnBlockedServices { + setts.BlockedServices = c.BlockedServices.Clone() + } + + setts.ClientName = c.Name + setts.ClientTags = slices.Clone(c.Tags) + if !c.UseOwnSettings { + return + } + + setts.FilteringEnabled = c.FilteringEnabled + setts.SafeSearchEnabled = c.SafeSearchConf.Enabled + setts.ClientSafeSearch = c.SafeSearch + setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled + setts.ParentalEnabled = c.ParentalEnabled +} diff --git a/internal/filtering/filter.go b/internal/filtering/filter.go index 0dd3471c..d485ed0d 100644 --- a/internal/filtering/filter.go +++ b/internal/filtering/filter.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net/http" + "net/netip" "os" "path/filepath" "slices" @@ -629,3 +630,18 @@ func (d *DNSFilter) enableFiltersLocked(async bool) { d.SetEnabled(d.conf.FilteringEnabled) } + +// applyAdditionalFiltering adds additional client information and settings if +// the client has them. +func (d *DNSFilter) applyAdditionalFiltering(cliAddr netip.Addr, clientID string, setts *Settings) { + d.applyClientFiltering(clientID, cliAddr, setts) + + if setts.BlockedServices != nil { + // TODO(e.burkov): Get rid of this crutch. + setts.ServicesRules = nil + svcs := setts.BlockedServices.IDs + if !setts.BlockedServices.Schedule.Contains(time.Now()) { + d.ApplyBlockedServicesList(setts, svcs) + } + } +} diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index 8836515c..941502a4 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -47,6 +47,8 @@ type Settings struct { ServicesRules []ServiceEntry + BlockedServices *BlockedServices + ProtectionEnabled bool FilteringEnabled bool SafeSearchEnabled bool @@ -78,6 +80,10 @@ type Config struct { SafeSearch SafeSearch `yaml:"-"` + // ApplyClientFiltering retrieves persistent client information using the + // ClientID or client IP address, and applies it to the filtering settings. + ApplyClientFiltering func(clientID string, cliAddr netip.Addr, setts *Settings) `yaml:"-"` + // BlockedServices is the configuration of blocked services. // Per-client settings can override this configuration. BlockedServices *BlockedServices `yaml:"blocked_services"` @@ -244,6 +250,13 @@ type DNSFilter struct { // parentalControl is the parental control hash-prefix checker. parentalControlChecker Checker + // applyClientFiltering retrieves persistent client information using the + // ClientID or client IP address, and applies it to the filtering settings. + // + // TODO(s.chzhen): !! Consider finding a better approach while taking an + // import cycle into account. + applyClientFiltering func(clientID string, cliAddr netip.Addr, setts *Settings) + engineLock sync.RWMutex // confMu protects conf. @@ -998,6 +1011,7 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) { refreshLock: &sync.Mutex{}, safeBrowsingChecker: c.SafeBrowsingChecker, parentalControlChecker: c.ParentalControlChecker, + applyClientFiltering: c.ApplyClientFiltering, confMu: &sync.RWMutex{}, } diff --git a/internal/filtering/http.go b/internal/filtering/http.go index 94a601f5..e3aa33b1 100644 --- a/internal/filtering/http.go +++ b/internal/filtering/http.go @@ -9,6 +9,7 @@ import ( "os" "path/filepath" "slices" + "strconv" "sync" "time" @@ -422,13 +423,35 @@ type checkHostResp struct { func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) { host := r.URL.Query().Get("name") + cli := r.URL.Query().Get("client") + qTypeStr := r.URL.Query().Get("qtype") + qType, err := stringToDNSType(qTypeStr) + if err != nil { + aghhttp.Error( + r, + w, + http.StatusUnprocessableEntity, + "bad qtype query parameter: %s", + qTypeStr, + ) + + return + } setts := d.Settings() setts.FilteringEnabled = true setts.ProtectionEnabled = true d.ApplyBlockedServices(setts) - result, err := d.CheckHost(host, dns.TypeA, setts) + + addr, err := netip.ParseAddr(cli) + if err == nil { + d.applyAdditionalFiltering(addr, "", setts) + } else if cli != "" { + d.applyAdditionalFiltering(netip.Addr{}, cli, setts) + } + + result, err := d.CheckHost(host, qType, setts) if err != nil { aghhttp.Error( r, @@ -466,6 +489,26 @@ func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) { aghhttp.WriteJSONResponseOK(w, r, resp) } +// stringToDNSType is a helper function that converts a string to DNS type. If +// the string is empty, it returns the default value [dns.TypeA]. +func stringToDNSType(str string) (qtype uint16, err error) { + if str == "" { + return dns.TypeA, nil + } + + qtype, ok := dns.StringToType[str] + if ok { + return qtype, nil + } + + val, err := strconv.ParseUint(str, 10, 16) + if err == nil { + return uint16(val), nil + } + + return 0, errors.ErrBadEnumValue +} + // setProtectedBool sets the value of a boolean pointer under a lock. l must // protect the value under ptr. // diff --git a/internal/home/clients.go b/internal/home/clients.go index e2fd62fb..92f2a4a9 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -119,6 +119,8 @@ func (clients *clientsContainer) Init( sigHdlr.addClientStorage(clients.storage) + filteringConf.ApplyClientFiltering = clients.storage.ApplyClientFiltering + return nil }