mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2025-04-01 23:13:36 +03:00
all: check host
This commit is contained in:
parent
318bd2901a
commit
ce0d6117ad
5 changed files with 108 additions and 1 deletions
internal
|
@ -12,6 +12,7 @@ import (
|
|||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
|
@ -666,3 +667,34 @@ func (s *Storage) ClearUpstreamCache() {
|
|||
|
||||
s.upstreamManager.clearUpstreamCache()
|
||||
}
|
||||
|
||||
// ApplyClientFiltering retrieves persistent client information using the
|
||||
// ClientID or client IP address, and applies it to the filtering settings.
|
||||
func (s *Storage) ApplyClientFiltering(id string, addr netip.Addr, setts *filtering.Settings) {
|
||||
c, ok := s.index.findByClientID(id)
|
||||
if !ok {
|
||||
c, ok = s.index.findByIP(addr)
|
||||
}
|
||||
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
setts.ClientIP = addr
|
||||
|
||||
if c.UseOwnBlockedServices {
|
||||
setts.BlockedServices = c.BlockedServices.Clone()
|
||||
}
|
||||
|
||||
setts.ClientName = c.Name
|
||||
setts.ClientTags = slices.Clone(c.Tags)
|
||||
if !c.UseOwnSettings {
|
||||
return
|
||||
}
|
||||
|
||||
setts.FilteringEnabled = c.FilteringEnabled
|
||||
setts.SafeSearchEnabled = c.SafeSearchConf.Enabled
|
||||
setts.ClientSafeSearch = c.SafeSearch
|
||||
setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled
|
||||
setts.ParentalEnabled = c.ParentalEnabled
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
@ -629,3 +630,18 @@ func (d *DNSFilter) enableFiltersLocked(async bool) {
|
|||
|
||||
d.SetEnabled(d.conf.FilteringEnabled)
|
||||
}
|
||||
|
||||
// applyAdditionalFiltering adds additional client information and settings if
|
||||
// the client has them.
|
||||
func (d *DNSFilter) applyAdditionalFiltering(cliAddr netip.Addr, clientID string, setts *Settings) {
|
||||
d.applyClientFiltering(clientID, cliAddr, setts)
|
||||
|
||||
if setts.BlockedServices != nil {
|
||||
// TODO(e.burkov): Get rid of this crutch.
|
||||
setts.ServicesRules = nil
|
||||
svcs := setts.BlockedServices.IDs
|
||||
if !setts.BlockedServices.Schedule.Contains(time.Now()) {
|
||||
d.ApplyBlockedServicesList(setts, svcs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,6 +47,8 @@ type Settings struct {
|
|||
|
||||
ServicesRules []ServiceEntry
|
||||
|
||||
BlockedServices *BlockedServices
|
||||
|
||||
ProtectionEnabled bool
|
||||
FilteringEnabled bool
|
||||
SafeSearchEnabled bool
|
||||
|
@ -78,6 +80,10 @@ type Config struct {
|
|||
|
||||
SafeSearch SafeSearch `yaml:"-"`
|
||||
|
||||
// ApplyClientFiltering retrieves persistent client information using the
|
||||
// ClientID or client IP address, and applies it to the filtering settings.
|
||||
ApplyClientFiltering func(clientID string, cliAddr netip.Addr, setts *Settings) `yaml:"-"`
|
||||
|
||||
// BlockedServices is the configuration of blocked services.
|
||||
// Per-client settings can override this configuration.
|
||||
BlockedServices *BlockedServices `yaml:"blocked_services"`
|
||||
|
@ -244,6 +250,13 @@ type DNSFilter struct {
|
|||
// parentalControl is the parental control hash-prefix checker.
|
||||
parentalControlChecker Checker
|
||||
|
||||
// applyClientFiltering retrieves persistent client information using the
|
||||
// ClientID or client IP address, and applies it to the filtering settings.
|
||||
//
|
||||
// TODO(s.chzhen): !! Consider finding a better approach while taking an
|
||||
// import cycle into account.
|
||||
applyClientFiltering func(clientID string, cliAddr netip.Addr, setts *Settings)
|
||||
|
||||
engineLock sync.RWMutex
|
||||
|
||||
// confMu protects conf.
|
||||
|
@ -998,6 +1011,7 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
|||
refreshLock: &sync.Mutex{},
|
||||
safeBrowsingChecker: c.SafeBrowsingChecker,
|
||||
parentalControlChecker: c.ParentalControlChecker,
|
||||
applyClientFiltering: c.ApplyClientFiltering,
|
||||
confMu: &sync.RWMutex{},
|
||||
}
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -422,13 +423,35 @@ type checkHostResp struct {
|
|||
|
||||
func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
host := r.URL.Query().Get("name")
|
||||
cli := r.URL.Query().Get("client")
|
||||
qTypeStr := r.URL.Query().Get("qtype")
|
||||
qType, err := stringToDNSType(qTypeStr)
|
||||
if err != nil {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusUnprocessableEntity,
|
||||
"bad qtype query parameter: %s",
|
||||
qTypeStr,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
setts := d.Settings()
|
||||
setts.FilteringEnabled = true
|
||||
setts.ProtectionEnabled = true
|
||||
|
||||
d.ApplyBlockedServices(setts)
|
||||
result, err := d.CheckHost(host, dns.TypeA, setts)
|
||||
|
||||
addr, err := netip.ParseAddr(cli)
|
||||
if err == nil {
|
||||
d.applyAdditionalFiltering(addr, "", setts)
|
||||
} else if cli != "" {
|
||||
d.applyAdditionalFiltering(netip.Addr{}, cli, setts)
|
||||
}
|
||||
|
||||
result, err := d.CheckHost(host, qType, setts)
|
||||
if err != nil {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
|
@ -466,6 +489,26 @@ func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
|||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
}
|
||||
|
||||
// stringToDNSType is a helper function that converts a string to DNS type. If
|
||||
// the string is empty, it returns the default value [dns.TypeA].
|
||||
func stringToDNSType(str string) (qtype uint16, err error) {
|
||||
if str == "" {
|
||||
return dns.TypeA, nil
|
||||
}
|
||||
|
||||
qtype, ok := dns.StringToType[str]
|
||||
if ok {
|
||||
return qtype, nil
|
||||
}
|
||||
|
||||
val, err := strconv.ParseUint(str, 10, 16)
|
||||
if err == nil {
|
||||
return uint16(val), nil
|
||||
}
|
||||
|
||||
return 0, errors.ErrBadEnumValue
|
||||
}
|
||||
|
||||
// setProtectedBool sets the value of a boolean pointer under a lock. l must
|
||||
// protect the value under ptr.
|
||||
//
|
||||
|
|
|
@ -119,6 +119,8 @@ func (clients *clientsContainer) Init(
|
|||
|
||||
sigHdlr.addClientStorage(clients.storage)
|
||||
|
||||
filteringConf.ApplyClientFiltering = clients.storage.ApplyClientFiltering
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue