all: check host

This commit is contained in:
Stanislav Chzhen 2025-03-04 15:13:06 +03:00
parent 318bd2901a
commit ce0d6117ad
5 changed files with 108 additions and 1 deletions

View file

@ -12,6 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
@ -666,3 +667,34 @@ func (s *Storage) ClearUpstreamCache() {
s.upstreamManager.clearUpstreamCache()
}
// ApplyClientFiltering retrieves persistent client information using the
// ClientID or client IP address, and applies it to the filtering settings.
func (s *Storage) ApplyClientFiltering(id string, addr netip.Addr, setts *filtering.Settings) {
c, ok := s.index.findByClientID(id)
if !ok {
c, ok = s.index.findByIP(addr)
}
if !ok {
return
}
setts.ClientIP = addr
if c.UseOwnBlockedServices {
setts.BlockedServices = c.BlockedServices.Clone()
}
setts.ClientName = c.Name
setts.ClientTags = slices.Clone(c.Tags)
if !c.UseOwnSettings {
return
}
setts.FilteringEnabled = c.FilteringEnabled
setts.SafeSearchEnabled = c.SafeSearchConf.Enabled
setts.ClientSafeSearch = c.SafeSearch
setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled
setts.ParentalEnabled = c.ParentalEnabled
}

View file

@ -4,6 +4,7 @@ import (
"fmt"
"io"
"net/http"
"net/netip"
"os"
"path/filepath"
"slices"
@ -629,3 +630,18 @@ func (d *DNSFilter) enableFiltersLocked(async bool) {
d.SetEnabled(d.conf.FilteringEnabled)
}
// applyAdditionalFiltering adds additional client information and settings if
// the client has them.
func (d *DNSFilter) applyAdditionalFiltering(cliAddr netip.Addr, clientID string, setts *Settings) {
d.applyClientFiltering(clientID, cliAddr, setts)
if setts.BlockedServices != nil {
// TODO(e.burkov): Get rid of this crutch.
setts.ServicesRules = nil
svcs := setts.BlockedServices.IDs
if !setts.BlockedServices.Schedule.Contains(time.Now()) {
d.ApplyBlockedServicesList(setts, svcs)
}
}
}

View file

@ -47,6 +47,8 @@ type Settings struct {
ServicesRules []ServiceEntry
BlockedServices *BlockedServices
ProtectionEnabled bool
FilteringEnabled bool
SafeSearchEnabled bool
@ -78,6 +80,10 @@ type Config struct {
SafeSearch SafeSearch `yaml:"-"`
// ApplyClientFiltering retrieves persistent client information using the
// ClientID or client IP address, and applies it to the filtering settings.
ApplyClientFiltering func(clientID string, cliAddr netip.Addr, setts *Settings) `yaml:"-"`
// BlockedServices is the configuration of blocked services.
// Per-client settings can override this configuration.
BlockedServices *BlockedServices `yaml:"blocked_services"`
@ -244,6 +250,13 @@ type DNSFilter struct {
// parentalControl is the parental control hash-prefix checker.
parentalControlChecker Checker
// applyClientFiltering retrieves persistent client information using the
// ClientID or client IP address, and applies it to the filtering settings.
//
// TODO(s.chzhen): !! Consider finding a better approach while taking an
// import cycle into account.
applyClientFiltering func(clientID string, cliAddr netip.Addr, setts *Settings)
engineLock sync.RWMutex
// confMu protects conf.
@ -998,6 +1011,7 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
refreshLock: &sync.Mutex{},
safeBrowsingChecker: c.SafeBrowsingChecker,
parentalControlChecker: c.ParentalControlChecker,
applyClientFiltering: c.ApplyClientFiltering,
confMu: &sync.RWMutex{},
}

View file

@ -9,6 +9,7 @@ import (
"os"
"path/filepath"
"slices"
"strconv"
"sync"
"time"
@ -422,13 +423,35 @@ type checkHostResp struct {
func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
host := r.URL.Query().Get("name")
cli := r.URL.Query().Get("client")
qTypeStr := r.URL.Query().Get("qtype")
qType, err := stringToDNSType(qTypeStr)
if err != nil {
aghhttp.Error(
r,
w,
http.StatusUnprocessableEntity,
"bad qtype query parameter: %s",
qTypeStr,
)
return
}
setts := d.Settings()
setts.FilteringEnabled = true
setts.ProtectionEnabled = true
d.ApplyBlockedServices(setts)
result, err := d.CheckHost(host, dns.TypeA, setts)
addr, err := netip.ParseAddr(cli)
if err == nil {
d.applyAdditionalFiltering(addr, "", setts)
} else if cli != "" {
d.applyAdditionalFiltering(netip.Addr{}, cli, setts)
}
result, err := d.CheckHost(host, qType, setts)
if err != nil {
aghhttp.Error(
r,
@ -466,6 +489,26 @@ func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
aghhttp.WriteJSONResponseOK(w, r, resp)
}
// stringToDNSType is a helper function that converts a string to DNS type. If
// the string is empty, it returns the default value [dns.TypeA].
func stringToDNSType(str string) (qtype uint16, err error) {
if str == "" {
return dns.TypeA, nil
}
qtype, ok := dns.StringToType[str]
if ok {
return qtype, nil
}
val, err := strconv.ParseUint(str, 10, 16)
if err == nil {
return uint16(val), nil
}
return 0, errors.ErrBadEnumValue
}
// setProtectedBool sets the value of a boolean pointer under a lock. l must
// protect the value under ptr.
//

View file

@ -119,6 +119,8 @@ func (clients *clientsContainer) Init(
sigHdlr.addClientStorage(clients.storage)
filteringConf.ApplyClientFiltering = clients.storage.ApplyClientFiltering
return nil
}