diff --git a/ci.sh b/ci.sh index 362c61d0..544f0109 100755 --- a/ci.sh +++ b/ci.sh @@ -16,14 +16,14 @@ golangci-lint --version # Run linter golangci-lint run -# Run tests -go test -race -v -bench=. -coverprofile=coverage.txt -covermode=atomic ./... - # Make make clean make build/static/index.html make +# Run tests +go test -race -v -bench=. -coverprofile=coverage.txt -covermode=atomic ./... + # if [[ -z "$(git status --porcelain)" ]]; then # # Working directory clean # echo "Git status is clean" diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index afd0149e..404d1ca0 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -2,7 +2,6 @@ package dnsforward import ( "crypto/tls" - "errors" "fmt" "net" "net/http" @@ -51,7 +50,12 @@ type Server struct { stats stats.Stats access *accessCtx + // DNS proxy instance for internal usage + // We don't Start() it and so no listen port is required. + internalProxy *proxy.Proxy + webRegistered bool + isRunning bool sync.RWMutex conf ServerConfig @@ -78,6 +82,7 @@ func (s *Server) Close() { s.dnsFilter = nil s.stats = nil s.queryLog = nil + s.dnsProxy = nil s.Unlock() } @@ -165,28 +170,54 @@ var defaultValues = ServerConfig{ FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600}, } +// Resolve - get IP addresses by host name from an upstream server. +// No request/response filtering is performed. +// Query log and Stats are not updated. +// This method may be called before Start(). +func (s *Server) Resolve(host string) ([]net.IPAddr, error) { + s.RLock() + defer s.RUnlock() + return s.internalProxy.LookupIPAddr(host) +} + +// Exchange - send DNS request to an upstream server and receive response +// No request/response filtering is performed. +// Query log and Stats are not updated. +// This method may be called before Start(). +func (s *Server) Exchange(req *dns.Msg) (*dns.Msg, error) { + s.RLock() + defer s.RUnlock() + + ctx := &proxy.DNSContext{ + Proto: "udp", + Req: req, + StartTime: time.Now(), + } + err := s.internalProxy.Resolve(ctx) + if err != nil { + return nil, err + } + return ctx.Res, nil +} + // Start starts the DNS server -func (s *Server) Start(config *ServerConfig) error { +func (s *Server) Start() error { s.Lock() defer s.Unlock() - return s.startInternal(config) + return s.startInternal() } // startInternal starts without locking -func (s *Server) startInternal(config *ServerConfig) error { - err := s.prepare(config) - if err != nil { - return err +func (s *Server) startInternal() error { + err := s.dnsProxy.Start() + if err == nil { + s.isRunning = true } - return s.dnsProxy.Start() + return err } // Prepare the object -func (s *Server) prepare(config *ServerConfig) error { - if s.dnsProxy != nil { - return errors.New("DNS server is already started") - } - +func (s *Server) Prepare(config *ServerConfig) error { if config != nil { s.conf = *config } @@ -234,6 +265,14 @@ func (s *Server) prepare(config *ServerConfig) error { EnableEDNSClientSubnet: s.conf.EnableEDNSClientSubnet, } + intlProxyConfig := proxy.Config{ + CacheEnabled: true, + CacheSizeBytes: 4096, + Upstreams: s.conf.Upstreams, + DomainsReservedUpstreams: s.conf.DomainsReservedUpstreams, + } + s.internalProxy = &proxy.Proxy{Config: intlProxyConfig} + s.access = &accessCtx{} err = s.access.Init(s.conf.AllowedClients, s.conf.DisallowedClients, s.conf.BlockedHosts) if err != nil { @@ -277,24 +316,20 @@ func (s *Server) Stop() error { func (s *Server) stopInternal() error { if s.dnsProxy != nil { err := s.dnsProxy.Stop() - s.dnsProxy = nil if err != nil { return errorx.Decorate(err, "could not stop the DNS server properly") } } + s.isRunning = false return nil } // IsRunning returns true if the DNS server is running func (s *Server) IsRunning() bool { s.RLock() - isRunning := true - if s.dnsProxy == nil { - isRunning = false - } - s.RUnlock() - return isRunning + defer s.RUnlock() + return s.isRunning } // Restart - restart server @@ -306,7 +341,7 @@ func (s *Server) Restart() error { if err != nil { return errorx.Decorate(err, "could not reconfigure the server") } - err = s.startInternal(nil) + err = s.startInternal() if err != nil { return errorx.Decorate(err, "could not reconfigure the server") } @@ -330,7 +365,12 @@ func (s *Server) Reconfigure(config *ServerConfig) error { time.Sleep(1 * time.Second) } - err = s.startInternal(config) + err = s.Prepare(config) + if err != nil { + return errorx.Decorate(err, "could not reconfigure the server") + } + + err = s.startInternal() if err != nil { return errorx.Decorate(err, "could not reconfigure the server") } diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 43970a7e..0e7339ce 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -28,7 +28,7 @@ const ( func TestServer(t *testing.T) { s := createTestServer(t) - err := s.Start(nil) + err := s.Start() if err != nil { t.Fatalf("Failed to start server: %s", err) } @@ -62,7 +62,7 @@ func TestServer(t *testing.T) { func TestServerWithProtectionDisabled(t *testing.T) { s := createTestServer(t) s.conf.ProtectionEnabled = false - err := s.Start(nil) + err := s.Start() if err != nil { t.Fatalf("Failed to start server: %s", err) } @@ -94,8 +94,9 @@ func TestDotServer(t *testing.T) { PrivateKeyData: keyPem, } + _ = s.Prepare(nil) // Starting the server - err := s.Start(nil) + err := s.Start() if err != nil { t.Fatalf("Failed to start server: %s", err) } @@ -127,7 +128,7 @@ func TestDotServer(t *testing.T) { func TestServerRace(t *testing.T) { s := createTestServer(t) - err := s.Start(nil) + err := s.Start() if err != nil { t.Fatalf("Failed to start server: %s", err) } @@ -150,7 +151,7 @@ func TestServerRace(t *testing.T) { func TestSafeSearch(t *testing.T) { s := createTestServer(t) - err := s.Start(nil) + err := s.Start() if err != nil { t.Fatalf("Failed to start server: %s", err) } @@ -191,7 +192,7 @@ func TestSafeSearch(t *testing.T) { func TestInvalidRequest(t *testing.T) { s := createTestServer(t) - err := s.Start(nil) + err := s.Start() if err != nil { t.Fatalf("Failed to start server: %s", err) } @@ -217,7 +218,7 @@ func TestInvalidRequest(t *testing.T) { func TestBlockedRequest(t *testing.T) { s := createTestServer(t) - err := s.Start(nil) + err := s.Start() if err != nil { t.Fatalf("Failed to start server: %s", err) } @@ -319,7 +320,7 @@ func (u *testUpstream) Address() string { func (s *Server) startWithUpstream(u upstream.Upstream) error { s.Lock() defer s.Unlock() - err := s.prepare(nil) + err := s.Prepare(nil) if err != nil { return err } @@ -386,7 +387,7 @@ func TestBlockCNAME(t *testing.T) { func TestNullBlockedRequest(t *testing.T) { s := createTestServer(t) s.conf.FilteringConfig.BlockingMode = "null_ip" - err := s.Start(nil) + err := s.Start() if err != nil { t.Fatalf("Failed to start server: %s", err) } @@ -425,7 +426,7 @@ func TestNullBlockedRequest(t *testing.T) { func TestBlockedByHosts(t *testing.T) { s := createTestServer(t) - err := s.Start(nil) + err := s.Start() if err != nil { t.Fatalf("Failed to start server: %s", err) } @@ -464,7 +465,7 @@ func TestBlockedByHosts(t *testing.T) { func TestBlockedBySafeBrowsing(t *testing.T) { s := createTestServer(t) - err := s.Start(nil) + err := s.Start() if err != nil { t.Fatalf("Failed to start server: %s", err) } @@ -530,6 +531,8 @@ func createTestServer(t *testing.T) *Server { s.conf.TCPListenAddr = &net.TCPAddr{Port: 0} s.conf.UpstreamDNS = []string{"8.8.8.8:53", "8.8.4.4:53"} s.conf.FilteringConfig.ProtectionEnabled = true + err := s.Prepare(nil) + assert.True(t, err == nil) return s } diff --git a/go.mod b/go.mod index fffa32e8..a3506604 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome go 1.13 require ( - github.com/AdguardTeam/dnsproxy v0.22.0 + github.com/AdguardTeam/dnsproxy v0.23.0 github.com/AdguardTeam/golibs v0.3.0 github.com/AdguardTeam/urlfilter v0.7.0 github.com/NYTimes/gziphandler v1.1.1 diff --git a/go.sum b/go.sum index f3f3b62a..4c37a7ac 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/AdguardTeam/dnsproxy v0.22.0 h1:8mpPu+KN0puFTHNhGy7XQ13fe3+3DGFaiwnqhNMWl+M= -github.com/AdguardTeam/dnsproxy v0.22.0/go.mod h1:2qy8rpdfBzKgMPxkHmPdaNK4XZJ322v4KtVGI8s8Bn0= +github.com/AdguardTeam/dnsproxy v0.23.0 h1:GrOUapcWjf19MF8NznZUbcYujBbl7QXapBWTFKqkJQg= +github.com/AdguardTeam/dnsproxy v0.23.0/go.mod h1:2qy8rpdfBzKgMPxkHmPdaNK4XZJ322v4KtVGI8s8Bn0= github.com/AdguardTeam/golibs v0.2.4 h1:GUssokegKxKF13K67Pgl0ZGwqHjNN6X7sep5ik6ORdY= github.com/AdguardTeam/golibs v0.2.4/go.mod h1:R3M+mAg3nWG4X4Hsag5eef/TckHFH12ZYhK7AzJc8+U= github.com/AdguardTeam/golibs v0.3.0 h1:1zO8ulGEOdXDDM++Ap4sYfTsT/Z4tZBZtiWSA4ykcOU= diff --git a/home/clients.go b/home/clients.go index b8e43d6f..573fae4d 100644 --- a/home/clients.go +++ b/home/clients.go @@ -128,24 +128,30 @@ func (clients *clientsContainer) addFromConfig(objects []clientObject) { // WriteDiskConfig - write configuration func (clients *clientsContainer) WriteDiskConfig(objects *[]clientObject) { - clientsList := clients.GetList() - for _, cli := range clientsList { + clients.lock.Lock() + for _, cli := range clients.list { cy := clientObject{ - Name: cli.Name, - IDs: cli.IDs, - UseGlobalSettings: !cli.UseOwnSettings, - FilteringEnabled: cli.FilteringEnabled, - ParentalEnabled: cli.ParentalEnabled, - SafeSearchEnabled: cli.SafeSearchEnabled, - SafeBrowsingEnabled: cli.SafeBrowsingEnabled, - + Name: cli.Name, + UseGlobalSettings: !cli.UseOwnSettings, + FilteringEnabled: cli.FilteringEnabled, + ParentalEnabled: cli.ParentalEnabled, + SafeSearchEnabled: cli.SafeSearchEnabled, + SafeBrowsingEnabled: cli.SafeBrowsingEnabled, UseGlobalBlockedServices: !cli.UseOwnBlockedServices, - BlockedServices: cli.BlockedServices, - - Upstreams: cli.Upstreams, } + + cy.IDs = make([]string, len(cli.IDs)) + copy(cy.IDs, cli.IDs) + + cy.BlockedServices = make([]string, len(cli.BlockedServices)) + copy(cy.BlockedServices, cli.BlockedServices) + + cy.Upstreams = make([]string, len(cli.Upstreams)) + copy(cy.Upstreams, cli.Upstreams) + *objects = append(*objects, cy) } + clients.lock.Unlock() } func (clients *clientsContainer) periodicUpdate() { @@ -157,11 +163,6 @@ func (clients *clientsContainer) periodicUpdate() { } } -// GetList returns the pointer to clients list -func (clients *clientsContainer) GetList() map[string]*Client { - return clients.list -} - // Exists checks if client with this IP already exists func (clients *clientsContainer) Exists(ip string, source clientSource) bool { clients.lock.Lock() diff --git a/home/config.go b/home/config.go index cb264787..3d99160b 100644 --- a/home/config.go +++ b/home/config.go @@ -29,6 +29,7 @@ type logSettings struct { Verbose bool `yaml:"verbose"` // If true, verbose logging is enabled } +// HTTPSServer - HTTPS Server type HTTPSServer struct { server *http.Server cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey @@ -51,25 +52,15 @@ type configuration struct { runningAsService bool disableUpdate bool // If set, don't check for updates appSignalChannel chan os.Signal - clients clientsContainer // per-client-settings module controlLock sync.Mutex transport *http.Transport client *http.Client - stats stats.Stats // statistics module - queryLog querylog.QueryLog // query log module - auth *Auth // HTTP authentication module + auth *Auth // HTTP authentication module // cached version.json to avoid hammering github.io for each page reload versionCheckJSON []byte versionCheckLastTime time.Time - dnsctx dnsContext - dnsFilter *dnsfilter.Dnsfilter - dnsServer *dnsforward.Server - dhcpServer *dhcpd.Server - httpServer *http.Server - httpsServer HTTPSServer - BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server Users []User `yaml:"users"` // Users that can access HTTP server @@ -296,41 +287,41 @@ func (c *configuration) write() error { c.Lock() defer c.Unlock() - config.clients.WriteDiskConfig(&config.Clients) + Context.clients.WriteDiskConfig(&config.Clients) if config.auth != nil { config.Users = config.auth.GetUsers() } - if config.stats != nil { + if Context.stats != nil { sdc := stats.DiskConfig{} - config.stats.WriteDiskConfig(&sdc) + Context.stats.WriteDiskConfig(&sdc) config.DNS.StatsInterval = sdc.Interval } - if config.queryLog != nil { + if Context.queryLog != nil { dc := querylog.DiskConfig{} - config.queryLog.WriteDiskConfig(&dc) + Context.queryLog.WriteDiskConfig(&dc) config.DNS.QueryLogEnabled = dc.Enabled config.DNS.QueryLogInterval = dc.Interval config.DNS.QueryLogMemSize = dc.MemSize } - if config.dnsFilter != nil { + if Context.dnsFilter != nil { c := dnsfilter.Config{} - config.dnsFilter.WriteDiskConfig(&c) + Context.dnsFilter.WriteDiskConfig(&c) config.DNS.DnsfilterConf = c } - if config.dnsServer != nil { + if Context.dnsServer != nil { c := dnsforward.FilteringConfig{} - config.dnsServer.WriteDiskConfig(&c) + Context.dnsServer.WriteDiskConfig(&c) config.DNS.FilteringConfig = c } - if config.dhcpServer != nil { + if Context.dhcpServer != nil { c := dhcpd.ServerConfig{} - config.dhcpServer.WriteDiskConfig(&c) + Context.dhcpServer.WriteDiskConfig(&c) config.DHCP = c } diff --git a/home/control.go b/home/control.go index e27164c9..2953cf14 100644 --- a/home/control.go +++ b/home/control.go @@ -93,8 +93,8 @@ func getDNSAddresses() []string { func handleStatus(w http.ResponseWriter, r *http.Request) { c := dnsforward.FilteringConfig{} - if config.dnsServer != nil { - config.dnsServer.WriteDiskConfig(&c) + if Context.dnsServer != nil { + Context.dnsServer.WriteDiskConfig(&c) } data := map[string]interface{}{ "dns_addresses": getDNSAddresses(), @@ -154,7 +154,7 @@ func handleDOH(w http.ResponseWriter, r *http.Request) { return } - config.dnsServer.ServeHTTP(w, r) + Context.dnsServer.ServeHTTP(w, r) } // ------------------------ diff --git a/home/control_install.go b/home/control_install.go index 7e9e315f..61abeb19 100644 --- a/home/control_install.go +++ b/home/control_install.go @@ -235,13 +235,19 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { config.DNS.BindHost = newSettings.DNS.IP config.DNS.Port = newSettings.DNS.Port - initDNSServer() - - err = startDNSServer() - if err != nil { + err = initDNSServer() + var err2 error + if err == nil { + err2 = startDNSServer() + } + if err != nil || err2 != nil { config.firstRun = true copyInstallSettings(&config, &curConfig) - httpError(w, http.StatusInternalServerError, "Couldn't start DNS server: %s", err) + if err != nil { + httpError(w, http.StatusInternalServerError, "Couldn't initialize DNS server: %s", err) + } else { + httpError(w, http.StatusInternalServerError, "Couldn't start DNS server: %s", err2) + } return } @@ -261,7 +267,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { // until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely if restartHTTP { go func() { - _ = config.httpServer.Shutdown(context.TODO()) + _ = Context.httpServer.Shutdown(context.TODO()) }() } diff --git a/home/control_tls.go b/home/control_tls.go index 0f20bb91..f0f4c655 100644 --- a/home/control_tls.go +++ b/home/control_tls.go @@ -80,7 +80,7 @@ func handleTLSValidate(w http.ResponseWriter, r *http.Request) { // check if port is available // BUT: if we are already using this port, no need alreadyRunning := false - if config.httpsServer.server != nil { + if Context.httpsServer.server != nil { alreadyRunning = true } if !alreadyRunning { @@ -110,7 +110,7 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) { // check if port is available // BUT: if we are already using this port, no need alreadyRunning := false - if config.httpsServer.server != nil { + if Context.httpsServer.server != nil { alreadyRunning = true } if !alreadyRunning { @@ -145,12 +145,12 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) { if restartHTTPS { go func() { time.Sleep(time.Second) // TODO: could not find a way to reliably know that data was fully sent to client by https server, so we wait a bit to let response through before closing the server - config.httpsServer.cond.L.Lock() - config.httpsServer.cond.Broadcast() - if config.httpsServer.server != nil { - config.httpsServer.server.Shutdown(context.TODO()) + Context.httpsServer.cond.L.Lock() + Context.httpsServer.cond.Broadcast() + if Context.httpsServer.server != nil { + Context.httpsServer.server.Shutdown(context.TODO()) } - config.httpsServer.cond.L.Unlock() + Context.httpsServer.cond.L.Unlock() }() } } diff --git a/home/dhcp.go b/home/dhcp.go index dcb7a28b..799333b0 100644 --- a/home/dhcp.go +++ b/home/dhcp.go @@ -10,12 +10,12 @@ func startDHCPServer() error { return nil } - err := config.dhcpServer.Init(config.DHCP) + err := Context.dhcpServer.Init(config.DHCP) if err != nil { return errorx.Decorate(err, "Couldn't init DHCP server") } - err = config.dhcpServer.Start() + err = Context.dhcpServer.Start() if err != nil { return errorx.Decorate(err, "Couldn't start DHCP server") } @@ -27,7 +27,7 @@ func stopDHCPServer() error { return nil } - err := config.dhcpServer.Stop() + err := Context.dhcpServer.Stop() if err != nil { return errorx.Decorate(err, "Couldn't stop DHCP server") } diff --git a/home/dns.go b/home/dns.go index dbc48a73..d6fc8b8a 100644 --- a/home/dns.go +++ b/home/dns.go @@ -15,11 +15,6 @@ import ( "github.com/joomcode/errorx" ) -type dnsContext struct { - rdns *RDNS - whois *Whois -} - // Called by other modules when configuration is changed func onConfigModified() { _ = config.write() @@ -28,12 +23,12 @@ func onConfigModified() { // initDNSServer creates an instance of the dnsforward.Server // Please note that we must do it even if we don't start it // so that we had access to the query log and the stats -func initDNSServer() { +func initDNSServer() error { baseDir := config.getDataDir() err := os.MkdirAll(baseDir, 0755) if err != nil { - log.Fatalf("Cannot create DNS data dir at %s: %s", baseDir, err) + return fmt.Errorf("Cannot create DNS data dir at %s: %s", baseDir, err) } statsConf := stats.Config{ @@ -42,9 +37,9 @@ func initDNSServer() { ConfigModified: onConfigModified, HTTPRegister: httpRegister, } - config.stats, err = stats.New(statsConf) + Context.stats, err = stats.New(statsConf) if err != nil { - log.Fatal("Couldn't initialize statistics module") + return fmt.Errorf("Couldn't initialize statistics module") } conf := querylog.Config{ Enabled: config.DNS.QueryLogEnabled, @@ -54,7 +49,7 @@ func initDNSServer() { ConfigModified: onConfigModified, HTTPRegister: httpRegister, } - config.queryLog = querylog.New(conf) + Context.queryLog = querylog.New(conf) filterConf := config.DNS.DnsfilterConf bindhost := config.DNS.BindHost @@ -64,22 +59,28 @@ func initDNSServer() { filterConf.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port) filterConf.ConfigModified = onConfigModified filterConf.HTTPRegister = httpRegister - config.dnsFilter = dnsfilter.New(&filterConf, nil) + Context.dnsFilter = dnsfilter.New(&filterConf, nil) - config.dnsServer = dnsforward.NewServer(config.dnsFilter, config.stats, config.queryLog) + Context.dnsServer = dnsforward.NewServer(Context.dnsFilter, Context.stats, Context.queryLog) + dnsConfig := generateServerConfig() + err = Context.dnsServer.Prepare(&dnsConfig) + if err != nil { + return fmt.Errorf("dnsServer.Prepare: %s", err) + } sessFilename := filepath.Join(baseDir, "sessions.db") config.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60) config.Users = nil - config.dnsctx.rdns = InitRDNS(&config.clients) - config.dnsctx.whois = initWhois(&config.clients) + Context.rdns = InitRDNS(Context.dnsServer, &Context.clients) + Context.whois = initWhois(&Context.clients) initFiltering() + return nil } func isRunning() bool { - return config.dnsServer != nil && config.dnsServer.IsRunning() + return Context.dnsServer != nil && Context.dnsServer.IsRunning() } // nolint (gocyclo) @@ -145,14 +146,14 @@ func onDNSRequest(d *proxy.DNSContext) { ipAddr := net.ParseIP(ip) if !ipAddr.IsLoopback() { - config.dnsctx.rdns.Begin(ip) + Context.rdns.Begin(ip) } if isPublicIP(ipAddr) { - config.dnsctx.whois.Begin(ip) + Context.whois.Begin(ip) } } -func generateServerConfig() (dnsforward.ServerConfig, error) { +func generateServerConfig() dnsforward.ServerConfig { newconfig := dnsforward.ServerConfig{ UDPListenAddr: &net.UDPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port}, TCPListenAddr: &net.TCPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port}, @@ -171,11 +172,11 @@ func generateServerConfig() (dnsforward.ServerConfig, error) { newconfig.FilterHandler = applyAdditionalFiltering newconfig.GetUpstreamsByClient = getUpstreamsByClient - return newconfig, nil + return newconfig } func getUpstreamsByClient(clientAddr string) []string { - c, ok := config.clients.Find(clientAddr) + c, ok := Context.clients.Find(clientAddr) if !ok { return []string{} } @@ -192,7 +193,7 @@ func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteri return } - c, ok := config.clients.Find(clientAddr) + c, ok := Context.clients.Find(clientAddr) if !ok { return } @@ -220,12 +221,7 @@ func startDNSServer() error { enableFilters(false) - newconfig, err := generateServerConfig() - if err != nil { - return errorx.Decorate(err, "Couldn't start forwarding DNS server") - } - - err = config.dnsServer.Start(&newconfig) + err := Context.dnsServer.Start() if err != nil { return errorx.Decorate(err, "Couldn't start forwarding DNS server") } @@ -233,14 +229,14 @@ func startDNSServer() error { startFiltering() const topClientsNumber = 100 // the number of clients to get - topClients := config.stats.GetTopClientsIP(topClientsNumber) + topClients := Context.stats.GetTopClientsIP(topClientsNumber) for _, ip := range topClients { ipAddr := net.ParseIP(ip) if !ipAddr.IsLoopback() { - config.dnsctx.rdns.Begin(ip) + Context.rdns.Begin(ip) } if isPublicIP(ipAddr) { - config.dnsctx.whois.Begin(ip) + Context.whois.Begin(ip) } } @@ -248,11 +244,8 @@ func startDNSServer() error { } func reconfigureDNSServer() error { - newconfig, err := generateServerConfig() - if err != nil { - return errorx.Decorate(err, "Couldn't start forwarding DNS server") - } - err = config.dnsServer.Reconfigure(&newconfig) + newconfig := generateServerConfig() + err := Context.dnsServer.Reconfigure(&newconfig) if err != nil { return errorx.Decorate(err, "Couldn't start forwarding DNS server") } @@ -261,26 +254,22 @@ func reconfigureDNSServer() error { } func stopDNSServer() error { - if !isRunning() { - return nil - } - - err := config.dnsServer.Stop() + err := Context.dnsServer.Stop() if err != nil { return errorx.Decorate(err, "Couldn't stop forwarding DNS server") } // DNS forward module must be closed BEFORE stats or queryLog because it depends on them - config.dnsServer.Close() + Context.dnsServer.Close() - config.dnsFilter.Close() - config.dnsFilter = nil + Context.dnsFilter.Close() + Context.dnsFilter = nil - config.stats.Close() - config.stats = nil + Context.stats.Close() + Context.stats = nil - config.queryLog.Close() - config.queryLog = nil + Context.queryLog.Close() + Context.queryLog = nil config.auth.Close() config.auth = nil diff --git a/home/dns_test.go b/home/dns_test.go deleted file mode 100644 index 52344279..00000000 --- a/home/dns_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package home - -import ( - "os" - "testing" -) - -func TestResolveRDNS(t *testing.T) { - _ = os.RemoveAll(config.getDataDir()) - defer func() { _ = os.RemoveAll(config.getDataDir()) }() - - config.DNS.BindHost = "1.1.1.1" - initDNSServer() - if r := config.dnsctx.rdns.resolve("1.1.1.1"); r != "one.one.one.one" { - t.Errorf("resolveRDNS(): %s", r) - } -} diff --git a/home/filter.go b/home/filter.go index 04a32d87..4359d9f4 100644 --- a/home/filter.go +++ b/home/filter.go @@ -514,5 +514,5 @@ func enableFilters(async bool) { } } - _ = config.dnsFilter.SetFilters(filters, async) + _ = Context.dnsFilter.SetFilters(filters, async) } diff --git a/home/helpers.go b/home/helpers.go index f4be0b38..2fdd0181 100644 --- a/home/helpers.go +++ b/home/helpers.go @@ -16,7 +16,6 @@ import ( "syscall" "time" - "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" "github.com/joomcode/errorx" ) @@ -118,7 +117,7 @@ func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.Res return } // enforce https? - if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && config.httpsServer.server != nil { + if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && Context.httpsServer.server != nil { // yes, and we want host from host:port host, _, err := net.SplitHostPort(r.Host) if err != nil { @@ -273,14 +272,8 @@ func customDialContext(ctx context.Context, network, addr string) (net.Conn, err return con, err } - bindhost := config.DNS.BindHost - if config.DNS.BindHost == "0.0.0.0" { - bindhost = "127.0.0.1" - } - resolverAddr := fmt.Sprintf("%s:%d", bindhost, config.DNS.Port) - r := upstream.NewResolver(resolverAddr, 30*time.Second) - addrs, e := r.LookupIPAddr(ctx, host) - log.Tracef("LookupIPAddr: %s: %v", host, addrs) + addrs, e := Context.dnsServer.Resolve(host) + log.Debug("dnsServer.Resolve: %s: %v", host, addrs) if e != nil { return nil, e } diff --git a/home/home.go b/home/home.go index 1ff0f128..66309c53 100644 --- a/home/home.go +++ b/home/home.go @@ -21,6 +21,10 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/dhcpd" + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/AdguardTeam/AdGuardHome/dnsforward" + "github.com/AdguardTeam/AdGuardHome/querylog" + "github.com/AdguardTeam/AdGuardHome/stats" "github.com/AdguardTeam/golibs/log" "github.com/NYTimes/gziphandler" "github.com/gobuffalo/packr" @@ -40,6 +44,23 @@ var ( const versionCheckPeriod = time.Hour * 8 +// Global context +type homeContext struct { + clients clientsContainer // per-client-settings module + stats stats.Stats // statistics module + queryLog querylog.QueryLog // query log module + dnsServer *dnsforward.Server // DNS module + rdns *RDNS // rDNS module + whois *Whois // WHOIS module + dnsFilter *dnsfilter.Dnsfilter // DNS filtering module + dhcpServer *dhcpd.Server // DHCP module + httpServer *http.Server // HTTP module + httpsServer HTTPSServer // HTTPS module +} + +// Context - a global context object +var Context homeContext + // Main is the entry point func Main(version string, channel string) { // Init update-related global variables @@ -122,8 +143,8 @@ func run(args options) { config.DHCP.WorkDir = config.ourWorkingDir config.DHCP.HTTPRegister = httpRegister config.DHCP.ConfigModified = onConfigModified - config.dhcpServer = dhcpd.Create(config.DHCP) - config.clients.Init(config.Clients, config.dhcpServer) + Context.dhcpServer = dhcpd.Create(config.DHCP) + Context.clients.Init(config.Clients, Context.dhcpServer) config.Clients = nil if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && @@ -146,7 +167,10 @@ func run(args options) { log.Fatal(err) } - initDNSServer() + err = initDNSServer() + if err != nil { + log.Fatalf("%s", err) + } go func() { err = startDNSServer() if err != nil { @@ -178,21 +202,21 @@ func run(args options) { registerInstallHandlers() } - config.httpsServer.cond = sync.NewCond(&config.httpsServer.Mutex) + Context.httpsServer.cond = sync.NewCond(&Context.httpsServer.Mutex) // for https, we have a separate goroutine loop go httpServerLoop() // this loop is used as an ability to change listening host and/or port - for !config.httpsServer.shutdown { + for !Context.httpsServer.shutdown { printHTTPAddresses("http") // we need to have new instance, because after Shutdown() the Server is not usable address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort)) - config.httpServer = &http.Server{ + Context.httpServer = &http.Server{ Addr: address, } - err := config.httpServer.ListenAndServe() + err := Context.httpServer.ListenAndServe() if err != http.ErrServerClosed { cleanupAlways() log.Fatal(err) @@ -205,14 +229,14 @@ func run(args options) { } func httpServerLoop() { - for !config.httpsServer.shutdown { - config.httpsServer.cond.L.Lock() + for !Context.httpsServer.shutdown { + Context.httpsServer.cond.L.Lock() // this mechanism doesn't let us through until all conditions are met for config.TLS.Enabled == false || config.TLS.PortHTTPS == 0 || len(config.TLS.PrivateKeyData) == 0 || len(config.TLS.CertificateChainData) == 0 { // sleep until necessary data is supplied - config.httpsServer.cond.Wait() + Context.httpsServer.cond.Wait() } address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.TLS.PortHTTPS)) // validate current TLS config and update warnings (it could have been loaded from file) @@ -236,10 +260,10 @@ func httpServerLoop() { cleanupAlways() log.Fatal(err) } - config.httpsServer.cond.L.Unlock() + Context.httpsServer.cond.L.Unlock() // prepare HTTPS server - config.httpsServer.server = &http.Server{ + Context.httpsServer.server = &http.Server{ Addr: address, TLSConfig: &tls.Config{ Certificates: []tls.Certificate{cert}, @@ -248,7 +272,7 @@ func httpServerLoop() { } printHTTPAddresses("https") - err = config.httpsServer.server.ListenAndServeTLS("", "") + err = Context.httpsServer.server.ListenAndServeTLS("", "") if err != http.ErrServerClosed { cleanupAlways() log.Fatal(err) @@ -326,11 +350,10 @@ func configureLogger(args options) { ls.LogFile = args.logFile } - level := log.INFO + // log.SetLevel(log.INFO) - default if ls.Verbose { - level = log.DEBUG + log.SetLevel(log.DEBUG) } - log.SetLevel(level) if args.runningAsService && ls.LogFile == "" && runtime.GOOS == "windows" { // When running as a Windows service, use eventlog by default if nothing else is configured @@ -378,11 +401,11 @@ func cleanup() { // Stop HTTP server, possibly waiting for all active connections to be closed func stopHTTPServer() { log.Info("Stopping HTTP server...") - config.httpsServer.shutdown = true - if config.httpsServer.server != nil { - config.httpsServer.server.Shutdown(context.TODO()) + Context.httpsServer.shutdown = true + if Context.httpsServer.server != nil { + Context.httpsServer.server.Shutdown(context.TODO()) } - config.httpServer.Shutdown(context.TODO()) + Context.httpServer.Shutdown(context.TODO()) log.Info("Stopped HTTP server") } diff --git a/home/home_test.go b/home/home_test.go new file mode 100644 index 00000000..8e2c0508 --- /dev/null +++ b/home/home_test.go @@ -0,0 +1,154 @@ +package home + +import ( + "context" + "io/ioutil" + "net/http" + "os" + "path/filepath" + "testing" + "time" + + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/stretchr/testify/assert" +) + +const yamlConf = `bind_host: 127.0.0.1 +bind_port: 3000 +users: [] +language: en +rlimit_nofile: 0 +web_session_ttl: 720 +dns: + bind_host: 127.0.0.1 + port: 5354 + statistics_interval: 90 + querylog_enabled: true + querylog_interval: 90 + querylog_memsize: 0 + protection_enabled: true + blocking_mode: null_ip + blocked_response_ttl: 0 + ratelimit: 100 + ratelimit_whitelist: [] + refuse_any: false + bootstrap_dns: + - 1.1.1.1:53 + all_servers: false + allowed_clients: [] + disallowed_clients: [] + blocked_hosts: [] + parental_block_host: family-block.dns.adguard.com + safebrowsing_block_host: standard-block.dns.adguard.com + cache_size: 0 + upstream_dns: + - https://1.1.1.1/dns-query + filtering_enabled: true + filters_update_interval: 168 + parental_sensitivity: 13 + parental_enabled: true + safesearch_enabled: false + safebrowsing_enabled: false + safebrowsing_cache_size: 1048576 + safesearch_cache_size: 1048576 + parental_cache_size: 1048576 + cache_time: 30 + rewrites: [] + blocked_services: [] +tls: + enabled: false + server_name: www.example.com + force_https: false + port_https: 443 + port_dns_over_tls: 853 + certificate_chain: "" + private_key: "" + certificate_path: "" + private_key_path: "" +filters: +- enabled: true + url: https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt + name: AdGuard Simplified Domain Names filter + id: 1 +- enabled: false + url: https://hosts-file.net/ad_servers.txt + name: hpHosts - Ad and Tracking servers only + id: 2 +- enabled: false + url: https://adaway.org/hosts.txt + name: adaway + id: 3 +user_rules: +- "" +dhcp: + enabled: false + interface_name: "" + gateway_ip: "" + subnet_mask: "" + range_start: "" + range_end: "" + lease_duration: 86400 + icmp_timeout_msec: 1000 +clients: [] +log_file: "" +verbose: false +schema_version: 5 +` + +// . Create a configuration file +// . Start AGH instance +// . Check Web server +// . Check DNS server +// . Wait until the filters are downloaded +// . Stop and cleanup +func TestHome(t *testing.T) { + dir := prepareTestDir() + defer func() { _ = os.RemoveAll(dir) }() + fn := filepath.Join(dir, "AdGuardHome.yaml") + + assert.True(t, ioutil.WriteFile(fn, []byte(yamlConf), 0644) == nil) + fn, _ = filepath.Abs(fn) + + args := options{} + args.configFilename = fn + args.workDir = dir + go run(args) + + var err error + var resp *http.Response + h := http.Client{} + for i := 0; i != 5; i++ { + resp, err = h.Get("http://127.0.0.1:3000/") + if err == nil && resp.StatusCode != 404 { + break + } + time.Sleep(1 * time.Second) + } + assert.Truef(t, err == nil, "%s", err) + assert.Equal(t, 200, resp.StatusCode) + + resp, err = h.Get("http://127.0.0.1:3000/control/status") + assert.Truef(t, err == nil, "%s", err) + assert.Equal(t, 200, resp.StatusCode) + + r := upstream.NewResolver("127.0.0.1:5354", 3*time.Second) + addrs, err := r.LookupIPAddr(context.TODO(), "static.adguard.com") + assert.Truef(t, err == nil, "%s", err) + haveIP := len(addrs) != 0 + assert.True(t, haveIP) + + for i := 1; ; i++ { + st, err := os.Stat(filepath.Join(dir, "data", "filters", "1.txt")) + if err == nil && st.Size() != 0 { + break + } + if i == 5 { + assert.True(t, false) + break + } + time.Sleep(1 * time.Second) + } + + cleanup() + cleanupAlways() +} diff --git a/home/rdns.go b/home/rdns.go index e395fdb3..297e1925 100644 --- a/home/rdns.go +++ b/home/rdns.go @@ -2,25 +2,20 @@ package home import ( "encoding/binary" - "fmt" "strings" "time" - "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" ) -const ( - rdnsTimeout = 3 * time.Second // max time to wait for rDNS response -) - // RDNS - module context type RDNS struct { + dnsServer *dnsforward.Server clients *clientsContainer - ipChannel chan string // pass data from DNS request handling thread to rDNS thread - upstream upstream.Upstream // Upstream object for our own DNS server + ipChannel chan string // pass data from DNS request handling thread to rDNS thread // Contains IP addresses of clients to be resolved by rDNS // If IP address is resolved, it stays here while it's inside Clients. @@ -30,25 +25,10 @@ type RDNS struct { } // InitRDNS - create module context -func InitRDNS(clients *clientsContainer) *RDNS { +func InitRDNS(dnsServer *dnsforward.Server, clients *clientsContainer) *RDNS { r := RDNS{} + r.dnsServer = dnsServer r.clients = clients - var err error - - bindhost := config.DNS.BindHost - if config.DNS.BindHost == "0.0.0.0" { - bindhost = "127.0.0.1" - } - resolverAddress := fmt.Sprintf("%s:%d", bindhost, config.DNS.Port) - - opts := upstream.Options{ - Timeout: rdnsTimeout, - } - r.upstream, err = upstream.AddressToUpstream(resolverAddress, opts) - if err != nil { - log.Error("upstream.AddressToUpstream: %s", err) - return nil - } cconf := cache.Config{} cconf.EnableLRU = true @@ -109,7 +89,7 @@ func (r *RDNS) resolve(ip string) string { return "" } - resp, err := r.upstream.Exchange(&req) + resp, err := r.dnsServer.Exchange(&req) if err != nil { log.Debug("Error while making an rDNS lookup for %s: %s", ip, err) return "" @@ -144,6 +124,6 @@ func (r *RDNS) workerLoop() { continue } - _, _ = config.clients.AddHost(ip, host, ClientSourceRDNS) + _, _ = r.clients.AddHost(ip, host, ClientSourceRDNS) } } diff --git a/home/rdns_test.go b/home/rdns_test.go new file mode 100644 index 00000000..84c996d1 --- /dev/null +++ b/home/rdns_test.go @@ -0,0 +1,21 @@ +package home + +import ( + "testing" + + "github.com/AdguardTeam/AdGuardHome/dnsforward" + "github.com/stretchr/testify/assert" +) + +func TestResolveRDNS(t *testing.T) { + dns := &dnsforward.Server{} + conf := &dnsforward.ServerConfig{} + conf.UpstreamDNS = []string{"8.8.8.8"} + err := dns.Prepare(conf) + assert.True(t, err == nil, "%s", err) + + clients := &clientsContainer{} + rdns := InitRDNS(dns, clients) + r := rdns.resolve("1.1.1.1") + assert.True(t, r == "one.one.one.one", "%s", r) +} diff --git a/stats/stats.go b/stats/stats.go index 68e53b3c..f4c05b94 100644 --- a/stats/stats.go +++ b/stats/stats.go @@ -25,6 +25,8 @@ type Config struct { // Register an HTTP handler HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) + + limit uint32 // maximum time we need to keep data for (in hours) } // New - create object diff --git a/stats/stats_http.go b/stats/stats_http.go index 4d860fee..ab49a6ae 100644 --- a/stats/stats_http.go +++ b/stats/stats_http.go @@ -21,13 +21,8 @@ func httpError(r *http.Request, w http.ResponseWriter, code int, format string, // Return data func (s *statsCtx) handleStats(w http.ResponseWriter, r *http.Request) { - units := Hours - if s.limit/24 > 7 { - units = Days - } - start := time.Now() - d := s.getData(units) + d := s.getData() log.Debug("Stats: prepared data in %v", time.Since(start)) if d == nil { @@ -52,7 +47,7 @@ type config struct { // Get configuration func (s *statsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) { resp := config{} - resp.IntervalDays = s.limit / 24 + resp.IntervalDays = s.conf.limit / 24 data, err := json.Marshal(resp) if err != nil { diff --git a/stats/stats_test.go b/stats/stats_test.go index 0c7c4bb5..7b2b1a99 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -45,7 +45,7 @@ func TestStats(t *testing.T) { e.Time = 123456 s.Update(e) - d := s.getData(Hours) + d := s.getData() a := []uint64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2} assert.True(t, UIntArrayEquals(d["dns_queries"].([]uint64), a)) @@ -116,7 +116,7 @@ func TestLargeNumbers(t *testing.T) { } } - d := s.getData(Hours) + d := s.getData() assert.True(t, d["num_dns_queries"].(uint64) == uint64(int(hour)*n)) s.Close() diff --git a/stats/stats_unit.go b/stats/stats_unit.go index 5b1661e6..b52b31f0 100644 --- a/stats/stats_unit.go +++ b/stats/stats_unit.go @@ -21,10 +21,8 @@ const ( // statsCtx - global context type statsCtx struct { - limit uint32 // maximum time we need to keep data for (in hours) - db *bolt.DB - - conf Config + db *bolt.DB + conf *Config unit *unit // the current unit unitLock sync.Mutex // protect 'unit' @@ -67,8 +65,9 @@ func createObject(conf Config) (*statsCtx, error) { if !checkInterval(conf.LimitDays) { conf.LimitDays = 1 } - s.limit = conf.LimitDays * 24 - s.conf = conf + s.conf = &Config{} + *s.conf = conf + s.conf.limit = conf.LimitDays * 24 if conf.UnitID == nil { s.conf.UnitID = newUnitID } @@ -82,7 +81,7 @@ func createObject(conf Config) (*statsCtx, error) { var udb *unitDB if tx != nil { log.Tracef("Deleting old units...") - firstID := id - s.limit - 1 + firstID := id - s.conf.limit - 1 unitDel := 0 forEachBkt := func(name []byte, b *bolt.Bucket) error { id := uint32(btoi(name)) @@ -243,7 +242,7 @@ func (s *statsCtx) periodicFlush() { continue } ok1 := s.flushUnitToDB(tx, u.id, udb) - ok2 := s.deleteUnit(tx, id-s.limit) + ok2 := s.deleteUnit(tx, id-s.conf.limit) if ok1 || ok2 { s.commitTxn(tx) } else { @@ -383,12 +382,14 @@ func convertTopArray(a []countPair) []map[string]uint64 { } func (s *statsCtx) setLimit(limitDays int) { - s.limit = uint32(limitDays) * 24 + conf := *s.conf + conf.limit = uint32(limitDays) * 24 + s.conf = &conf log.Debug("Stats: set limit: %d", limitDays) } func (s *statsCtx) WriteDiskConfig(dc *DiskConfig) { - dc.Interval = s.limit / 24 + dc.Interval = s.conf.limit / 24 } func (s *statsCtx) Close() { @@ -466,7 +467,7 @@ func (s *statsCtx) Update(e Entry) { s.unitLock.Unlock() } -func (s *statsCtx) loadUnits() ([]*unitDB, uint32) { +func (s *statsCtx) loadUnits(limit uint32) ([]*unitDB, uint32) { tx := s.beginTxn(false) if tx == nil { return nil, 0 @@ -478,7 +479,7 @@ func (s *statsCtx) loadUnits() ([]*unitDB, uint32) { s.unitLock.Unlock() units := []*unitDB{} //per-hour units - firstID := curID - s.limit + 1 + firstID := curID - limit + 1 for i := firstID; i != curID; i++ { u := s.loadUnitFromDB(tx, i) if u == nil { @@ -492,8 +493,8 @@ func (s *statsCtx) loadUnits() ([]*unitDB, uint32) { units = append(units, curUnit) - if len(units) != int(s.limit) { - log.Fatalf("len(units) != s.limit: %d %d", len(units), s.limit) + if len(units) != int(limit) { + log.Fatalf("len(units) != limit: %d %d", len(units), limit) } return units, firstID @@ -527,10 +528,16 @@ func (s *statsCtx) loadUnits() ([]*unitDB, uint32) { These values are just the sum of data for all units. */ // nolint (gocyclo) -func (s *statsCtx) getData(timeUnit TimeUnit) map[string]interface{} { - d := map[string]interface{}{} +func (s *statsCtx) getData() map[string]interface{} { + limit := s.conf.limit - units, firstID := s.loadUnits() + d := map[string]interface{}{} + timeUnit := Hours + if limit/24 > 7 { + timeUnit = Days + } + + units, firstID := s.loadUnits(limit) if units == nil { return nil } @@ -561,8 +568,8 @@ func (s *statsCtx) getData(timeUnit TimeUnit) map[string]interface{} { if id <= nextDayID { a = append(a, sum) } - if len(a) != int(s.limit/24) { - log.Fatalf("len(a) != s.limit: %d %d", len(a), s.limit) + if len(a) != int(limit/24) { + log.Fatalf("len(a) != limit: %d %d", len(a), limit) } } d["dns_queries"] = a @@ -705,8 +712,8 @@ func (s *statsCtx) getData(timeUnit TimeUnit) map[string]interface{} { return d } -func (s *statsCtx) GetTopClientsIP(limit uint) []string { - units, _ := s.loadUnits() +func (s *statsCtx) GetTopClientsIP(maxCount uint) []string { + units, _ := s.loadUnits(s.conf.limit) if units == nil { return nil } @@ -718,7 +725,7 @@ func (s *statsCtx) GetTopClientsIP(limit uint) []string { m[it.Name] += it.Count } } - a := convertMapToArray(m, int(limit)) + a := convertMapToArray(m, int(maxCount)) d := []string{} for _, it := range a { d = append(d, it.Name)