* dnsforward: create dnsfilter asynchronously

This commit is contained in:
Simon Zolin 2019-09-20 13:48:22 +03:00
parent f6404ef181
commit 75b864f25e
2 changed files with 79 additions and 5 deletions

View file

@ -44,6 +44,12 @@ type Server struct {
queryLog querylog.QueryLog // Query log instance
stats stats.Stats
// How many times the server was started
// While creating a dnsfilter object,
// we use this value to set s.dnsFilter property only with the most recent settings.
startCounter uint32
dnsfilterCreatorChan chan dnsfilterCreatorParams
AllowedClients map[string]bool // IP addresses of whitelist clients
DisallowedClients map[string]bool // IP addresses of clients that should be blocked
AllowedClientsIPNet []net.IPNet // CIDRs of whitelist clients
@ -54,6 +60,11 @@ type Server struct {
conf ServerConfig
}
type dnsfilterCreatorParams struct {
conf dnsfilter.Config
filters map[int]string
}
// NewServer creates a new instance of the dnsforward.Server
// Note: this function must be called only once
func NewServer(stats stats.Stats, queryLog querylog.QueryLog) *Server {
@ -73,6 +84,12 @@ func (s *Server) Close() {
// FilteringConfig represents the DNS filtering configuration of AdGuard Home
// The zero FilteringConfig is empty and ready for use.
type FilteringConfig struct {
// Create dnsfilter asynchronously.
// Requests won't be filtered until dnsfilter is created.
// If "restart" command is received while we're creating an old dnsfilter object,
// we delay creation of the new object until the old one is created.
AsyncStartup bool
ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
@ -254,8 +271,6 @@ func (s *Server) startInternal(config *ServerConfig) error {
// Initializes the DNS filter
func (s *Server) initDNSFilter(config *ServerConfig) error {
log.Tracef("Creating dnsfilter")
if config != nil {
s.conf = *config
}
@ -280,13 +295,71 @@ func (s *Server) initDNSFilter(config *ServerConfig) error {
s.conf.SafeBrowsingBlockHost = safeBrowsingBlockHost
}
s.dnsFilter = dnsfilter.New(&s.conf.Config, filters)
if s.dnsFilter == nil {
return fmt.Errorf("could not initialize dnsfilter")
if s.conf.AsyncStartup {
params := dnsfilterCreatorParams{
conf: s.conf.Config,
filters: filters,
}
s.startCounter++
if s.startCounter == 1 {
s.dnsfilterCreatorChan = make(chan dnsfilterCreatorParams, 1)
go s.dnsfilterCreator()
}
// remove all pending tasks
stop := false
for !stop {
select {
case <-s.dnsfilterCreatorChan:
//
default:
stop = true
}
}
s.dnsfilterCreatorChan <- params
} else {
log.Debug("creating dnsfilter...")
f := dnsfilter.New(&s.conf.Config, filters)
if f == nil {
return fmt.Errorf("could not initialize dnsfilter")
}
log.Debug("created dnsfilter")
s.dnsFilter = f
}
return nil
}
func (s *Server) dnsfilterCreator() {
for {
params := <-s.dnsfilterCreatorChan
s.Lock()
counter := s.startCounter
s.Unlock()
log.Debug("creating dnsfilter...")
f := dnsfilter.New(&params.conf, params.filters)
if f == nil {
log.Error("could not initialize dnsfilter")
continue
}
set := false
s.Lock()
if counter == s.startCounter {
s.dnsFilter = f
set = true
}
s.Unlock()
if set {
log.Debug("created and activated dnsfilter")
} else {
log.Debug("created dnsfilter")
}
}
}
// Stop stops the DNS server
func (s *Server) Stop() error {
s.Lock()

View file

@ -104,6 +104,7 @@ func generateServerConfig() (dnsforward.ServerConfig, error) {
FilteringConfig: config.DNS.FilteringConfig,
Filters: filters,
}
newconfig.AsyncStartup = true
bindhost := config.DNS.BindHost
if config.DNS.BindHost == "0.0.0.0" {
bindhost = "127.0.0.1"