mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-22 04:55:33 +03:00
Merge: * use upstream servers directly for the internal DNS resolver
Close #1212 * Server.Start(config *ServerConfig) -> Start() + Server.Prepare(config *ServerConfig) + Server.Resolve(host string) + Server.Exchange() * rDNS: use internal DNS resolver - clients: fix race in WriteDiskConfig() - fix race: move 'clients' object from 'configuration' to 'HomeContext' Go race detector didn't like our 'clients' object in 'configuration'. + add AGH startup test . Create a configuration file . Start AGH instance . Check Web server . Check DNS server . Wait until the filters are downloaded . Stop and cleanup * move module objects from config.* to Context.* * don't call log.SetLevel() if not necessary This helps to avoid Go race detector's warning * ci.sh: 'make' and then run tests Squashed commit of the following: commit 86500c7f749307f37af4cc8c2a1066f679d0cfad Author: Simon Zolin <s.zolin@adguard.com> Date: Tue Dec 10 18:08:53 2019 +0300 minor commit 6e6abb9dca3cd250c458bec23aa30d2250a9eb40 Author: Simon Zolin <s.zolin@adguard.com> Date: Tue Dec 10 18:08:31 2019 +0300 * ci.sh: 'make' and then run tests commit 114192eefea6800e565ba9ab238202c006516c27 Author: Simon Zolin <s.zolin@adguard.com> Date: Tue Dec 10 17:50:04 2019 +0300 fix commit d426deea7f02cdfd4c7217a38c59e51251956a0f Author: Simon Zolin <s.zolin@adguard.com> Date: Tue Dec 10 17:46:33 2019 +0300 tests commit 7b350edf03027895b4e43dee908d0155a9b0ac9b Author: Simon Zolin <s.zolin@adguard.com> Date: Tue Dec 10 15:56:12 2019 +0300 fix test commit 2f5f116873bbbfdd4bb7f82a596f9e1f5c2bcfd8 Author: Simon Zolin <s.zolin@adguard.com> Date: Tue Dec 10 15:48:56 2019 +0300 fix tests commit 3fbdc77f9c34726e2295185279444983652d559e Author: Simon Zolin <s.zolin@adguard.com> Date: Tue Dec 10 15:45:00 2019 +0300 linter commit 9da0b6965a2b6863bcd552fa83a4de2866600bb8 Author: Simon Zolin <s.zolin@adguard.com> Date: Tue Dec 10 15:33:23 2019 +0300 * config.dnsctx.whois -> Context.whois commit c71ebdbdf6efd88c877b2f243c69d3bc00a997d7 Author: Simon Zolin <s.zolin@adguard.com> Date: Tue Dec 10 15:31:08 2019 +0300 * don't call log.SetLevel() if not necessary This helps to avoid Go race detector's warning commit 0f250220133cefdcb0843a50000cb932802b8324 Author: Simon Zolin <s.zolin@adguard.com> Date: Tue Dec 10 15:28:19 2019 +0300 * rdns: refactor commit c460d8c9414940dac852e390b6c1b4d4fb38dff9 Author: Simon Zolin <s.zolin@adguard.com> Date: Tue Dec 10 14:08:08 2019 +0300 Revert: * stats: serialize access to 'limit' Use 'conf *Config' and update it atomically, as in querylog module. (Note: Race detector still doesn't like it) commit 488bcb884971276de0d5629384b29e22c59ee7e6 Author: Simon Zolin <s.zolin@adguard.com> Date: Tue Dec 10 13:50:23 2019 +0300 * config.dnsFilter -> Context.dnsFilter commit 86c0a6827a450414b50acec7ebfc5220d13b81e4 Author: Simon Zolin <s.zolin@adguard.com> Date: Tue Dec 10 13:45:05 2019 +0300 * config.dnsServer -> Context.dnsServer commit ee35ef095ccaabc89e3de0ef52c9b5ed56b36873 Author: Simon Zolin <s.zolin@adguard.com> Date: Tue Dec 10 13:42:10 2019 +0300 * config.dhcpServer -> Context.dhcpServer commit 1537001cd211099d5fad01696c0b806ae5d257b1 Author: Simon Zolin <s.zolin@adguard.com> Date: Tue Dec 10 13:39:45 2019 +0300 * config.queryLog -> Context.queryLog commit e5955fe4ff1ef6f41763461b37b502ea25a3d04c Author: Simon Zolin <s.zolin@adguard.com> Date: Tue Dec 10 13:03:18 2019 +0300 * config.httpsServer -> Context.httpsServer commit 6153c10a9ac173e159d1f05e0db1512579b9203c Author: Simon Zolin <s.zolin@adguard.com> Date: Mon Dec 9 20:12:24 2019 +0300 * config.httpServer -> Context.httpServer commit abd021fb94039015cd45c97614e8b78d4694f956 Author: Simon Zolin <s.zolin@adguard.com> Date: Mon Dec 9 20:08:05 2019 +0300 * stats: serialize access to 'limit' commit 38c2decfd87c712100edcabe62a6d4518719cb53 Author: Simon Zolin <s.zolin@adguard.com> Date: Mon Dec 9 19:57:04 2019 +0300 * config.stats -> Context.stats commit 6caf8965ad44db9dce9a7a5103aa8fa305ad9a06 Author: Simon Zolin <s.zolin@adguard.com> Date: Mon Dec 9 19:45:23 2019 +0300 fix Restart() ... and 6 more commits
This commit is contained in:
parent
fe357d04f7
commit
0a66913b4d
23 changed files with 439 additions and 251 deletions
6
ci.sh
6
ci.sh
|
@ -16,14 +16,14 @@ golangci-lint --version
|
||||||
# Run linter
|
# Run linter
|
||||||
golangci-lint run
|
golangci-lint run
|
||||||
|
|
||||||
# Run tests
|
|
||||||
go test -race -v -bench=. -coverprofile=coverage.txt -covermode=atomic ./...
|
|
||||||
|
|
||||||
# Make
|
# Make
|
||||||
make clean
|
make clean
|
||||||
make build/static/index.html
|
make build/static/index.html
|
||||||
make
|
make
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
go test -race -v -bench=. -coverprofile=coverage.txt -covermode=atomic ./...
|
||||||
|
|
||||||
# if [[ -z "$(git status --porcelain)" ]]; then
|
# if [[ -z "$(git status --porcelain)" ]]; then
|
||||||
# # Working directory clean
|
# # Working directory clean
|
||||||
# echo "Git status is clean"
|
# echo "Git status is clean"
|
||||||
|
|
|
@ -2,7 +2,6 @@ package dnsforward
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -51,7 +50,12 @@ type Server struct {
|
||||||
stats stats.Stats
|
stats stats.Stats
|
||||||
access *accessCtx
|
access *accessCtx
|
||||||
|
|
||||||
|
// DNS proxy instance for internal usage
|
||||||
|
// We don't Start() it and so no listen port is required.
|
||||||
|
internalProxy *proxy.Proxy
|
||||||
|
|
||||||
webRegistered bool
|
webRegistered bool
|
||||||
|
isRunning bool
|
||||||
|
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
conf ServerConfig
|
conf ServerConfig
|
||||||
|
@ -78,6 +82,7 @@ func (s *Server) Close() {
|
||||||
s.dnsFilter = nil
|
s.dnsFilter = nil
|
||||||
s.stats = nil
|
s.stats = nil
|
||||||
s.queryLog = nil
|
s.queryLog = nil
|
||||||
|
s.dnsProxy = nil
|
||||||
s.Unlock()
|
s.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -165,28 +170,54 @@ var defaultValues = ServerConfig{
|
||||||
FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600},
|
FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Resolve - get IP addresses by host name from an upstream server.
|
||||||
|
// No request/response filtering is performed.
|
||||||
|
// Query log and Stats are not updated.
|
||||||
|
// This method may be called before Start().
|
||||||
|
func (s *Server) Resolve(host string) ([]net.IPAddr, error) {
|
||||||
|
s.RLock()
|
||||||
|
defer s.RUnlock()
|
||||||
|
return s.internalProxy.LookupIPAddr(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exchange - send DNS request to an upstream server and receive response
|
||||||
|
// No request/response filtering is performed.
|
||||||
|
// Query log and Stats are not updated.
|
||||||
|
// This method may be called before Start().
|
||||||
|
func (s *Server) Exchange(req *dns.Msg) (*dns.Msg, error) {
|
||||||
|
s.RLock()
|
||||||
|
defer s.RUnlock()
|
||||||
|
|
||||||
|
ctx := &proxy.DNSContext{
|
||||||
|
Proto: "udp",
|
||||||
|
Req: req,
|
||||||
|
StartTime: time.Now(),
|
||||||
|
}
|
||||||
|
err := s.internalProxy.Resolve(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ctx.Res, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Start starts the DNS server
|
// Start starts the DNS server
|
||||||
func (s *Server) Start(config *ServerConfig) error {
|
func (s *Server) Start() error {
|
||||||
s.Lock()
|
s.Lock()
|
||||||
defer s.Unlock()
|
defer s.Unlock()
|
||||||
return s.startInternal(config)
|
return s.startInternal()
|
||||||
}
|
}
|
||||||
|
|
||||||
// startInternal starts without locking
|
// startInternal starts without locking
|
||||||
func (s *Server) startInternal(config *ServerConfig) error {
|
func (s *Server) startInternal() error {
|
||||||
err := s.prepare(config)
|
err := s.dnsProxy.Start()
|
||||||
if err != nil {
|
if err == nil {
|
||||||
return err
|
s.isRunning = true
|
||||||
}
|
}
|
||||||
return s.dnsProxy.Start()
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare the object
|
// Prepare the object
|
||||||
func (s *Server) prepare(config *ServerConfig) error {
|
func (s *Server) Prepare(config *ServerConfig) error {
|
||||||
if s.dnsProxy != nil {
|
|
||||||
return errors.New("DNS server is already started")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config != nil {
|
if config != nil {
|
||||||
s.conf = *config
|
s.conf = *config
|
||||||
}
|
}
|
||||||
|
@ -234,6 +265,14 @@ func (s *Server) prepare(config *ServerConfig) error {
|
||||||
EnableEDNSClientSubnet: s.conf.EnableEDNSClientSubnet,
|
EnableEDNSClientSubnet: s.conf.EnableEDNSClientSubnet,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
intlProxyConfig := proxy.Config{
|
||||||
|
CacheEnabled: true,
|
||||||
|
CacheSizeBytes: 4096,
|
||||||
|
Upstreams: s.conf.Upstreams,
|
||||||
|
DomainsReservedUpstreams: s.conf.DomainsReservedUpstreams,
|
||||||
|
}
|
||||||
|
s.internalProxy = &proxy.Proxy{Config: intlProxyConfig}
|
||||||
|
|
||||||
s.access = &accessCtx{}
|
s.access = &accessCtx{}
|
||||||
err = s.access.Init(s.conf.AllowedClients, s.conf.DisallowedClients, s.conf.BlockedHosts)
|
err = s.access.Init(s.conf.AllowedClients, s.conf.DisallowedClients, s.conf.BlockedHosts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -277,24 +316,20 @@ func (s *Server) Stop() error {
|
||||||
func (s *Server) stopInternal() error {
|
func (s *Server) stopInternal() error {
|
||||||
if s.dnsProxy != nil {
|
if s.dnsProxy != nil {
|
||||||
err := s.dnsProxy.Stop()
|
err := s.dnsProxy.Stop()
|
||||||
s.dnsProxy = nil
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorx.Decorate(err, "could not stop the DNS server properly")
|
return errorx.Decorate(err, "could not stop the DNS server properly")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.isRunning = false
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsRunning returns true if the DNS server is running
|
// IsRunning returns true if the DNS server is running
|
||||||
func (s *Server) IsRunning() bool {
|
func (s *Server) IsRunning() bool {
|
||||||
s.RLock()
|
s.RLock()
|
||||||
isRunning := true
|
defer s.RUnlock()
|
||||||
if s.dnsProxy == nil {
|
return s.isRunning
|
||||||
isRunning = false
|
|
||||||
}
|
|
||||||
s.RUnlock()
|
|
||||||
return isRunning
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Restart - restart server
|
// Restart - restart server
|
||||||
|
@ -306,7 +341,7 @@ func (s *Server) Restart() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorx.Decorate(err, "could not reconfigure the server")
|
return errorx.Decorate(err, "could not reconfigure the server")
|
||||||
}
|
}
|
||||||
err = s.startInternal(nil)
|
err = s.startInternal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorx.Decorate(err, "could not reconfigure the server")
|
return errorx.Decorate(err, "could not reconfigure the server")
|
||||||
}
|
}
|
||||||
|
@ -330,7 +365,12 @@ func (s *Server) Reconfigure(config *ServerConfig) error {
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.startInternal(config)
|
err = s.Prepare(config)
|
||||||
|
if err != nil {
|
||||||
|
return errorx.Decorate(err, "could not reconfigure the server")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.startInternal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorx.Decorate(err, "could not reconfigure the server")
|
return errorx.Decorate(err, "could not reconfigure the server")
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,7 +28,7 @@ const (
|
||||||
|
|
||||||
func TestServer(t *testing.T) {
|
func TestServer(t *testing.T) {
|
||||||
s := createTestServer(t)
|
s := createTestServer(t)
|
||||||
err := s.Start(nil)
|
err := s.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to start server: %s", err)
|
t.Fatalf("Failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -62,7 +62,7 @@ func TestServer(t *testing.T) {
|
||||||
func TestServerWithProtectionDisabled(t *testing.T) {
|
func TestServerWithProtectionDisabled(t *testing.T) {
|
||||||
s := createTestServer(t)
|
s := createTestServer(t)
|
||||||
s.conf.ProtectionEnabled = false
|
s.conf.ProtectionEnabled = false
|
||||||
err := s.Start(nil)
|
err := s.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to start server: %s", err)
|
t.Fatalf("Failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -94,8 +94,9 @@ func TestDotServer(t *testing.T) {
|
||||||
PrivateKeyData: keyPem,
|
PrivateKeyData: keyPem,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_ = s.Prepare(nil)
|
||||||
// Starting the server
|
// Starting the server
|
||||||
err := s.Start(nil)
|
err := s.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to start server: %s", err)
|
t.Fatalf("Failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -127,7 +128,7 @@ func TestDotServer(t *testing.T) {
|
||||||
|
|
||||||
func TestServerRace(t *testing.T) {
|
func TestServerRace(t *testing.T) {
|
||||||
s := createTestServer(t)
|
s := createTestServer(t)
|
||||||
err := s.Start(nil)
|
err := s.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to start server: %s", err)
|
t.Fatalf("Failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -150,7 +151,7 @@ func TestServerRace(t *testing.T) {
|
||||||
|
|
||||||
func TestSafeSearch(t *testing.T) {
|
func TestSafeSearch(t *testing.T) {
|
||||||
s := createTestServer(t)
|
s := createTestServer(t)
|
||||||
err := s.Start(nil)
|
err := s.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to start server: %s", err)
|
t.Fatalf("Failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -191,7 +192,7 @@ func TestSafeSearch(t *testing.T) {
|
||||||
|
|
||||||
func TestInvalidRequest(t *testing.T) {
|
func TestInvalidRequest(t *testing.T) {
|
||||||
s := createTestServer(t)
|
s := createTestServer(t)
|
||||||
err := s.Start(nil)
|
err := s.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to start server: %s", err)
|
t.Fatalf("Failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -217,7 +218,7 @@ func TestInvalidRequest(t *testing.T) {
|
||||||
|
|
||||||
func TestBlockedRequest(t *testing.T) {
|
func TestBlockedRequest(t *testing.T) {
|
||||||
s := createTestServer(t)
|
s := createTestServer(t)
|
||||||
err := s.Start(nil)
|
err := s.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to start server: %s", err)
|
t.Fatalf("Failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -319,7 +320,7 @@ func (u *testUpstream) Address() string {
|
||||||
func (s *Server) startWithUpstream(u upstream.Upstream) error {
|
func (s *Server) startWithUpstream(u upstream.Upstream) error {
|
||||||
s.Lock()
|
s.Lock()
|
||||||
defer s.Unlock()
|
defer s.Unlock()
|
||||||
err := s.prepare(nil)
|
err := s.Prepare(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -386,7 +387,7 @@ func TestBlockCNAME(t *testing.T) {
|
||||||
func TestNullBlockedRequest(t *testing.T) {
|
func TestNullBlockedRequest(t *testing.T) {
|
||||||
s := createTestServer(t)
|
s := createTestServer(t)
|
||||||
s.conf.FilteringConfig.BlockingMode = "null_ip"
|
s.conf.FilteringConfig.BlockingMode = "null_ip"
|
||||||
err := s.Start(nil)
|
err := s.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to start server: %s", err)
|
t.Fatalf("Failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -425,7 +426,7 @@ func TestNullBlockedRequest(t *testing.T) {
|
||||||
|
|
||||||
func TestBlockedByHosts(t *testing.T) {
|
func TestBlockedByHosts(t *testing.T) {
|
||||||
s := createTestServer(t)
|
s := createTestServer(t)
|
||||||
err := s.Start(nil)
|
err := s.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to start server: %s", err)
|
t.Fatalf("Failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -464,7 +465,7 @@ func TestBlockedByHosts(t *testing.T) {
|
||||||
|
|
||||||
func TestBlockedBySafeBrowsing(t *testing.T) {
|
func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||||
s := createTestServer(t)
|
s := createTestServer(t)
|
||||||
err := s.Start(nil)
|
err := s.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to start server: %s", err)
|
t.Fatalf("Failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -530,6 +531,8 @@ func createTestServer(t *testing.T) *Server {
|
||||||
s.conf.TCPListenAddr = &net.TCPAddr{Port: 0}
|
s.conf.TCPListenAddr = &net.TCPAddr{Port: 0}
|
||||||
s.conf.UpstreamDNS = []string{"8.8.8.8:53", "8.8.4.4:53"}
|
s.conf.UpstreamDNS = []string{"8.8.8.8:53", "8.8.4.4:53"}
|
||||||
s.conf.FilteringConfig.ProtectionEnabled = true
|
s.conf.FilteringConfig.ProtectionEnabled = true
|
||||||
|
err := s.Prepare(nil)
|
||||||
|
assert.True(t, err == nil)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome
|
||||||
go 1.13
|
go 1.13
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/AdguardTeam/dnsproxy v0.22.0
|
github.com/AdguardTeam/dnsproxy v0.23.0
|
||||||
github.com/AdguardTeam/golibs v0.3.0
|
github.com/AdguardTeam/golibs v0.3.0
|
||||||
github.com/AdguardTeam/urlfilter v0.7.0
|
github.com/AdguardTeam/urlfilter v0.7.0
|
||||||
github.com/NYTimes/gziphandler v1.1.1
|
github.com/NYTimes/gziphandler v1.1.1
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -1,5 +1,5 @@
|
||||||
github.com/AdguardTeam/dnsproxy v0.22.0 h1:8mpPu+KN0puFTHNhGy7XQ13fe3+3DGFaiwnqhNMWl+M=
|
github.com/AdguardTeam/dnsproxy v0.23.0 h1:GrOUapcWjf19MF8NznZUbcYujBbl7QXapBWTFKqkJQg=
|
||||||
github.com/AdguardTeam/dnsproxy v0.22.0/go.mod h1:2qy8rpdfBzKgMPxkHmPdaNK4XZJ322v4KtVGI8s8Bn0=
|
github.com/AdguardTeam/dnsproxy v0.23.0/go.mod h1:2qy8rpdfBzKgMPxkHmPdaNK4XZJ322v4KtVGI8s8Bn0=
|
||||||
github.com/AdguardTeam/golibs v0.2.4 h1:GUssokegKxKF13K67Pgl0ZGwqHjNN6X7sep5ik6ORdY=
|
github.com/AdguardTeam/golibs v0.2.4 h1:GUssokegKxKF13K67Pgl0ZGwqHjNN6X7sep5ik6ORdY=
|
||||||
github.com/AdguardTeam/golibs v0.2.4/go.mod h1:R3M+mAg3nWG4X4Hsag5eef/TckHFH12ZYhK7AzJc8+U=
|
github.com/AdguardTeam/golibs v0.2.4/go.mod h1:R3M+mAg3nWG4X4Hsag5eef/TckHFH12ZYhK7AzJc8+U=
|
||||||
github.com/AdguardTeam/golibs v0.3.0 h1:1zO8ulGEOdXDDM++Ap4sYfTsT/Z4tZBZtiWSA4ykcOU=
|
github.com/AdguardTeam/golibs v0.3.0 h1:1zO8ulGEOdXDDM++Ap4sYfTsT/Z4tZBZtiWSA4ykcOU=
|
||||||
|
|
|
@ -128,24 +128,30 @@ func (clients *clientsContainer) addFromConfig(objects []clientObject) {
|
||||||
|
|
||||||
// WriteDiskConfig - write configuration
|
// WriteDiskConfig - write configuration
|
||||||
func (clients *clientsContainer) WriteDiskConfig(objects *[]clientObject) {
|
func (clients *clientsContainer) WriteDiskConfig(objects *[]clientObject) {
|
||||||
clientsList := clients.GetList()
|
clients.lock.Lock()
|
||||||
for _, cli := range clientsList {
|
for _, cli := range clients.list {
|
||||||
cy := clientObject{
|
cy := clientObject{
|
||||||
Name: cli.Name,
|
Name: cli.Name,
|
||||||
IDs: cli.IDs,
|
UseGlobalSettings: !cli.UseOwnSettings,
|
||||||
UseGlobalSettings: !cli.UseOwnSettings,
|
FilteringEnabled: cli.FilteringEnabled,
|
||||||
FilteringEnabled: cli.FilteringEnabled,
|
ParentalEnabled: cli.ParentalEnabled,
|
||||||
ParentalEnabled: cli.ParentalEnabled,
|
SafeSearchEnabled: cli.SafeSearchEnabled,
|
||||||
SafeSearchEnabled: cli.SafeSearchEnabled,
|
SafeBrowsingEnabled: cli.SafeBrowsingEnabled,
|
||||||
SafeBrowsingEnabled: cli.SafeBrowsingEnabled,
|
|
||||||
|
|
||||||
UseGlobalBlockedServices: !cli.UseOwnBlockedServices,
|
UseGlobalBlockedServices: !cli.UseOwnBlockedServices,
|
||||||
BlockedServices: cli.BlockedServices,
|
|
||||||
|
|
||||||
Upstreams: cli.Upstreams,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cy.IDs = make([]string, len(cli.IDs))
|
||||||
|
copy(cy.IDs, cli.IDs)
|
||||||
|
|
||||||
|
cy.BlockedServices = make([]string, len(cli.BlockedServices))
|
||||||
|
copy(cy.BlockedServices, cli.BlockedServices)
|
||||||
|
|
||||||
|
cy.Upstreams = make([]string, len(cli.Upstreams))
|
||||||
|
copy(cy.Upstreams, cli.Upstreams)
|
||||||
|
|
||||||
*objects = append(*objects, cy)
|
*objects = append(*objects, cy)
|
||||||
}
|
}
|
||||||
|
clients.lock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (clients *clientsContainer) periodicUpdate() {
|
func (clients *clientsContainer) periodicUpdate() {
|
||||||
|
@ -157,11 +163,6 @@ func (clients *clientsContainer) periodicUpdate() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetList returns the pointer to clients list
|
|
||||||
func (clients *clientsContainer) GetList() map[string]*Client {
|
|
||||||
return clients.list
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exists checks if client with this IP already exists
|
// Exists checks if client with this IP already exists
|
||||||
func (clients *clientsContainer) Exists(ip string, source clientSource) bool {
|
func (clients *clientsContainer) Exists(ip string, source clientSource) bool {
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
|
|
|
@ -29,6 +29,7 @@ type logSettings struct {
|
||||||
Verbose bool `yaml:"verbose"` // If true, verbose logging is enabled
|
Verbose bool `yaml:"verbose"` // If true, verbose logging is enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HTTPSServer - HTTPS Server
|
||||||
type HTTPSServer struct {
|
type HTTPSServer struct {
|
||||||
server *http.Server
|
server *http.Server
|
||||||
cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey
|
cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey
|
||||||
|
@ -51,25 +52,15 @@ type configuration struct {
|
||||||
runningAsService bool
|
runningAsService bool
|
||||||
disableUpdate bool // If set, don't check for updates
|
disableUpdate bool // If set, don't check for updates
|
||||||
appSignalChannel chan os.Signal
|
appSignalChannel chan os.Signal
|
||||||
clients clientsContainer // per-client-settings module
|
|
||||||
controlLock sync.Mutex
|
controlLock sync.Mutex
|
||||||
transport *http.Transport
|
transport *http.Transport
|
||||||
client *http.Client
|
client *http.Client
|
||||||
stats stats.Stats // statistics module
|
auth *Auth // HTTP authentication module
|
||||||
queryLog querylog.QueryLog // query log module
|
|
||||||
auth *Auth // HTTP authentication module
|
|
||||||
|
|
||||||
// cached version.json to avoid hammering github.io for each page reload
|
// cached version.json to avoid hammering github.io for each page reload
|
||||||
versionCheckJSON []byte
|
versionCheckJSON []byte
|
||||||
versionCheckLastTime time.Time
|
versionCheckLastTime time.Time
|
||||||
|
|
||||||
dnsctx dnsContext
|
|
||||||
dnsFilter *dnsfilter.Dnsfilter
|
|
||||||
dnsServer *dnsforward.Server
|
|
||||||
dhcpServer *dhcpd.Server
|
|
||||||
httpServer *http.Server
|
|
||||||
httpsServer HTTPSServer
|
|
||||||
|
|
||||||
BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
|
BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
|
||||||
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
|
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
|
||||||
Users []User `yaml:"users"` // Users that can access HTTP server
|
Users []User `yaml:"users"` // Users that can access HTTP server
|
||||||
|
@ -296,41 +287,41 @@ func (c *configuration) write() error {
|
||||||
c.Lock()
|
c.Lock()
|
||||||
defer c.Unlock()
|
defer c.Unlock()
|
||||||
|
|
||||||
config.clients.WriteDiskConfig(&config.Clients)
|
Context.clients.WriteDiskConfig(&config.Clients)
|
||||||
|
|
||||||
if config.auth != nil {
|
if config.auth != nil {
|
||||||
config.Users = config.auth.GetUsers()
|
config.Users = config.auth.GetUsers()
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.stats != nil {
|
if Context.stats != nil {
|
||||||
sdc := stats.DiskConfig{}
|
sdc := stats.DiskConfig{}
|
||||||
config.stats.WriteDiskConfig(&sdc)
|
Context.stats.WriteDiskConfig(&sdc)
|
||||||
config.DNS.StatsInterval = sdc.Interval
|
config.DNS.StatsInterval = sdc.Interval
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.queryLog != nil {
|
if Context.queryLog != nil {
|
||||||
dc := querylog.DiskConfig{}
|
dc := querylog.DiskConfig{}
|
||||||
config.queryLog.WriteDiskConfig(&dc)
|
Context.queryLog.WriteDiskConfig(&dc)
|
||||||
config.DNS.QueryLogEnabled = dc.Enabled
|
config.DNS.QueryLogEnabled = dc.Enabled
|
||||||
config.DNS.QueryLogInterval = dc.Interval
|
config.DNS.QueryLogInterval = dc.Interval
|
||||||
config.DNS.QueryLogMemSize = dc.MemSize
|
config.DNS.QueryLogMemSize = dc.MemSize
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.dnsFilter != nil {
|
if Context.dnsFilter != nil {
|
||||||
c := dnsfilter.Config{}
|
c := dnsfilter.Config{}
|
||||||
config.dnsFilter.WriteDiskConfig(&c)
|
Context.dnsFilter.WriteDiskConfig(&c)
|
||||||
config.DNS.DnsfilterConf = c
|
config.DNS.DnsfilterConf = c
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.dnsServer != nil {
|
if Context.dnsServer != nil {
|
||||||
c := dnsforward.FilteringConfig{}
|
c := dnsforward.FilteringConfig{}
|
||||||
config.dnsServer.WriteDiskConfig(&c)
|
Context.dnsServer.WriteDiskConfig(&c)
|
||||||
config.DNS.FilteringConfig = c
|
config.DNS.FilteringConfig = c
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.dhcpServer != nil {
|
if Context.dhcpServer != nil {
|
||||||
c := dhcpd.ServerConfig{}
|
c := dhcpd.ServerConfig{}
|
||||||
config.dhcpServer.WriteDiskConfig(&c)
|
Context.dhcpServer.WriteDiskConfig(&c)
|
||||||
config.DHCP = c
|
config.DHCP = c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -93,8 +93,8 @@ func getDNSAddresses() []string {
|
||||||
|
|
||||||
func handleStatus(w http.ResponseWriter, r *http.Request) {
|
func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
c := dnsforward.FilteringConfig{}
|
c := dnsforward.FilteringConfig{}
|
||||||
if config.dnsServer != nil {
|
if Context.dnsServer != nil {
|
||||||
config.dnsServer.WriteDiskConfig(&c)
|
Context.dnsServer.WriteDiskConfig(&c)
|
||||||
}
|
}
|
||||||
data := map[string]interface{}{
|
data := map[string]interface{}{
|
||||||
"dns_addresses": getDNSAddresses(),
|
"dns_addresses": getDNSAddresses(),
|
||||||
|
@ -154,7 +154,7 @@ func handleDOH(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
config.dnsServer.ServeHTTP(w, r)
|
Context.dnsServer.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ------------------------
|
// ------------------------
|
||||||
|
|
|
@ -235,13 +235,19 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||||
config.DNS.BindHost = newSettings.DNS.IP
|
config.DNS.BindHost = newSettings.DNS.IP
|
||||||
config.DNS.Port = newSettings.DNS.Port
|
config.DNS.Port = newSettings.DNS.Port
|
||||||
|
|
||||||
initDNSServer()
|
err = initDNSServer()
|
||||||
|
var err2 error
|
||||||
err = startDNSServer()
|
if err == nil {
|
||||||
if err != nil {
|
err2 = startDNSServer()
|
||||||
|
}
|
||||||
|
if err != nil || err2 != nil {
|
||||||
config.firstRun = true
|
config.firstRun = true
|
||||||
copyInstallSettings(&config, &curConfig)
|
copyInstallSettings(&config, &curConfig)
|
||||||
httpError(w, http.StatusInternalServerError, "Couldn't start DNS server: %s", err)
|
if err != nil {
|
||||||
|
httpError(w, http.StatusInternalServerError, "Couldn't initialize DNS server: %s", err)
|
||||||
|
} else {
|
||||||
|
httpError(w, http.StatusInternalServerError, "Couldn't start DNS server: %s", err2)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -261,7 +267,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||||
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
|
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
|
||||||
if restartHTTP {
|
if restartHTTP {
|
||||||
go func() {
|
go func() {
|
||||||
_ = config.httpServer.Shutdown(context.TODO())
|
_ = Context.httpServer.Shutdown(context.TODO())
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -80,7 +80,7 @@ func handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||||
// check if port is available
|
// check if port is available
|
||||||
// BUT: if we are already using this port, no need
|
// BUT: if we are already using this port, no need
|
||||||
alreadyRunning := false
|
alreadyRunning := false
|
||||||
if config.httpsServer.server != nil {
|
if Context.httpsServer.server != nil {
|
||||||
alreadyRunning = true
|
alreadyRunning = true
|
||||||
}
|
}
|
||||||
if !alreadyRunning {
|
if !alreadyRunning {
|
||||||
|
@ -110,7 +110,7 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||||
// check if port is available
|
// check if port is available
|
||||||
// BUT: if we are already using this port, no need
|
// BUT: if we are already using this port, no need
|
||||||
alreadyRunning := false
|
alreadyRunning := false
|
||||||
if config.httpsServer.server != nil {
|
if Context.httpsServer.server != nil {
|
||||||
alreadyRunning = true
|
alreadyRunning = true
|
||||||
}
|
}
|
||||||
if !alreadyRunning {
|
if !alreadyRunning {
|
||||||
|
@ -145,12 +145,12 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||||
if restartHTTPS {
|
if restartHTTPS {
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(time.Second) // TODO: could not find a way to reliably know that data was fully sent to client by https server, so we wait a bit to let response through before closing the server
|
time.Sleep(time.Second) // TODO: could not find a way to reliably know that data was fully sent to client by https server, so we wait a bit to let response through before closing the server
|
||||||
config.httpsServer.cond.L.Lock()
|
Context.httpsServer.cond.L.Lock()
|
||||||
config.httpsServer.cond.Broadcast()
|
Context.httpsServer.cond.Broadcast()
|
||||||
if config.httpsServer.server != nil {
|
if Context.httpsServer.server != nil {
|
||||||
config.httpsServer.server.Shutdown(context.TODO())
|
Context.httpsServer.server.Shutdown(context.TODO())
|
||||||
}
|
}
|
||||||
config.httpsServer.cond.L.Unlock()
|
Context.httpsServer.cond.L.Unlock()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,12 +10,12 @@ func startDHCPServer() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := config.dhcpServer.Init(config.DHCP)
|
err := Context.dhcpServer.Init(config.DHCP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorx.Decorate(err, "Couldn't init DHCP server")
|
return errorx.Decorate(err, "Couldn't init DHCP server")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = config.dhcpServer.Start()
|
err = Context.dhcpServer.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorx.Decorate(err, "Couldn't start DHCP server")
|
return errorx.Decorate(err, "Couldn't start DHCP server")
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,7 @@ func stopDHCPServer() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := config.dhcpServer.Stop()
|
err := Context.dhcpServer.Stop()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorx.Decorate(err, "Couldn't stop DHCP server")
|
return errorx.Decorate(err, "Couldn't stop DHCP server")
|
||||||
}
|
}
|
||||||
|
|
83
home/dns.go
83
home/dns.go
|
@ -15,11 +15,6 @@ import (
|
||||||
"github.com/joomcode/errorx"
|
"github.com/joomcode/errorx"
|
||||||
)
|
)
|
||||||
|
|
||||||
type dnsContext struct {
|
|
||||||
rdns *RDNS
|
|
||||||
whois *Whois
|
|
||||||
}
|
|
||||||
|
|
||||||
// Called by other modules when configuration is changed
|
// Called by other modules when configuration is changed
|
||||||
func onConfigModified() {
|
func onConfigModified() {
|
||||||
_ = config.write()
|
_ = config.write()
|
||||||
|
@ -28,12 +23,12 @@ func onConfigModified() {
|
||||||
// initDNSServer creates an instance of the dnsforward.Server
|
// initDNSServer creates an instance of the dnsforward.Server
|
||||||
// Please note that we must do it even if we don't start it
|
// Please note that we must do it even if we don't start it
|
||||||
// so that we had access to the query log and the stats
|
// so that we had access to the query log and the stats
|
||||||
func initDNSServer() {
|
func initDNSServer() error {
|
||||||
baseDir := config.getDataDir()
|
baseDir := config.getDataDir()
|
||||||
|
|
||||||
err := os.MkdirAll(baseDir, 0755)
|
err := os.MkdirAll(baseDir, 0755)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Cannot create DNS data dir at %s: %s", baseDir, err)
|
return fmt.Errorf("Cannot create DNS data dir at %s: %s", baseDir, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
statsConf := stats.Config{
|
statsConf := stats.Config{
|
||||||
|
@ -42,9 +37,9 @@ func initDNSServer() {
|
||||||
ConfigModified: onConfigModified,
|
ConfigModified: onConfigModified,
|
||||||
HTTPRegister: httpRegister,
|
HTTPRegister: httpRegister,
|
||||||
}
|
}
|
||||||
config.stats, err = stats.New(statsConf)
|
Context.stats, err = stats.New(statsConf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("Couldn't initialize statistics module")
|
return fmt.Errorf("Couldn't initialize statistics module")
|
||||||
}
|
}
|
||||||
conf := querylog.Config{
|
conf := querylog.Config{
|
||||||
Enabled: config.DNS.QueryLogEnabled,
|
Enabled: config.DNS.QueryLogEnabled,
|
||||||
|
@ -54,7 +49,7 @@ func initDNSServer() {
|
||||||
ConfigModified: onConfigModified,
|
ConfigModified: onConfigModified,
|
||||||
HTTPRegister: httpRegister,
|
HTTPRegister: httpRegister,
|
||||||
}
|
}
|
||||||
config.queryLog = querylog.New(conf)
|
Context.queryLog = querylog.New(conf)
|
||||||
|
|
||||||
filterConf := config.DNS.DnsfilterConf
|
filterConf := config.DNS.DnsfilterConf
|
||||||
bindhost := config.DNS.BindHost
|
bindhost := config.DNS.BindHost
|
||||||
|
@ -64,22 +59,28 @@ func initDNSServer() {
|
||||||
filterConf.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
|
filterConf.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
|
||||||
filterConf.ConfigModified = onConfigModified
|
filterConf.ConfigModified = onConfigModified
|
||||||
filterConf.HTTPRegister = httpRegister
|
filterConf.HTTPRegister = httpRegister
|
||||||
config.dnsFilter = dnsfilter.New(&filterConf, nil)
|
Context.dnsFilter = dnsfilter.New(&filterConf, nil)
|
||||||
|
|
||||||
config.dnsServer = dnsforward.NewServer(config.dnsFilter, config.stats, config.queryLog)
|
Context.dnsServer = dnsforward.NewServer(Context.dnsFilter, Context.stats, Context.queryLog)
|
||||||
|
dnsConfig := generateServerConfig()
|
||||||
|
err = Context.dnsServer.Prepare(&dnsConfig)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("dnsServer.Prepare: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
sessFilename := filepath.Join(baseDir, "sessions.db")
|
sessFilename := filepath.Join(baseDir, "sessions.db")
|
||||||
config.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60)
|
config.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60)
|
||||||
config.Users = nil
|
config.Users = nil
|
||||||
|
|
||||||
config.dnsctx.rdns = InitRDNS(&config.clients)
|
Context.rdns = InitRDNS(Context.dnsServer, &Context.clients)
|
||||||
config.dnsctx.whois = initWhois(&config.clients)
|
Context.whois = initWhois(&Context.clients)
|
||||||
|
|
||||||
initFiltering()
|
initFiltering()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isRunning() bool {
|
func isRunning() bool {
|
||||||
return config.dnsServer != nil && config.dnsServer.IsRunning()
|
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint (gocyclo)
|
// nolint (gocyclo)
|
||||||
|
@ -145,14 +146,14 @@ func onDNSRequest(d *proxy.DNSContext) {
|
||||||
|
|
||||||
ipAddr := net.ParseIP(ip)
|
ipAddr := net.ParseIP(ip)
|
||||||
if !ipAddr.IsLoopback() {
|
if !ipAddr.IsLoopback() {
|
||||||
config.dnsctx.rdns.Begin(ip)
|
Context.rdns.Begin(ip)
|
||||||
}
|
}
|
||||||
if isPublicIP(ipAddr) {
|
if isPublicIP(ipAddr) {
|
||||||
config.dnsctx.whois.Begin(ip)
|
Context.whois.Begin(ip)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateServerConfig() (dnsforward.ServerConfig, error) {
|
func generateServerConfig() dnsforward.ServerConfig {
|
||||||
newconfig := dnsforward.ServerConfig{
|
newconfig := dnsforward.ServerConfig{
|
||||||
UDPListenAddr: &net.UDPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port},
|
UDPListenAddr: &net.UDPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port},
|
||||||
TCPListenAddr: &net.TCPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port},
|
TCPListenAddr: &net.TCPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port},
|
||||||
|
@ -171,11 +172,11 @@ func generateServerConfig() (dnsforward.ServerConfig, error) {
|
||||||
|
|
||||||
newconfig.FilterHandler = applyAdditionalFiltering
|
newconfig.FilterHandler = applyAdditionalFiltering
|
||||||
newconfig.GetUpstreamsByClient = getUpstreamsByClient
|
newconfig.GetUpstreamsByClient = getUpstreamsByClient
|
||||||
return newconfig, nil
|
return newconfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func getUpstreamsByClient(clientAddr string) []string {
|
func getUpstreamsByClient(clientAddr string) []string {
|
||||||
c, ok := config.clients.Find(clientAddr)
|
c, ok := Context.clients.Find(clientAddr)
|
||||||
if !ok {
|
if !ok {
|
||||||
return []string{}
|
return []string{}
|
||||||
}
|
}
|
||||||
|
@ -192,7 +193,7 @@ func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteri
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c, ok := config.clients.Find(clientAddr)
|
c, ok := Context.clients.Find(clientAddr)
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -220,12 +221,7 @@ func startDNSServer() error {
|
||||||
|
|
||||||
enableFilters(false)
|
enableFilters(false)
|
||||||
|
|
||||||
newconfig, err := generateServerConfig()
|
err := Context.dnsServer.Start()
|
||||||
if err != nil {
|
|
||||||
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = config.dnsServer.Start(&newconfig)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
||||||
}
|
}
|
||||||
|
@ -233,14 +229,14 @@ func startDNSServer() error {
|
||||||
startFiltering()
|
startFiltering()
|
||||||
|
|
||||||
const topClientsNumber = 100 // the number of clients to get
|
const topClientsNumber = 100 // the number of clients to get
|
||||||
topClients := config.stats.GetTopClientsIP(topClientsNumber)
|
topClients := Context.stats.GetTopClientsIP(topClientsNumber)
|
||||||
for _, ip := range topClients {
|
for _, ip := range topClients {
|
||||||
ipAddr := net.ParseIP(ip)
|
ipAddr := net.ParseIP(ip)
|
||||||
if !ipAddr.IsLoopback() {
|
if !ipAddr.IsLoopback() {
|
||||||
config.dnsctx.rdns.Begin(ip)
|
Context.rdns.Begin(ip)
|
||||||
}
|
}
|
||||||
if isPublicIP(ipAddr) {
|
if isPublicIP(ipAddr) {
|
||||||
config.dnsctx.whois.Begin(ip)
|
Context.whois.Begin(ip)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -248,11 +244,8 @@ func startDNSServer() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func reconfigureDNSServer() error {
|
func reconfigureDNSServer() error {
|
||||||
newconfig, err := generateServerConfig()
|
newconfig := generateServerConfig()
|
||||||
if err != nil {
|
err := Context.dnsServer.Reconfigure(&newconfig)
|
||||||
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
|
||||||
}
|
|
||||||
err = config.dnsServer.Reconfigure(&newconfig)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
||||||
}
|
}
|
||||||
|
@ -261,26 +254,22 @@ func reconfigureDNSServer() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func stopDNSServer() error {
|
func stopDNSServer() error {
|
||||||
if !isRunning() {
|
err := Context.dnsServer.Stop()
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err := config.dnsServer.Stop()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorx.Decorate(err, "Couldn't stop forwarding DNS server")
|
return errorx.Decorate(err, "Couldn't stop forwarding DNS server")
|
||||||
}
|
}
|
||||||
|
|
||||||
// DNS forward module must be closed BEFORE stats or queryLog because it depends on them
|
// DNS forward module must be closed BEFORE stats or queryLog because it depends on them
|
||||||
config.dnsServer.Close()
|
Context.dnsServer.Close()
|
||||||
|
|
||||||
config.dnsFilter.Close()
|
Context.dnsFilter.Close()
|
||||||
config.dnsFilter = nil
|
Context.dnsFilter = nil
|
||||||
|
|
||||||
config.stats.Close()
|
Context.stats.Close()
|
||||||
config.stats = nil
|
Context.stats = nil
|
||||||
|
|
||||||
config.queryLog.Close()
|
Context.queryLog.Close()
|
||||||
config.queryLog = nil
|
Context.queryLog = nil
|
||||||
|
|
||||||
config.auth.Close()
|
config.auth.Close()
|
||||||
config.auth = nil
|
config.auth = nil
|
||||||
|
|
|
@ -1,17 +0,0 @@
|
||||||
package home
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestResolveRDNS(t *testing.T) {
|
|
||||||
_ = os.RemoveAll(config.getDataDir())
|
|
||||||
defer func() { _ = os.RemoveAll(config.getDataDir()) }()
|
|
||||||
|
|
||||||
config.DNS.BindHost = "1.1.1.1"
|
|
||||||
initDNSServer()
|
|
||||||
if r := config.dnsctx.rdns.resolve("1.1.1.1"); r != "one.one.one.one" {
|
|
||||||
t.Errorf("resolveRDNS(): %s", r)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -514,5 +514,5 @@ func enableFilters(async bool) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = config.dnsFilter.SetFilters(filters, async)
|
_ = Context.dnsFilter.SetFilters(filters, async)
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,6 @@ import (
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/joomcode/errorx"
|
"github.com/joomcode/errorx"
|
||||||
)
|
)
|
||||||
|
@ -118,7 +117,7 @@ func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.Res
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// enforce https?
|
// enforce https?
|
||||||
if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && config.httpsServer.server != nil {
|
if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && Context.httpsServer.server != nil {
|
||||||
// yes, and we want host from host:port
|
// yes, and we want host from host:port
|
||||||
host, _, err := net.SplitHostPort(r.Host)
|
host, _, err := net.SplitHostPort(r.Host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -273,14 +272,8 @@ func customDialContext(ctx context.Context, network, addr string) (net.Conn, err
|
||||||
return con, err
|
return con, err
|
||||||
}
|
}
|
||||||
|
|
||||||
bindhost := config.DNS.BindHost
|
addrs, e := Context.dnsServer.Resolve(host)
|
||||||
if config.DNS.BindHost == "0.0.0.0" {
|
log.Debug("dnsServer.Resolve: %s: %v", host, addrs)
|
||||||
bindhost = "127.0.0.1"
|
|
||||||
}
|
|
||||||
resolverAddr := fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
|
|
||||||
r := upstream.NewResolver(resolverAddr, 30*time.Second)
|
|
||||||
addrs, e := r.LookupIPAddr(ctx, host)
|
|
||||||
log.Tracef("LookupIPAddr: %s: %v", host, addrs)
|
|
||||||
if e != nil {
|
if e != nil {
|
||||||
return nil, e
|
return nil, e
|
||||||
}
|
}
|
||||||
|
|
63
home/home.go
63
home/home.go
|
@ -21,6 +21,10 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/querylog"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/stats"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/NYTimes/gziphandler"
|
"github.com/NYTimes/gziphandler"
|
||||||
"github.com/gobuffalo/packr"
|
"github.com/gobuffalo/packr"
|
||||||
|
@ -40,6 +44,23 @@ var (
|
||||||
|
|
||||||
const versionCheckPeriod = time.Hour * 8
|
const versionCheckPeriod = time.Hour * 8
|
||||||
|
|
||||||
|
// Global context
|
||||||
|
type homeContext struct {
|
||||||
|
clients clientsContainer // per-client-settings module
|
||||||
|
stats stats.Stats // statistics module
|
||||||
|
queryLog querylog.QueryLog // query log module
|
||||||
|
dnsServer *dnsforward.Server // DNS module
|
||||||
|
rdns *RDNS // rDNS module
|
||||||
|
whois *Whois // WHOIS module
|
||||||
|
dnsFilter *dnsfilter.Dnsfilter // DNS filtering module
|
||||||
|
dhcpServer *dhcpd.Server // DHCP module
|
||||||
|
httpServer *http.Server // HTTP module
|
||||||
|
httpsServer HTTPSServer // HTTPS module
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context - a global context object
|
||||||
|
var Context homeContext
|
||||||
|
|
||||||
// Main is the entry point
|
// Main is the entry point
|
||||||
func Main(version string, channel string) {
|
func Main(version string, channel string) {
|
||||||
// Init update-related global variables
|
// Init update-related global variables
|
||||||
|
@ -122,8 +143,8 @@ func run(args options) {
|
||||||
config.DHCP.WorkDir = config.ourWorkingDir
|
config.DHCP.WorkDir = config.ourWorkingDir
|
||||||
config.DHCP.HTTPRegister = httpRegister
|
config.DHCP.HTTPRegister = httpRegister
|
||||||
config.DHCP.ConfigModified = onConfigModified
|
config.DHCP.ConfigModified = onConfigModified
|
||||||
config.dhcpServer = dhcpd.Create(config.DHCP)
|
Context.dhcpServer = dhcpd.Create(config.DHCP)
|
||||||
config.clients.Init(config.Clients, config.dhcpServer)
|
Context.clients.Init(config.Clients, Context.dhcpServer)
|
||||||
config.Clients = nil
|
config.Clients = nil
|
||||||
|
|
||||||
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
|
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
|
||||||
|
@ -146,7 +167,10 @@ func run(args options) {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
initDNSServer()
|
err = initDNSServer()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("%s", err)
|
||||||
|
}
|
||||||
go func() {
|
go func() {
|
||||||
err = startDNSServer()
|
err = startDNSServer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -178,21 +202,21 @@ func run(args options) {
|
||||||
registerInstallHandlers()
|
registerInstallHandlers()
|
||||||
}
|
}
|
||||||
|
|
||||||
config.httpsServer.cond = sync.NewCond(&config.httpsServer.Mutex)
|
Context.httpsServer.cond = sync.NewCond(&Context.httpsServer.Mutex)
|
||||||
|
|
||||||
// for https, we have a separate goroutine loop
|
// for https, we have a separate goroutine loop
|
||||||
go httpServerLoop()
|
go httpServerLoop()
|
||||||
|
|
||||||
// this loop is used as an ability to change listening host and/or port
|
// this loop is used as an ability to change listening host and/or port
|
||||||
for !config.httpsServer.shutdown {
|
for !Context.httpsServer.shutdown {
|
||||||
printHTTPAddresses("http")
|
printHTTPAddresses("http")
|
||||||
|
|
||||||
// we need to have new instance, because after Shutdown() the Server is not usable
|
// we need to have new instance, because after Shutdown() the Server is not usable
|
||||||
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
|
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
|
||||||
config.httpServer = &http.Server{
|
Context.httpServer = &http.Server{
|
||||||
Addr: address,
|
Addr: address,
|
||||||
}
|
}
|
||||||
err := config.httpServer.ListenAndServe()
|
err := Context.httpServer.ListenAndServe()
|
||||||
if err != http.ErrServerClosed {
|
if err != http.ErrServerClosed {
|
||||||
cleanupAlways()
|
cleanupAlways()
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
|
@ -205,14 +229,14 @@ func run(args options) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func httpServerLoop() {
|
func httpServerLoop() {
|
||||||
for !config.httpsServer.shutdown {
|
for !Context.httpsServer.shutdown {
|
||||||
config.httpsServer.cond.L.Lock()
|
Context.httpsServer.cond.L.Lock()
|
||||||
// this mechanism doesn't let us through until all conditions are met
|
// this mechanism doesn't let us through until all conditions are met
|
||||||
for config.TLS.Enabled == false ||
|
for config.TLS.Enabled == false ||
|
||||||
config.TLS.PortHTTPS == 0 ||
|
config.TLS.PortHTTPS == 0 ||
|
||||||
len(config.TLS.PrivateKeyData) == 0 ||
|
len(config.TLS.PrivateKeyData) == 0 ||
|
||||||
len(config.TLS.CertificateChainData) == 0 { // sleep until necessary data is supplied
|
len(config.TLS.CertificateChainData) == 0 { // sleep until necessary data is supplied
|
||||||
config.httpsServer.cond.Wait()
|
Context.httpsServer.cond.Wait()
|
||||||
}
|
}
|
||||||
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.TLS.PortHTTPS))
|
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.TLS.PortHTTPS))
|
||||||
// validate current TLS config and update warnings (it could have been loaded from file)
|
// validate current TLS config and update warnings (it could have been loaded from file)
|
||||||
|
@ -236,10 +260,10 @@ func httpServerLoop() {
|
||||||
cleanupAlways()
|
cleanupAlways()
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
config.httpsServer.cond.L.Unlock()
|
Context.httpsServer.cond.L.Unlock()
|
||||||
|
|
||||||
// prepare HTTPS server
|
// prepare HTTPS server
|
||||||
config.httpsServer.server = &http.Server{
|
Context.httpsServer.server = &http.Server{
|
||||||
Addr: address,
|
Addr: address,
|
||||||
TLSConfig: &tls.Config{
|
TLSConfig: &tls.Config{
|
||||||
Certificates: []tls.Certificate{cert},
|
Certificates: []tls.Certificate{cert},
|
||||||
|
@ -248,7 +272,7 @@ func httpServerLoop() {
|
||||||
}
|
}
|
||||||
|
|
||||||
printHTTPAddresses("https")
|
printHTTPAddresses("https")
|
||||||
err = config.httpsServer.server.ListenAndServeTLS("", "")
|
err = Context.httpsServer.server.ListenAndServeTLS("", "")
|
||||||
if err != http.ErrServerClosed {
|
if err != http.ErrServerClosed {
|
||||||
cleanupAlways()
|
cleanupAlways()
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
|
@ -326,11 +350,10 @@ func configureLogger(args options) {
|
||||||
ls.LogFile = args.logFile
|
ls.LogFile = args.logFile
|
||||||
}
|
}
|
||||||
|
|
||||||
level := log.INFO
|
// log.SetLevel(log.INFO) - default
|
||||||
if ls.Verbose {
|
if ls.Verbose {
|
||||||
level = log.DEBUG
|
log.SetLevel(log.DEBUG)
|
||||||
}
|
}
|
||||||
log.SetLevel(level)
|
|
||||||
|
|
||||||
if args.runningAsService && ls.LogFile == "" && runtime.GOOS == "windows" {
|
if args.runningAsService && ls.LogFile == "" && runtime.GOOS == "windows" {
|
||||||
// When running as a Windows service, use eventlog by default if nothing else is configured
|
// When running as a Windows service, use eventlog by default if nothing else is configured
|
||||||
|
@ -378,11 +401,11 @@ func cleanup() {
|
||||||
// Stop HTTP server, possibly waiting for all active connections to be closed
|
// Stop HTTP server, possibly waiting for all active connections to be closed
|
||||||
func stopHTTPServer() {
|
func stopHTTPServer() {
|
||||||
log.Info("Stopping HTTP server...")
|
log.Info("Stopping HTTP server...")
|
||||||
config.httpsServer.shutdown = true
|
Context.httpsServer.shutdown = true
|
||||||
if config.httpsServer.server != nil {
|
if Context.httpsServer.server != nil {
|
||||||
config.httpsServer.server.Shutdown(context.TODO())
|
Context.httpsServer.server.Shutdown(context.TODO())
|
||||||
}
|
}
|
||||||
config.httpServer.Shutdown(context.TODO())
|
Context.httpServer.Shutdown(context.TODO())
|
||||||
log.Info("Stopped HTTP server")
|
log.Info("Stopped HTTP server")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
154
home/home_test.go
Normal file
154
home/home_test.go
Normal file
|
@ -0,0 +1,154 @@
|
||||||
|
package home
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
const yamlConf = `bind_host: 127.0.0.1
|
||||||
|
bind_port: 3000
|
||||||
|
users: []
|
||||||
|
language: en
|
||||||
|
rlimit_nofile: 0
|
||||||
|
web_session_ttl: 720
|
||||||
|
dns:
|
||||||
|
bind_host: 127.0.0.1
|
||||||
|
port: 5354
|
||||||
|
statistics_interval: 90
|
||||||
|
querylog_enabled: true
|
||||||
|
querylog_interval: 90
|
||||||
|
querylog_memsize: 0
|
||||||
|
protection_enabled: true
|
||||||
|
blocking_mode: null_ip
|
||||||
|
blocked_response_ttl: 0
|
||||||
|
ratelimit: 100
|
||||||
|
ratelimit_whitelist: []
|
||||||
|
refuse_any: false
|
||||||
|
bootstrap_dns:
|
||||||
|
- 1.1.1.1:53
|
||||||
|
all_servers: false
|
||||||
|
allowed_clients: []
|
||||||
|
disallowed_clients: []
|
||||||
|
blocked_hosts: []
|
||||||
|
parental_block_host: family-block.dns.adguard.com
|
||||||
|
safebrowsing_block_host: standard-block.dns.adguard.com
|
||||||
|
cache_size: 0
|
||||||
|
upstream_dns:
|
||||||
|
- https://1.1.1.1/dns-query
|
||||||
|
filtering_enabled: true
|
||||||
|
filters_update_interval: 168
|
||||||
|
parental_sensitivity: 13
|
||||||
|
parental_enabled: true
|
||||||
|
safesearch_enabled: false
|
||||||
|
safebrowsing_enabled: false
|
||||||
|
safebrowsing_cache_size: 1048576
|
||||||
|
safesearch_cache_size: 1048576
|
||||||
|
parental_cache_size: 1048576
|
||||||
|
cache_time: 30
|
||||||
|
rewrites: []
|
||||||
|
blocked_services: []
|
||||||
|
tls:
|
||||||
|
enabled: false
|
||||||
|
server_name: www.example.com
|
||||||
|
force_https: false
|
||||||
|
port_https: 443
|
||||||
|
port_dns_over_tls: 853
|
||||||
|
certificate_chain: ""
|
||||||
|
private_key: ""
|
||||||
|
certificate_path: ""
|
||||||
|
private_key_path: ""
|
||||||
|
filters:
|
||||||
|
- enabled: true
|
||||||
|
url: https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt
|
||||||
|
name: AdGuard Simplified Domain Names filter
|
||||||
|
id: 1
|
||||||
|
- enabled: false
|
||||||
|
url: https://hosts-file.net/ad_servers.txt
|
||||||
|
name: hpHosts - Ad and Tracking servers only
|
||||||
|
id: 2
|
||||||
|
- enabled: false
|
||||||
|
url: https://adaway.org/hosts.txt
|
||||||
|
name: adaway
|
||||||
|
id: 3
|
||||||
|
user_rules:
|
||||||
|
- ""
|
||||||
|
dhcp:
|
||||||
|
enabled: false
|
||||||
|
interface_name: ""
|
||||||
|
gateway_ip: ""
|
||||||
|
subnet_mask: ""
|
||||||
|
range_start: ""
|
||||||
|
range_end: ""
|
||||||
|
lease_duration: 86400
|
||||||
|
icmp_timeout_msec: 1000
|
||||||
|
clients: []
|
||||||
|
log_file: ""
|
||||||
|
verbose: false
|
||||||
|
schema_version: 5
|
||||||
|
`
|
||||||
|
|
||||||
|
// . Create a configuration file
|
||||||
|
// . Start AGH instance
|
||||||
|
// . Check Web server
|
||||||
|
// . Check DNS server
|
||||||
|
// . Wait until the filters are downloaded
|
||||||
|
// . Stop and cleanup
|
||||||
|
func TestHome(t *testing.T) {
|
||||||
|
dir := prepareTestDir()
|
||||||
|
defer func() { _ = os.RemoveAll(dir) }()
|
||||||
|
fn := filepath.Join(dir, "AdGuardHome.yaml")
|
||||||
|
|
||||||
|
assert.True(t, ioutil.WriteFile(fn, []byte(yamlConf), 0644) == nil)
|
||||||
|
fn, _ = filepath.Abs(fn)
|
||||||
|
|
||||||
|
args := options{}
|
||||||
|
args.configFilename = fn
|
||||||
|
args.workDir = dir
|
||||||
|
go run(args)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
var resp *http.Response
|
||||||
|
h := http.Client{}
|
||||||
|
for i := 0; i != 5; i++ {
|
||||||
|
resp, err = h.Get("http://127.0.0.1:3000/")
|
||||||
|
if err == nil && resp.StatusCode != 404 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}
|
||||||
|
assert.Truef(t, err == nil, "%s", err)
|
||||||
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
resp, err = h.Get("http://127.0.0.1:3000/control/status")
|
||||||
|
assert.Truef(t, err == nil, "%s", err)
|
||||||
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
r := upstream.NewResolver("127.0.0.1:5354", 3*time.Second)
|
||||||
|
addrs, err := r.LookupIPAddr(context.TODO(), "static.adguard.com")
|
||||||
|
assert.Truef(t, err == nil, "%s", err)
|
||||||
|
haveIP := len(addrs) != 0
|
||||||
|
assert.True(t, haveIP)
|
||||||
|
|
||||||
|
for i := 1; ; i++ {
|
||||||
|
st, err := os.Stat(filepath.Join(dir, "data", "filters", "1.txt"))
|
||||||
|
if err == nil && st.Size() != 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if i == 5 {
|
||||||
|
assert.True(t, false)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup()
|
||||||
|
cleanupAlways()
|
||||||
|
}
|
34
home/rdns.go
34
home/rdns.go
|
@ -2,25 +2,20 @@ package home
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||||
"github.com/AdguardTeam/golibs/cache"
|
"github.com/AdguardTeam/golibs/cache"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
rdnsTimeout = 3 * time.Second // max time to wait for rDNS response
|
|
||||||
)
|
|
||||||
|
|
||||||
// RDNS - module context
|
// RDNS - module context
|
||||||
type RDNS struct {
|
type RDNS struct {
|
||||||
|
dnsServer *dnsforward.Server
|
||||||
clients *clientsContainer
|
clients *clientsContainer
|
||||||
ipChannel chan string // pass data from DNS request handling thread to rDNS thread
|
ipChannel chan string // pass data from DNS request handling thread to rDNS thread
|
||||||
upstream upstream.Upstream // Upstream object for our own DNS server
|
|
||||||
|
|
||||||
// Contains IP addresses of clients to be resolved by rDNS
|
// Contains IP addresses of clients to be resolved by rDNS
|
||||||
// If IP address is resolved, it stays here while it's inside Clients.
|
// If IP address is resolved, it stays here while it's inside Clients.
|
||||||
|
@ -30,25 +25,10 @@ type RDNS struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitRDNS - create module context
|
// InitRDNS - create module context
|
||||||
func InitRDNS(clients *clientsContainer) *RDNS {
|
func InitRDNS(dnsServer *dnsforward.Server, clients *clientsContainer) *RDNS {
|
||||||
r := RDNS{}
|
r := RDNS{}
|
||||||
|
r.dnsServer = dnsServer
|
||||||
r.clients = clients
|
r.clients = clients
|
||||||
var err error
|
|
||||||
|
|
||||||
bindhost := config.DNS.BindHost
|
|
||||||
if config.DNS.BindHost == "0.0.0.0" {
|
|
||||||
bindhost = "127.0.0.1"
|
|
||||||
}
|
|
||||||
resolverAddress := fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
|
|
||||||
|
|
||||||
opts := upstream.Options{
|
|
||||||
Timeout: rdnsTimeout,
|
|
||||||
}
|
|
||||||
r.upstream, err = upstream.AddressToUpstream(resolverAddress, opts)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("upstream.AddressToUpstream: %s", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cconf := cache.Config{}
|
cconf := cache.Config{}
|
||||||
cconf.EnableLRU = true
|
cconf.EnableLRU = true
|
||||||
|
@ -109,7 +89,7 @@ func (r *RDNS) resolve(ip string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := r.upstream.Exchange(&req)
|
resp, err := r.dnsServer.Exchange(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug("Error while making an rDNS lookup for %s: %s", ip, err)
|
log.Debug("Error while making an rDNS lookup for %s: %s", ip, err)
|
||||||
return ""
|
return ""
|
||||||
|
@ -144,6 +124,6 @@ func (r *RDNS) workerLoop() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _ = config.clients.AddHost(ip, host, ClientSourceRDNS)
|
_, _ = r.clients.AddHost(ip, host, ClientSourceRDNS)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
21
home/rdns_test.go
Normal file
21
home/rdns_test.go
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
package home
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResolveRDNS(t *testing.T) {
|
||||||
|
dns := &dnsforward.Server{}
|
||||||
|
conf := &dnsforward.ServerConfig{}
|
||||||
|
conf.UpstreamDNS = []string{"8.8.8.8"}
|
||||||
|
err := dns.Prepare(conf)
|
||||||
|
assert.True(t, err == nil, "%s", err)
|
||||||
|
|
||||||
|
clients := &clientsContainer{}
|
||||||
|
rdns := InitRDNS(dns, clients)
|
||||||
|
r := rdns.resolve("1.1.1.1")
|
||||||
|
assert.True(t, r == "one.one.one.one", "%s", r)
|
||||||
|
}
|
|
@ -25,6 +25,8 @@ type Config struct {
|
||||||
|
|
||||||
// Register an HTTP handler
|
// Register an HTTP handler
|
||||||
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request))
|
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request))
|
||||||
|
|
||||||
|
limit uint32 // maximum time we need to keep data for (in hours)
|
||||||
}
|
}
|
||||||
|
|
||||||
// New - create object
|
// New - create object
|
||||||
|
|
|
@ -21,13 +21,8 @@ func httpError(r *http.Request, w http.ResponseWriter, code int, format string,
|
||||||
|
|
||||||
// Return data
|
// Return data
|
||||||
func (s *statsCtx) handleStats(w http.ResponseWriter, r *http.Request) {
|
func (s *statsCtx) handleStats(w http.ResponseWriter, r *http.Request) {
|
||||||
units := Hours
|
|
||||||
if s.limit/24 > 7 {
|
|
||||||
units = Days
|
|
||||||
}
|
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
d := s.getData(units)
|
d := s.getData()
|
||||||
log.Debug("Stats: prepared data in %v", time.Since(start))
|
log.Debug("Stats: prepared data in %v", time.Since(start))
|
||||||
|
|
||||||
if d == nil {
|
if d == nil {
|
||||||
|
@ -52,7 +47,7 @@ type config struct {
|
||||||
// Get configuration
|
// Get configuration
|
||||||
func (s *statsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) {
|
func (s *statsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) {
|
||||||
resp := config{}
|
resp := config{}
|
||||||
resp.IntervalDays = s.limit / 24
|
resp.IntervalDays = s.conf.limit / 24
|
||||||
|
|
||||||
data, err := json.Marshal(resp)
|
data, err := json.Marshal(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -45,7 +45,7 @@ func TestStats(t *testing.T) {
|
||||||
e.Time = 123456
|
e.Time = 123456
|
||||||
s.Update(e)
|
s.Update(e)
|
||||||
|
|
||||||
d := s.getData(Hours)
|
d := s.getData()
|
||||||
a := []uint64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
|
a := []uint64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
|
||||||
assert.True(t, UIntArrayEquals(d["dns_queries"].([]uint64), a))
|
assert.True(t, UIntArrayEquals(d["dns_queries"].([]uint64), a))
|
||||||
|
|
||||||
|
@ -116,7 +116,7 @@ func TestLargeNumbers(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
d := s.getData(Hours)
|
d := s.getData()
|
||||||
assert.True(t, d["num_dns_queries"].(uint64) == uint64(int(hour)*n))
|
assert.True(t, d["num_dns_queries"].(uint64) == uint64(int(hour)*n))
|
||||||
|
|
||||||
s.Close()
|
s.Close()
|
||||||
|
|
|
@ -21,10 +21,8 @@ const (
|
||||||
|
|
||||||
// statsCtx - global context
|
// statsCtx - global context
|
||||||
type statsCtx struct {
|
type statsCtx struct {
|
||||||
limit uint32 // maximum time we need to keep data for (in hours)
|
db *bolt.DB
|
||||||
db *bolt.DB
|
conf *Config
|
||||||
|
|
||||||
conf Config
|
|
||||||
|
|
||||||
unit *unit // the current unit
|
unit *unit // the current unit
|
||||||
unitLock sync.Mutex // protect 'unit'
|
unitLock sync.Mutex // protect 'unit'
|
||||||
|
@ -67,8 +65,9 @@ func createObject(conf Config) (*statsCtx, error) {
|
||||||
if !checkInterval(conf.LimitDays) {
|
if !checkInterval(conf.LimitDays) {
|
||||||
conf.LimitDays = 1
|
conf.LimitDays = 1
|
||||||
}
|
}
|
||||||
s.limit = conf.LimitDays * 24
|
s.conf = &Config{}
|
||||||
s.conf = conf
|
*s.conf = conf
|
||||||
|
s.conf.limit = conf.LimitDays * 24
|
||||||
if conf.UnitID == nil {
|
if conf.UnitID == nil {
|
||||||
s.conf.UnitID = newUnitID
|
s.conf.UnitID = newUnitID
|
||||||
}
|
}
|
||||||
|
@ -82,7 +81,7 @@ func createObject(conf Config) (*statsCtx, error) {
|
||||||
var udb *unitDB
|
var udb *unitDB
|
||||||
if tx != nil {
|
if tx != nil {
|
||||||
log.Tracef("Deleting old units...")
|
log.Tracef("Deleting old units...")
|
||||||
firstID := id - s.limit - 1
|
firstID := id - s.conf.limit - 1
|
||||||
unitDel := 0
|
unitDel := 0
|
||||||
forEachBkt := func(name []byte, b *bolt.Bucket) error {
|
forEachBkt := func(name []byte, b *bolt.Bucket) error {
|
||||||
id := uint32(btoi(name))
|
id := uint32(btoi(name))
|
||||||
|
@ -243,7 +242,7 @@ func (s *statsCtx) periodicFlush() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ok1 := s.flushUnitToDB(tx, u.id, udb)
|
ok1 := s.flushUnitToDB(tx, u.id, udb)
|
||||||
ok2 := s.deleteUnit(tx, id-s.limit)
|
ok2 := s.deleteUnit(tx, id-s.conf.limit)
|
||||||
if ok1 || ok2 {
|
if ok1 || ok2 {
|
||||||
s.commitTxn(tx)
|
s.commitTxn(tx)
|
||||||
} else {
|
} else {
|
||||||
|
@ -383,12 +382,14 @@ func convertTopArray(a []countPair) []map[string]uint64 {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *statsCtx) setLimit(limitDays int) {
|
func (s *statsCtx) setLimit(limitDays int) {
|
||||||
s.limit = uint32(limitDays) * 24
|
conf := *s.conf
|
||||||
|
conf.limit = uint32(limitDays) * 24
|
||||||
|
s.conf = &conf
|
||||||
log.Debug("Stats: set limit: %d", limitDays)
|
log.Debug("Stats: set limit: %d", limitDays)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *statsCtx) WriteDiskConfig(dc *DiskConfig) {
|
func (s *statsCtx) WriteDiskConfig(dc *DiskConfig) {
|
||||||
dc.Interval = s.limit / 24
|
dc.Interval = s.conf.limit / 24
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *statsCtx) Close() {
|
func (s *statsCtx) Close() {
|
||||||
|
@ -466,7 +467,7 @@ func (s *statsCtx) Update(e Entry) {
|
||||||
s.unitLock.Unlock()
|
s.unitLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *statsCtx) loadUnits() ([]*unitDB, uint32) {
|
func (s *statsCtx) loadUnits(limit uint32) ([]*unitDB, uint32) {
|
||||||
tx := s.beginTxn(false)
|
tx := s.beginTxn(false)
|
||||||
if tx == nil {
|
if tx == nil {
|
||||||
return nil, 0
|
return nil, 0
|
||||||
|
@ -478,7 +479,7 @@ func (s *statsCtx) loadUnits() ([]*unitDB, uint32) {
|
||||||
s.unitLock.Unlock()
|
s.unitLock.Unlock()
|
||||||
|
|
||||||
units := []*unitDB{} //per-hour units
|
units := []*unitDB{} //per-hour units
|
||||||
firstID := curID - s.limit + 1
|
firstID := curID - limit + 1
|
||||||
for i := firstID; i != curID; i++ {
|
for i := firstID; i != curID; i++ {
|
||||||
u := s.loadUnitFromDB(tx, i)
|
u := s.loadUnitFromDB(tx, i)
|
||||||
if u == nil {
|
if u == nil {
|
||||||
|
@ -492,8 +493,8 @@ func (s *statsCtx) loadUnits() ([]*unitDB, uint32) {
|
||||||
|
|
||||||
units = append(units, curUnit)
|
units = append(units, curUnit)
|
||||||
|
|
||||||
if len(units) != int(s.limit) {
|
if len(units) != int(limit) {
|
||||||
log.Fatalf("len(units) != s.limit: %d %d", len(units), s.limit)
|
log.Fatalf("len(units) != limit: %d %d", len(units), limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
return units, firstID
|
return units, firstID
|
||||||
|
@ -527,10 +528,16 @@ func (s *statsCtx) loadUnits() ([]*unitDB, uint32) {
|
||||||
These values are just the sum of data for all units.
|
These values are just the sum of data for all units.
|
||||||
*/
|
*/
|
||||||
// nolint (gocyclo)
|
// nolint (gocyclo)
|
||||||
func (s *statsCtx) getData(timeUnit TimeUnit) map[string]interface{} {
|
func (s *statsCtx) getData() map[string]interface{} {
|
||||||
d := map[string]interface{}{}
|
limit := s.conf.limit
|
||||||
|
|
||||||
units, firstID := s.loadUnits()
|
d := map[string]interface{}{}
|
||||||
|
timeUnit := Hours
|
||||||
|
if limit/24 > 7 {
|
||||||
|
timeUnit = Days
|
||||||
|
}
|
||||||
|
|
||||||
|
units, firstID := s.loadUnits(limit)
|
||||||
if units == nil {
|
if units == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -561,8 +568,8 @@ func (s *statsCtx) getData(timeUnit TimeUnit) map[string]interface{} {
|
||||||
if id <= nextDayID {
|
if id <= nextDayID {
|
||||||
a = append(a, sum)
|
a = append(a, sum)
|
||||||
}
|
}
|
||||||
if len(a) != int(s.limit/24) {
|
if len(a) != int(limit/24) {
|
||||||
log.Fatalf("len(a) != s.limit: %d %d", len(a), s.limit)
|
log.Fatalf("len(a) != limit: %d %d", len(a), limit)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
d["dns_queries"] = a
|
d["dns_queries"] = a
|
||||||
|
@ -705,8 +712,8 @@ func (s *statsCtx) getData(timeUnit TimeUnit) map[string]interface{} {
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *statsCtx) GetTopClientsIP(limit uint) []string {
|
func (s *statsCtx) GetTopClientsIP(maxCount uint) []string {
|
||||||
units, _ := s.loadUnits()
|
units, _ := s.loadUnits(s.conf.limit)
|
||||||
if units == nil {
|
if units == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -718,7 +725,7 @@ func (s *statsCtx) GetTopClientsIP(limit uint) []string {
|
||||||
m[it.Name] += it.Count
|
m[it.Name] += it.Count
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
a := convertMapToArray(m, int(limit))
|
a := convertMapToArray(m, int(maxCount))
|
||||||
d := []string{}
|
d := []string{}
|
||||||
for _, it := range a {
|
for _, it := range a {
|
||||||
d = append(d, it.Name)
|
d = append(d, it.Name)
|
||||||
|
|
Loading…
Reference in a new issue