mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-24 22:15:45 +03:00
+ use per-client DNS servers
This commit is contained in:
parent
b6b26c0681
commit
7313c3bc53
7 changed files with 58 additions and 6 deletions
|
@ -669,6 +669,7 @@ Response:
|
|||
key: "value"
|
||||
...
|
||||
}
|
||||
upstreams: ["upstream1", ...]
|
||||
}
|
||||
]
|
||||
auto_clients: [
|
||||
|
@ -703,6 +704,7 @@ Request:
|
|||
safesearch_enabled: false
|
||||
use_global_blocked_services: true
|
||||
blocked_services: [ "name1", ... ]
|
||||
upstreams: ["upstream1", ...]
|
||||
}
|
||||
|
||||
Response:
|
||||
|
@ -732,6 +734,7 @@ Request:
|
|||
safesearch_enabled: false
|
||||
use_global_blocked_services: true
|
||||
blocked_services: [ "name1", ... ]
|
||||
upstreams: ["upstream1", ...]
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -94,6 +94,9 @@ type FilteringConfig struct {
|
|||
// Filtering callback function
|
||||
FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
|
||||
|
||||
// This callback function returns the list of upstream servers for a client specified by IP address
|
||||
GetUpstreamsByClient func(clientAddr string) []string `yaml:"-"`
|
||||
|
||||
ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features
|
||||
|
||||
BlockingMode string `yaml:"blocking_mode"` // mode how to answer filtered requests
|
||||
|
@ -393,6 +396,19 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
|||
d.Req.Question[0].Name = dns.Fqdn(res.CanonName)
|
||||
}
|
||||
|
||||
if d.Addr != nil && s.conf.GetUpstreamsByClient != nil {
|
||||
clientIP, _, _ := net.SplitHostPort(d.Addr.String())
|
||||
upstreams := s.conf.GetUpstreamsByClient(clientIP)
|
||||
for _, us := range upstreams {
|
||||
u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: 30 * time.Second})
|
||||
if err != nil {
|
||||
log.Error("upstream.AddressToUpstream: %s: %s", us, err)
|
||||
continue
|
||||
}
|
||||
d.Upstreams = append(d.Upstreams, u)
|
||||
}
|
||||
}
|
||||
|
||||
// request was not filtered so let it be processed further
|
||||
err = p.Resolve(d)
|
||||
if err != nil {
|
||||
|
|
|
@ -44,7 +44,7 @@ func (s *Server) handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request)
|
|||
return
|
||||
}
|
||||
|
||||
err = validateUpstreams(req.Upstreams)
|
||||
err = ValidateUpstreams(req.Upstreams)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "wrong upstreams specification: %s", err)
|
||||
return
|
||||
|
@ -78,8 +78,8 @@ func (s *Server) handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
}
|
||||
|
||||
// validateUpstreams validates each upstream and returns an error if any upstream is invalid or if there are no default upstreams specified
|
||||
func validateUpstreams(upstreams []string) error {
|
||||
// ValidateUpstreams validates each upstream and returns an error if any upstream is invalid or if there are no default upstreams specified
|
||||
func ValidateUpstreams(upstreams []string) error {
|
||||
var defaultUpstreamFound bool
|
||||
for _, u := range upstreams {
|
||||
d, err := validateUpstream(u)
|
||||
|
|
|
@ -762,21 +762,21 @@ func TestValidateUpstreamsSet(t *testing.T) {
|
|||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
}
|
||||
err := validateUpstreams(upstreamsSet)
|
||||
err := ValidateUpstreams(upstreamsSet)
|
||||
if err == nil {
|
||||
t.Fatalf("there is no default upstream")
|
||||
}
|
||||
|
||||
// Let's add default upstream
|
||||
upstreamsSet = append(upstreamsSet, "8.8.8.8")
|
||||
err = validateUpstreams(upstreamsSet)
|
||||
err = ValidateUpstreams(upstreamsSet)
|
||||
if err != nil {
|
||||
t.Fatalf("upstreams set is valid, but doesn't pass through validation cause: %s", err)
|
||||
}
|
||||
|
||||
// Let's add invalid upstream
|
||||
upstreamsSet = append(upstreamsSet, "dhcp://fake.dns")
|
||||
err = validateUpstreams(upstreamsSet)
|
||||
err = ValidateUpstreams(upstreamsSet)
|
||||
if err == nil {
|
||||
t.Fatalf("there is an invalid upstream in set, but it pass through validation")
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/utils"
|
||||
)
|
||||
|
@ -34,6 +35,8 @@ type Client struct {
|
|||
|
||||
UseOwnBlockedServices bool // false: use global settings
|
||||
BlockedServices []string
|
||||
|
||||
Upstreams []string // list of upstream servers to be used for the client's requests
|
||||
}
|
||||
|
||||
type clientSource uint
|
||||
|
@ -96,6 +99,8 @@ type clientObject struct {
|
|||
|
||||
UseGlobalBlockedServices bool `yaml:"use_global_blocked_services"`
|
||||
BlockedServices []string `yaml:"blocked_services"`
|
||||
|
||||
Upstreams []string `yaml:"upstreams"`
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) addFromConfig(objects []clientObject) {
|
||||
|
@ -111,6 +116,8 @@ func (clients *clientsContainer) addFromConfig(objects []clientObject) {
|
|||
|
||||
UseOwnBlockedServices: !cy.UseGlobalBlockedServices,
|
||||
BlockedServices: cy.BlockedServices,
|
||||
|
||||
Upstreams: cy.Upstreams,
|
||||
}
|
||||
_, err := clients.Add(cli)
|
||||
if err != nil {
|
||||
|
@ -134,6 +141,8 @@ func (clients *clientsContainer) WriteDiskConfig(objects *[]clientObject) {
|
|||
|
||||
UseGlobalBlockedServices: !cli.UseOwnBlockedServices,
|
||||
BlockedServices: cli.BlockedServices,
|
||||
|
||||
Upstreams: cli.Upstreams,
|
||||
}
|
||||
*objects = append(*objects, cy)
|
||||
}
|
||||
|
@ -268,6 +277,14 @@ func (c *Client) check() error {
|
|||
|
||||
return fmt.Errorf("Invalid ID: %s", id)
|
||||
}
|
||||
|
||||
if len(c.Upstreams) != 0 {
|
||||
err := dnsforward.ValidateUpstreams(c.Upstreams)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Invalid upstream servers: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,8 @@ type clientJSON struct {
|
|||
|
||||
UseGlobalBlockedServices bool `json:"use_global_blocked_services"`
|
||||
BlockedServices []string `json:"blocked_services"`
|
||||
|
||||
Upstreams []string `json:"upstreams"`
|
||||
}
|
||||
|
||||
type clientHostJSON struct {
|
||||
|
@ -92,6 +94,8 @@ func jsonToClient(cj clientJSON) (*Client, error) {
|
|||
|
||||
UseOwnBlockedServices: !cj.UseGlobalBlockedServices,
|
||||
BlockedServices: cj.BlockedServices,
|
||||
|
||||
Upstreams: cj.Upstreams,
|
||||
}
|
||||
return &c, nil
|
||||
}
|
||||
|
@ -109,6 +113,8 @@ func clientToJSON(c *Client) clientJSON {
|
|||
|
||||
UseGlobalBlockedServices: !c.UseOwnBlockedServices,
|
||||
BlockedServices: c.BlockedServices,
|
||||
|
||||
Upstreams: c.Upstreams,
|
||||
}
|
||||
|
||||
cj.WhoisInfo = make(map[string]interface{})
|
||||
|
|
10
home/dns.go
10
home/dns.go
|
@ -170,9 +170,19 @@ func generateServerConfig() (dnsforward.ServerConfig, error) {
|
|||
}
|
||||
|
||||
newconfig.FilterHandler = applyAdditionalFiltering
|
||||
newconfig.GetUpstreamsByClient = getUpstreamsByClient
|
||||
return newconfig, nil
|
||||
}
|
||||
|
||||
func getUpstreamsByClient(clientAddr string) []string {
|
||||
c, ok := config.clients.Find(clientAddr)
|
||||
if !ok {
|
||||
return []string{}
|
||||
}
|
||||
log.Debug("Using upstreams %v for client %s (IP: %s)", c.Upstreams, c.Name, clientAddr)
|
||||
return c.Upstreams
|
||||
}
|
||||
|
||||
// If a client has his own settings, apply them
|
||||
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
|
||||
|
||||
|
|
Loading…
Reference in a new issue