cherry-pick: 4358 fix stats

Merge in DNS/adguard-home from 4358-fix-stats to master

Updates #4358.
Updates #4342.

Squashed commit of the following:

commit 5683cb304688ea639e5ba7f219a7bf12370211a4
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Aug 4 18:20:54 2022 +0300

    stats: rm races test

commit 63dd67650ed64eaf9685b955a4fdf3c0067a7f8c
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Aug 4 17:13:36 2022 +0300

    stats: try to imp test

commit 59a0f249fc00566872db62e362c87bc0c201b333
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Aug 4 16:38:57 2022 +0300

    stats: fix nil ptr deref

commit 7fc3ff18a34a1d0e0fec3ca83a33f499ac752572
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Apr 7 16:02:51 2022 +0300

    stats: fix races finally, imp tests

commit c63f5f4e7929819fe79b3a1e392f6b91cd630846
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Aug 4 00:56:49 2022 +0300

    aghhttp: add register func

commit 61adc7f0e95279c1b7f4a0c0af5ab387ee461411
Merge: edbdb2d4 9b3adac1
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Aug 4 00:36:01 2022 +0300

    Merge branch 'master' into 4358-fix-stats

commit edbdb2d4c6a06dcbf8107a28c4c3a61ba394e907
Merge: a91e4d7a a481ff4c
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Aug 3 21:00:42 2022 +0300

    Merge branch 'master' into 4358-fix-stats

commit a91e4d7af13591eeef45cb7980d1ebc1650a5cb7
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Aug 3 18:46:19 2022 +0300

    stats: imp code, docs

commit c5f3814c5c1a734ca8ff6726cc9ffc1177a055cf
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Aug 3 18:16:13 2022 +0300

    all: log changes

commit 5e6caafc771dddc4c6be07c34658de359106fbe5
Merge: 091ba756 eb8e8166
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Aug 3 18:09:10 2022 +0300

    Merge branch 'master' into 4358-fix-stats

commit 091ba75618d3689b9c04f05431283417c8cc52f9
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Aug 3 18:07:39 2022 +0300

    stats: imp docs, code

commit f2b2de77ce5f0448d6df9232a614a3710f1e2e8a
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Aug 2 17:09:30 2022 +0300

    all: refactor stats & add mutexes

commit b3f11c455ceaa3738ec20eefc46f866ff36ed046
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Apr 27 15:30:09 2022 +0300

    WIP
This commit is contained in:
Eugene Burkov 2022-08-04 19:05:28 +03:00 committed by Ainar Garipov
parent 56dc3eab02
commit 39b404be19
15 changed files with 434 additions and 300 deletions

View file

@ -9,6 +9,12 @@ import (
"github.com/AdguardTeam/golibs/log"
)
// RegisterFunc is the function that sets the handler to handle the URL for the
// method.
//
// TODO(e.burkov, a.garipov): Get rid of it.
type RegisterFunc func(method, url string, handler http.HandlerFunc)
// OK responds with word OK.
func OK(w http.ResponseWriter) {
if _, err := io.WriteString(w, "OK\n"); err != nil {

View file

@ -5,11 +5,11 @@ import (
"encoding/json"
"fmt"
"net"
"net/http"
"path/filepath"
"runtime"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
)
@ -126,7 +126,7 @@ type ServerConfig struct {
ConfigModified func() `yaml:"-"`
// Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"`
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
Enabled bool `yaml:"enabled"`
InterfaceName string `yaml:"interface_name"`

View file

@ -5,12 +5,12 @@ import (
"crypto/x509"
"fmt"
"net"
"net/http"
"os"
"sort"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
@ -191,7 +191,7 @@ type ServerConfig struct {
ConfigModified func()
// Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request))
HTTPRegister aghhttp.RegisterFunc
// ResolveClients signals if the RDNS should resolve clients' addresses.
ResolveClients bool

View file

@ -61,7 +61,7 @@ type Server struct {
dnsFilter *filtering.DNSFilter // DNS filter instance
dhcpServer dhcpd.ServerInterface // DHCP server instance (optional)
queryLog querylog.QueryLog // Query log instance
stats stats.Stats
stats stats.Interface
access *accessCtx
// localDomainSuffix is the suffix used to detect internal hosts. It
@ -107,7 +107,7 @@ const defaultLocalDomainSuffix = "lan"
// DNSCreateParams are parameters to create a new server.
type DNSCreateParams struct {
DNSFilter *filtering.DNSFilter
Stats stats.Stats
Stats stats.Interface
QueryLog querylog.QueryLog
DHCPServer dhcpd.ServerInterface
PrivateNets netutil.SubnetSet

View file

@ -34,7 +34,7 @@ func (l *testQueryLog) Add(p *querylog.AddParams) {
type testStats struct {
// Stats is embedded here simply to make testStats a stats.Stats without
// actually implementing all methods.
stats.Stats
stats.Interface
lastEntry stats.Entry
}

View file

@ -6,7 +6,6 @@ import (
"fmt"
"io/fs"
"net"
"net/http"
"os"
"runtime"
"runtime/debug"
@ -14,6 +13,7 @@ import (
"sync"
"sync/atomic"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache"
@ -94,7 +94,7 @@ type Config struct {
ConfigModified func() `yaml:"-"`
// Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"`
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
// CustomResolver is the resolver used by DNSFilter.
CustomResolver Resolver `yaml:"-"`

View file

@ -2,6 +2,7 @@ package home
import (
"bytes"
"encoding"
"fmt"
"net"
"sort"
@ -60,6 +61,33 @@ const (
ClientSourceHostsFile
)
var _ fmt.Stringer = clientSource(0)
// String returns a human-readable name of cs.
func (cs clientSource) String() (s string) {
switch cs {
case ClientSourceWHOIS:
return "WHOIS"
case ClientSourceARP:
return "ARP"
case ClientSourceRDNS:
return "rDNS"
case ClientSourceDHCP:
return "DHCP"
case ClientSourceHostsFile:
return "etc/hosts"
default:
return ""
}
}
var _ encoding.TextMarshaler = clientSource(0)
// MarshalText implements encoding.TextMarshaler for the clientSource.
func (cs clientSource) MarshalText() (text []byte, err error) {
return []byte(cs.String()), nil
}
// clientSourceConf is used to configure where the runtime clients will be
// obtained from.
type clientSourcesConf struct {
@ -397,6 +425,7 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
c.Tags = stringutil.CloneSlice(c.Tags)
c.BlockedServices = stringutil.CloneSlice(c.BlockedServices)
c.Upstreams = stringutil.CloneSlice(c.Upstreams)
return c, true
}

View file

@ -47,9 +47,9 @@ type clientJSON struct {
type runtimeClientJSON struct {
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
Name string `json:"name"`
Source string `json:"source"`
IP net.IP `json:"ip"`
Name string `json:"name"`
Source clientSource `json:"source"`
IP net.IP `json:"ip"`
}
type clientListJSON struct {
@ -81,20 +81,9 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
cj := runtimeClientJSON{
WHOISInfo: rc.WHOISInfo,
Name: rc.Host,
IP: ip,
}
cj.Source = "etc/hosts"
switch rc.Source {
case ClientSourceDHCP:
cj.Source = "DHCP"
case ClientSourceRDNS:
cj.Source = "rDNS"
case ClientSourceARP:
cj.Source = "ARP"
case ClientSourceWHOIS:
cj.Source = "WHOIS"
Name: rc.Host,
Source: rc.Source,
IP: ip,
}
data.RuntimeClients = append(data.RuntimeClients, cj)
@ -107,13 +96,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
w.Header().Set("Content-Type", "application/json")
e := json.NewEncoder(w).Encode(data)
if e != nil {
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Failed to encode to json: %v",
e,
)
aghhttp.Error(r, w, http.StatusInternalServerError, "failed to encode to json: %v", e)
return
}
@ -279,9 +262,9 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj *clientJSON) {
rc, ok := clients.FindRuntimeClient(ip)
if !ok {
// It is still possible that the IP used to be in the runtime
// clients list, but then the server was reloaded. So, check
// the DNS server's blocked IP list.
// It is still possible that the IP used to be in the runtime clients
// list, but then the server was reloaded. So, check the DNS server's
// blocked IP list.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)

View file

@ -189,7 +189,7 @@ func registerControlHandlers() {
RegisterAuthHandlers()
}
func httpRegister(method, url string, handler func(http.ResponseWriter, *http.Request)) {
func httpRegister(method, url string, handler http.HandlerFunc) {
if method == "" {
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
Context.mux.HandleFunc(url, postInstall(handler))

View file

@ -46,7 +46,7 @@ type homeContext struct {
// --
clients clientsContainer // per-client-settings module
stats stats.Stats // statistics module
stats stats.Interface // statistics module
queryLog querylog.QueryLog // query log module
dnsServer *dnsforward.Server // DNS module
rdns *RDNS // rDNS module

View file

@ -2,10 +2,10 @@ package querylog
import (
"net"
"net/http"
"path/filepath"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/errors"
@ -38,7 +38,7 @@ type Config struct {
ConfigModified func()
// HTTPRegister registers an HTTP handler.
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request))
HTTPRegister aghhttp.RegisterFunc
// FindClient returns client information by their IDs.
FindClient func(ids []string) (c *Client, err error)

View file

@ -39,34 +39,21 @@ type statsResponse struct {
}
// handleStats is a handler for getting statistics.
func (s *statsCtx) handleStats(w http.ResponseWriter, r *http.Request) {
func (s *StatsCtx) handleStats(w http.ResponseWriter, r *http.Request) {
start := time.Now()
var resp statsResponse
if s.conf.limit == 0 {
resp = statsResponse{
TimeUnits: "days",
var ok bool
resp, ok = s.getData()
TopBlocked: []topAddrs{},
TopClients: []topAddrs{},
TopQueried: []topAddrs{},
log.Debug("stats: prepared data in %v", time.Since(start))
BlockedFiltering: []uint64{},
DNSQueries: []uint64{},
ReplacedParental: []uint64{},
ReplacedSafebrowsing: []uint64{},
}
} else {
var ok bool
resp, ok = s.getData()
if !ok {
// Don't bring the message to the lower case since it's a part of UI
// text for the moment.
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get statistics data")
log.Debug("stats: prepared data in %v", time.Since(start))
if !ok {
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get statistics data")
return
}
return
}
w.Header().Set("Content-Type", "application/json")
@ -84,9 +71,9 @@ type config struct {
}
// Get configuration
func (s *statsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) {
func (s *StatsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) {
resp := config{}
resp.IntervalDays = s.conf.limit / 24
resp.IntervalDays = s.limitHours / 24
data, err := json.Marshal(resp)
if err != nil {
@ -102,7 +89,7 @@ func (s *statsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) {
}
// Set configuration
func (s *statsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) {
func (s *StatsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) {
reqData := config{}
err := json.NewDecoder(r.Body).Decode(&reqData)
if err != nil {
@ -118,22 +105,22 @@ func (s *statsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) {
}
s.setLimit(int(reqData.IntervalDays))
s.conf.ConfigModified()
s.configModified()
}
// Reset data
func (s *statsCtx) handleStatsReset(w http.ResponseWriter, r *http.Request) {
func (s *StatsCtx) handleStatsReset(w http.ResponseWriter, r *http.Request) {
s.clear()
}
// Register web handlers
func (s *statsCtx) initWeb() {
if s.conf.HTTPRegister == nil {
func (s *StatsCtx) initWeb() {
if s.httpRegister == nil {
return
}
s.conf.HTTPRegister(http.MethodGet, "/control/stats", s.handleStats)
s.conf.HTTPRegister(http.MethodPost, "/control/stats_reset", s.handleStatsReset)
s.conf.HTTPRegister(http.MethodPost, "/control/stats_config", s.handleStatsConfig)
s.conf.HTTPRegister(http.MethodGet, "/control/stats_info", s.handleStatsInfo)
s.httpRegister(http.MethodGet, "/control/stats", s.handleStats)
s.httpRegister(http.MethodPost, "/control/stats_reset", s.handleStatsReset)
s.httpRegister(http.MethodPost, "/control/stats_config", s.handleStatsConfig)
s.httpRegister(http.MethodGet, "/control/stats_info", s.handleStatsInfo)
}

View file

@ -4,75 +4,85 @@ package stats
import (
"net"
"net/http"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
)
type unitIDCallback func() uint32
// UnitIDGenFunc is the signature of a function that generates a unique ID for
// the statistics unit.
type UnitIDGenFunc func() (id uint32)
// DiskConfig - configuration settings that are stored on disk
// DiskConfig is the configuration structure that is stored in file.
type DiskConfig struct {
Interval uint32 `yaml:"statistics_interval"` // time interval for statistics (in days)
// Interval is the number of days for which the statistics are collected
// before flushing to the database.
Interval uint32 `yaml:"statistics_interval"`
}
// Config - module configuration
// Config is the configuration structure for the statistics collecting.
type Config struct {
Filename string // database file name
LimitDays uint32 // time limit (in days)
UnitID unitIDCallback // user function to get the current unit ID. If nil, the current time hour is used.
// UnitID is the function to generate the identifier for current unit. If
// nil, the default function is used, see newUnitID.
UnitID UnitIDGenFunc
// Called when the configuration is changed by HTTP request
// ConfigModified will be called each time the configuration changed via web
// interface.
ConfigModified func()
// Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request))
// HTTPRegister is the function that registers handlers for the stats
// endpoints.
HTTPRegister aghhttp.RegisterFunc
limit uint32 // maximum time we need to keep data for (in hours)
// Filename is the name of the database file.
Filename string
// LimitDays is the maximum number of days to collect statistics into the
// current unit.
LimitDays uint32
}
// New - create object
func New(conf Config) (Stats, error) {
return createObject(conf)
}
// Stats - main interface
type Stats interface {
// Interface is the statistics interface to be used by other packages.
type Interface interface {
// Start begins the statistics collecting.
Start()
// Close object.
// This function is not thread safe
// (can't be called in parallel with any other function of this interface).
// Close stops the statistics collecting.
Close()
// Update counters
// Update collects the incoming statistics data.
Update(e Entry)
// Get IP addresses of the clients with the most number of requests
// GetTopClientIP returns at most limit IP addresses corresponding to the
// clients with the most number of requests.
GetTopClientsIP(limit uint) []net.IP
// WriteDiskConfig - write configuration
// WriteDiskConfig puts the Interface's configuration to the dc.
WriteDiskConfig(dc *DiskConfig)
}
// TimeUnit - time unit
// TimeUnit is the unit of measuring time while aggregating the statistics.
type TimeUnit int
// Supported time units
// Supported TimeUnit values.
const (
Hours TimeUnit = iota
Days
)
// Result of DNS request processing
// Result is the resulting code of processing the DNS request.
type Result int
// Supported result values
// Supported Result values.
//
// TODO(e.burkov): Think about better naming.
const (
RNotFiltered Result = iota + 1
RFiltered
RSafeBrowsing
RSafeSearch
RParental
rLast
resultLast = RParental + 1
)
// Entry is a statistics data entry.
@ -82,7 +92,12 @@ type Entry struct {
// TODO(a.garipov): Make this a {net.IP, string} enum?
Client string
// Domain is the domain name requested.
Domain string
// Result is the result of processing the request.
Result Result
Time uint32 // processing time (msec)
// Time is the duration of the request processing in milliseconds.
Time uint32
}

View file

@ -37,7 +37,7 @@ func TestStats(t *testing.T) {
LimitDays: 1,
}
s, err := createObject(conf)
s, err := New(conf)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
s.clear()
@ -110,7 +110,7 @@ func TestLargeNumbers(t *testing.T) {
LimitDays: 1,
UnitID: newID,
}
s, err := createObject(conf)
s, err := New(conf)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
s.Close()

View file

@ -9,11 +9,13 @@ import (
"os"
"sort"
"sync"
"sync/atomic"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
bolt "go.etcd.io/bbolt"
"go.etcd.io/bbolt"
)
// TODO(a.garipov): Rewrite all of this. Add proper error handling and
@ -24,47 +26,130 @@ const (
maxClients = 100 // max number of top clients to store in file or return via Get()
)
// statsCtx - global context
type statsCtx struct {
// mu protects unit.
mu *sync.Mutex
// current is the actual statistics collection result.
current *unit
// StatsCtx collects the statistics and flushes it to the database. Its default
// flushing interval is one hour.
//
// TODO(e.burkov): Use atomic.Pointer for accessing curr and db in go1.19.
type StatsCtx struct {
// currMu protects the current unit.
currMu *sync.Mutex
// curr is the actual statistics collection result.
curr *unit
db *bolt.DB
conf *Config
// dbMu protects db.
dbMu *sync.Mutex
// db is the opened statistics database, if any.
db *bbolt.DB
// unitIDGen is the function that generates an identifier for the current
// unit. It's here for only testing purposes.
unitIDGen UnitIDGenFunc
// httpRegister is used to set HTTP handlers.
httpRegister aghhttp.RegisterFunc
// configModified is called whenever the configuration is modified via web
// interface.
configModified func()
// filename is the name of database file.
filename string
// limitHours is the maximum number of hours to collect statistics into the
// current unit.
limitHours uint32
}
// data for 1 time unit
// unit collects the statistics data for a specific period of time.
type unit struct {
id uint32 // unit ID. Default: absolute hour since Jan 1, 1970
// mu protects all the fields of a unit.
mu *sync.RWMutex
nTotal uint64 // total requests
nResult []uint64 // number of requests per one result
timeSum uint64 // sum of processing time of all requests (usec)
// id is the unique unit's identifier. It's set to an absolute hour number
// since the beginning of UNIX time by the default ID generating function.
id uint32
// top:
domains map[string]uint64 // number of requests per domain
blockedDomains map[string]uint64 // number of blocked requests per domain
clients map[string]uint64 // number of requests per client
// nTotal stores the total number of requests.
nTotal uint64
// nResult stores the number of requests grouped by it's result.
nResult []uint64
// timeSum stores the sum of processing time in milliseconds of each request
// written by the unit.
timeSum uint64
// domains stores the number of requests for each domain.
domains map[string]uint64
// blockedDomains stores the number of requests for each domain that has
// been blocked.
blockedDomains map[string]uint64
// clients stores the number of requests from each client.
clients map[string]uint64
}
// name-count pair
// ongoing returns the current unit. It's safe for concurrent use.
//
// Note that the unit itself should be locked before accessing.
func (s *StatsCtx) ongoing() (u *unit) {
s.currMu.Lock()
defer s.currMu.Unlock()
return s.curr
}
// swapCurrent swaps the current unit with another and returns it. It's safe
// for concurrent use.
func (s *StatsCtx) swapCurrent(with *unit) (old *unit) {
s.currMu.Lock()
defer s.currMu.Unlock()
old, s.curr = s.curr, with
return old
}
// database returns the database if it's opened. It's safe for concurrent use.
func (s *StatsCtx) database() (db *bbolt.DB) {
s.dbMu.Lock()
defer s.dbMu.Unlock()
return s.db
}
// swapDatabase swaps the database with another one and returns it. It's safe
// for concurrent use.
func (s *StatsCtx) swapDatabase(with *bbolt.DB) (old *bbolt.DB) {
s.dbMu.Lock()
defer s.dbMu.Unlock()
old, s.db = s.db, with
return old
}
// countPair is a single name-number pair for deserializing statistics data into
// the database.
type countPair struct {
Name string
Count uint64
}
// structure for storing data in file
// unitDB is the structure for deserializing statistics data into the database.
type unitDB struct {
NTotal uint64
// NTotal is the total number of requests.
NTotal uint64
// NResult is the number of requests by the result's kind.
NResult []uint64
Domains []countPair
// Domains is the number of requests for each domain name.
Domains []countPair
// BlockedDomains is the number of requests blocked for each domain name.
BlockedDomains []countPair
Clients []countPair
// Clients is the number of requests from each client.
Clients []countPair
TimeAvg uint32 // usec
// TimeAvg is the average of processing times in milliseconds of all the
// requests in the unit.
TimeAvg uint32
}
// withRecovered turns the value recovered from panic if any into an error and
@ -86,34 +171,40 @@ func withRecovered(orig *error) {
*orig = errors.WithDeferred(*orig, err)
}
// createObject creates s from conf and properly initializes it.
func createObject(conf Config) (s *statsCtx, err error) {
// isEnabled is a helper that check if the statistics collecting is enabled.
func (s *StatsCtx) isEnabled() (ok bool) {
return atomic.LoadUint32(&s.limitHours) != 0
}
// New creates s from conf and properly initializes it. Don't use s before
// calling it's Start method.
func New(conf Config) (s *StatsCtx, err error) {
defer withRecovered(&err)
s = &statsCtx{
mu: &sync.Mutex{},
s = &StatsCtx{
currMu: &sync.Mutex{},
dbMu: &sync.Mutex{},
filename: conf.Filename,
configModified: conf.ConfigModified,
httpRegister: conf.HTTPRegister,
}
if !checkInterval(conf.LimitDays) {
conf.LimitDays = 1
if s.limitHours = conf.LimitDays * 24; !checkInterval(conf.LimitDays) {
s.limitHours = 24
}
if s.unitIDGen = newUnitID; conf.UnitID != nil {
s.unitIDGen = conf.UnitID
}
s.conf = &Config{}
*s.conf = conf
s.conf.limit = conf.LimitDays * 24
if conf.UnitID == nil {
s.conf.UnitID = newUnitID
if err = s.dbOpen(); err != nil {
return nil, fmt.Errorf("opening database: %w", err)
}
if !s.dbOpen() {
return nil, fmt.Errorf("open database")
}
id := s.conf.UnitID()
tx := s.beginTxn(true)
id := s.unitIDGen()
tx := beginTxn(s.db, true)
var udb *unitDB
if tx != nil {
log.Tracef("Deleting old units...")
firstID := id - s.conf.limit - 1
firstID := id - s.limitHours - 1
unitDel := 0
err = tx.ForEach(newBucketWalker(tx, &unitDel, firstID))
@ -133,12 +224,11 @@ func createObject(conf Config) (s *statsCtx, err error) {
}
}
u := unit{}
s.initUnit(&u, id)
if udb != nil {
deserialize(&u, udb)
}
s.current = &u
u := newUnit(id)
// This use of deserialize is safe since the accessed unit has just been
// created.
u.deserialize(udb)
s.curr = u
log.Debug("stats: initialized")
@ -153,11 +243,11 @@ const errStop errors.Error = "stop iteration"
// integer that unitDelPtr points to is incremented for every successful
// deletion. If the bucket isn't deleted, f returns errStop.
func newBucketWalker(
tx *bolt.Tx,
tx *bbolt.Tx,
unitDelPtr *int,
firstID uint32,
) (f func(name []byte, b *bolt.Bucket) (err error)) {
return func(name []byte, _ *bolt.Bucket) (err error) {
) (f func(name []byte, b *bbolt.Bucket) (err error)) {
return func(name []byte, _ *bbolt.Bucket) (err error) {
nameID, ok := unitNameToID(name)
if !ok || nameID < firstID {
err = tx.DeleteBucket(name)
@ -178,80 +268,92 @@ func newBucketWalker(
}
}
func (s *statsCtx) Start() {
// Start makes s process the incoming data.
func (s *StatsCtx) Start() {
s.initWeb()
go s.periodicFlush()
}
func checkInterval(days uint32) bool {
// checkInterval returns true if days is valid to be used as statistics
// retention interval. The valid values are 0, 1, 7, 30 and 90.
func checkInterval(days uint32) (ok bool) {
return days == 0 || days == 1 || days == 7 || days == 30 || days == 90
}
func (s *statsCtx) dbOpen() bool {
var err error
// dbOpen returns an error if the database can't be opened from the specified
// file. It's safe for concurrent use.
func (s *StatsCtx) dbOpen() (err error) {
log.Tracef("db.Open...")
s.db, err = bolt.Open(s.conf.Filename, 0o644, nil)
s.dbMu.Lock()
defer s.dbMu.Unlock()
s.db, err = bbolt.Open(s.filename, 0o644, nil)
if err != nil {
log.Error("stats: open DB: %s: %s", s.conf.Filename, err)
log.Error("stats: open DB: %s: %s", s.filename, err)
if err.Error() == "invalid argument" {
log.Error("AdGuard Home cannot be initialized due to an incompatible file system.\nPlease read the explanation here: https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#limitations")
}
return false
return err
}
log.Tracef("db.Open")
return true
return nil
}
// Atomically swap the currently active unit with a new value
// Return old value
func (s *statsCtx) swapUnit(new *unit) (u *unit) {
s.mu.Lock()
defer s.mu.Unlock()
// newUnitID is the default UnitIDGenFunc that generates the unique id hourly.
func newUnitID() (id uint32) {
const secsInHour = int64(time.Hour / time.Second)
u = s.current
s.current = new
return u
return uint32(time.Now().Unix() / secsInHour)
}
// Get unit ID for the current hour
func newUnitID() uint32 {
return uint32(time.Now().Unix() / (60 * 60))
// newUnit allocates the new *unit.
func newUnit(id uint32) (u *unit) {
return &unit{
mu: &sync.RWMutex{},
id: id,
nResult: make([]uint64, resultLast),
domains: make(map[string]uint64),
blockedDomains: make(map[string]uint64),
clients: make(map[string]uint64),
}
}
// Initialize a unit
func (s *statsCtx) initUnit(u *unit, id uint32) {
u.id = id
u.nResult = make([]uint64, rLast)
u.domains = make(map[string]uint64)
u.blockedDomains = make(map[string]uint64)
u.clients = make(map[string]uint64)
}
// Open a DB transaction
func (s *statsCtx) beginTxn(wr bool) *bolt.Tx {
db := s.db
// beginTxn opens a new database transaction. If writable is true, the
// transaction will be opened for writing, and for reading otherwise. It
// returns nil if the transaction can't be created.
func beginTxn(db *bbolt.DB, writable bool) (tx *bbolt.Tx) {
if db == nil {
return nil
}
log.Tracef("db.Begin...")
tx, err := db.Begin(wr)
log.Tracef("opening a database transaction")
tx, err := db.Begin(writable)
if err != nil {
log.Error("db.Begin: %s", err)
log.Error("stats: opening a transaction: %s", err)
return nil
}
log.Tracef("db.Begin")
log.Tracef("transaction has been opened")
return tx
}
func (s *statsCtx) commitTxn(tx *bolt.Tx) {
// commitTxn applies the changes made in tx to the database.
func (s *StatsCtx) commitTxn(tx *bbolt.Tx) {
err := tx.Commit()
if err != nil {
log.Debug("tx.Commit: %s", err)
log.Error("stats: committing a transaction: %s", err)
return
}
log.Tracef("tx.Commit")
log.Tracef("transaction has been committed")
}
// bucketNameLen is the length of a bucket, a 64-bit unsigned integer.
@ -262,10 +364,10 @@ const bucketNameLen = 8
// idToUnitName converts a numerical ID into a database unit name.
func idToUnitName(id uint32) (name []byte) {
name = make([]byte, bucketNameLen)
binary.BigEndian.PutUint64(name, uint64(id))
n := [bucketNameLen]byte{}
binary.BigEndian.PutUint64(n[:], uint64(id))
return name
return n[:]
}
// unitNameToID converts a database unit name into a numerical ID. ok is false
@ -278,13 +380,6 @@ func unitNameToID(name []byte) (id uint32, ok bool) {
return uint32(binary.BigEndian.Uint64(name)), true
}
func (s *statsCtx) ongoing() (u *unit) {
s.mu.Lock()
defer s.mu.Unlock()
return s.current
}
// Flush the current unit to DB and delete an old unit when a new hour is started
// If a unit must be flushed:
// . lock DB
@ -293,34 +388,29 @@ func (s *statsCtx) ongoing() (u *unit) {
// . write the unit to DB
// . remove the stale unit from DB
// . unlock DB
func (s *statsCtx) periodicFlush() {
for {
ptr := s.ongoing()
if ptr == nil {
break
}
id := s.conf.UnitID()
if ptr.id == id || s.conf.limit == 0 {
func (s *StatsCtx) periodicFlush() {
for ptr := s.ongoing(); ptr != nil; ptr = s.ongoing() {
id := s.unitIDGen()
// Access the unit's ID with atomic to avoid locking the whole unit.
if !s.isEnabled() || atomic.LoadUint32(&ptr.id) == id {
time.Sleep(time.Second)
continue
}
tx := s.beginTxn(true)
tx := beginTxn(s.database(), true)
nu := unit{}
s.initUnit(&nu, id)
u := s.swapUnit(&nu)
udb := serialize(u)
nu := newUnit(id)
u := s.swapCurrent(nu)
udb := u.serialize()
if tx == nil {
continue
}
ok1 := s.flushUnitToDB(tx, u.id, udb)
ok2 := s.deleteUnit(tx, id-s.conf.limit)
if ok1 || ok2 {
flushOK := flushUnitToDB(tx, u.id, udb)
delOK := s.deleteUnit(tx, id-atomic.LoadUint32(&s.limitHours))
if flushOK || delOK {
s.commitTxn(tx)
} else {
_ = tx.Rollback()
@ -330,8 +420,8 @@ func (s *statsCtx) periodicFlush() {
log.Tracef("periodicFlush() exited")
}
// Delete unit's data from file
func (s *statsCtx) deleteUnit(tx *bolt.Tx, id uint32) bool {
// deleteUnit removes the unit by it's id from the database the tx belongs to.
func (s *StatsCtx) deleteUnit(tx *bbolt.Tx, id uint32) bool {
err := tx.DeleteBucket(idToUnitName(id))
if err != nil {
log.Tracef("stats: bolt DeleteBucket: %s", err)
@ -347,10 +437,7 @@ func (s *statsCtx) deleteUnit(tx *bolt.Tx, id uint32) bool {
func convertMapToSlice(m map[string]uint64, max int) []countPair {
a := []countPair{}
for k, v := range m {
pair := countPair{}
pair.Name = k
pair.Count = v
a = append(a, pair)
a = append(a, countPair{Name: k, Count: v})
}
less := func(i, j int) bool {
return a[j].Count < a[i].Count
@ -370,41 +457,46 @@ func convertSliceToMap(a []countPair) map[string]uint64 {
return m
}
func serialize(u *unit) *unitDB {
udb := unitDB{}
udb.NTotal = u.nTotal
udb.NResult = append(udb.NResult, u.nResult...)
// serialize converts u to the *unitDB. It's safe for concurrent use.
func (u *unit) serialize() (udb *unitDB) {
u.mu.RLock()
defer u.mu.RUnlock()
var timeAvg uint32 = 0
if u.nTotal != 0 {
udb.TimeAvg = uint32(u.timeSum / u.nTotal)
timeAvg = uint32(u.timeSum / u.nTotal)
}
udb.Domains = convertMapToSlice(u.domains, maxDomains)
udb.BlockedDomains = convertMapToSlice(u.blockedDomains, maxDomains)
udb.Clients = convertMapToSlice(u.clients, maxClients)
return &udb
return &unitDB{
NTotal: u.nTotal,
NResult: append([]uint64{}, u.nResult...),
Domains: convertMapToSlice(u.domains, maxDomains),
BlockedDomains: convertMapToSlice(u.blockedDomains, maxDomains),
Clients: convertMapToSlice(u.clients, maxClients),
TimeAvg: timeAvg,
}
}
func deserialize(u *unit, udb *unitDB) {
// deserealize assigns the appropriate values from udb to u. u must not be nil.
// It's safe for concurrent use.
func (u *unit) deserialize(udb *unitDB) {
if udb == nil {
return
}
u.mu.Lock()
defer u.mu.Unlock()
u.nTotal = udb.NTotal
n := len(udb.NResult)
if n < len(u.nResult) {
n = len(u.nResult) // n = min(len(udb.NResult), len(u.nResult))
}
for i := 1; i < n; i++ {
u.nResult[i] = udb.NResult[i]
}
u.nResult = make([]uint64, resultLast)
copy(u.nResult, udb.NResult)
u.domains = convertSliceToMap(udb.Domains)
u.blockedDomains = convertSliceToMap(udb.BlockedDomains)
u.clients = convertSliceToMap(udb.Clients)
u.timeSum = uint64(udb.TimeAvg) * u.nTotal
u.timeSum = uint64(udb.TimeAvg) * udb.NTotal
}
func (s *statsCtx) flushUnitToDB(tx *bolt.Tx, id uint32, udb *unitDB) bool {
func flushUnitToDB(tx *bbolt.Tx, id uint32, udb *unitDB) bool {
log.Tracef("Flushing unit %d", id)
bkt, err := tx.CreateBucketIfNotExists(idToUnitName(id))
@ -430,7 +522,7 @@ func (s *statsCtx) flushUnitToDB(tx *bolt.Tx, id uint32, udb *unitDB) bool {
return true
}
func (s *statsCtx) loadUnitFromDB(tx *bolt.Tx, id uint32) *unitDB {
func (s *StatsCtx) loadUnitFromDB(tx *bbolt.Tx, id uint32) *unitDB {
bkt := tx.Bucket(idToUnitName(id))
if bkt == nil {
return nil
@ -451,44 +543,44 @@ func (s *statsCtx) loadUnitFromDB(tx *bolt.Tx, id uint32) *unitDB {
return &udb
}
func convertTopSlice(a []countPair) []map[string]uint64 {
m := []map[string]uint64{}
func convertTopSlice(a []countPair) (m []map[string]uint64) {
m = make([]map[string]uint64, 0, len(a))
for _, it := range a {
ent := map[string]uint64{}
ent[it.Name] = it.Count
m = append(m, ent)
m = append(m, map[string]uint64{it.Name: it.Count})
}
return m
}
func (s *statsCtx) setLimit(limitDays int) {
s.conf.limit = uint32(limitDays) * 24
func (s *StatsCtx) setLimit(limitDays int) {
atomic.StoreUint32(&s.limitHours, uint32(24*limitDays))
if limitDays == 0 {
s.clear()
}
log.Debug("stats: set limit: %d", limitDays)
log.Debug("stats: set limit: %d days", limitDays)
}
func (s *statsCtx) WriteDiskConfig(dc *DiskConfig) {
dc.Interval = s.conf.limit / 24
func (s *StatsCtx) WriteDiskConfig(dc *DiskConfig) {
dc.Interval = atomic.LoadUint32(&s.limitHours) / 24
}
func (s *statsCtx) Close() {
u := s.swapUnit(nil)
udb := serialize(u)
tx := s.beginTxn(true)
if tx != nil {
if s.flushUnitToDB(tx, u.id, udb) {
func (s *StatsCtx) Close() {
u := s.swapCurrent(nil)
db := s.database()
if tx := beginTxn(db, true); tx != nil {
udb := u.serialize()
if flushUnitToDB(tx, u.id, udb) {
s.commitTxn(tx)
} else {
_ = tx.Rollback()
}
}
if s.db != nil {
if db != nil {
log.Tracef("db.Close...")
_ = s.db.Close()
_ = db.Close()
log.Tracef("db.Close")
}
@ -496,11 +588,11 @@ func (s *statsCtx) Close() {
}
// Reset counters and clear database
func (s *statsCtx) clear() {
tx := s.beginTxn(true)
func (s *StatsCtx) clear() {
db := s.database()
tx := beginTxn(db, true)
if tx != nil {
db := s.db
s.db = nil
_ = s.swapDatabase(nil)
_ = tx.Rollback()
// the active transactions can continue using database,
// but no new transactions will be opened
@ -509,11 +601,10 @@ func (s *statsCtx) clear() {
// all active transactions are now closed
}
u := unit{}
s.initUnit(&u, s.conf.UnitID())
_ = s.swapUnit(&u)
u := newUnit(s.unitIDGen())
_ = s.swapCurrent(u)
err := os.Remove(s.conf.Filename)
err := os.Remove(s.filename)
if err != nil {
log.Error("os.Remove: %s", err)
}
@ -523,13 +614,13 @@ func (s *statsCtx) clear() {
log.Debug("stats: cleared")
}
func (s *statsCtx) Update(e Entry) {
if s.conf.limit == 0 {
func (s *StatsCtx) Update(e Entry) {
if !s.isEnabled() {
return
}
if e.Result == 0 ||
e.Result >= rLast ||
e.Result >= resultLast ||
e.Domain == "" ||
e.Client == "" {
return
@ -540,13 +631,15 @@ func (s *statsCtx) Update(e Entry) {
clientID = ip.String()
}
s.mu.Lock()
defer s.mu.Unlock()
u := s.ongoing()
if u == nil {
return
}
u := s.current
u.mu.Lock()
defer u.mu.Unlock()
u.nResult[e.Result]++
if e.Result == RNotFiltered {
u.domains[e.Domain]++
} else {
@ -558,14 +651,19 @@ func (s *statsCtx) Update(e Entry) {
u.nTotal++
}
func (s *statsCtx) loadUnits(limit uint32) ([]*unitDB, uint32) {
tx := s.beginTxn(false)
func (s *StatsCtx) loadUnits(limit uint32) ([]*unitDB, uint32) {
tx := beginTxn(s.database(), false)
if tx == nil {
return nil, 0
}
cur := s.ongoing()
curID := cur.id
var curID uint32
if cur != nil {
curID = atomic.LoadUint32(&cur.id)
} else {
curID = s.unitIDGen()
}
// Per-hour units.
units := []*unitDB{}
@ -574,14 +672,16 @@ func (s *statsCtx) loadUnits(limit uint32) ([]*unitDB, uint32) {
u := s.loadUnitFromDB(tx, i)
if u == nil {
u = &unitDB{}
u.NResult = make([]uint64, rLast)
u.NResult = make([]uint64, resultLast)
}
units = append(units, u)
}
_ = tx.Rollback()
units = append(units, serialize(cur))
if cur != nil {
units = append(units, cur.serialize())
}
if len(units) != int(limit) {
log.Fatalf("len(units) != limit: %d %d", len(units), limit)
@ -628,13 +728,13 @@ func statsCollector(units []*unitDB, firstID uint32, timeUnit TimeUnit, ng numsG
// pairsGetter is a signature for topsCollector argument.
type pairsGetter func(u *unitDB) (pairs []countPair)
// topsCollector collects statistics about highest values fro the given *unitDB
// topsCollector collects statistics about highest values from the given *unitDB
// slice using pg to retrieve data.
func topsCollector(units []*unitDB, max int, pg pairsGetter) []map[string]uint64 {
m := map[string]uint64{}
for _, u := range units {
for _, it := range pg(u) {
m[it.Name] += it.Count
for _, cp := range pg(u) {
m[cp.Name] += cp.Count
}
}
a2 := convertMapToSlice(m, max)
@ -668,8 +768,22 @@ func topsCollector(units []*unitDB, max int, pg pairsGetter) []map[string]uint64
* parental-blocked
These values are just the sum of data for all units.
*/
func (s *statsCtx) getData() (statsResponse, bool) {
limit := s.conf.limit
func (s *StatsCtx) getData() (statsResponse, bool) {
limit := atomic.LoadUint32(&s.limitHours)
if limit == 0 {
return statsResponse{
TimeUnits: "days",
TopBlocked: []topAddrs{},
TopClients: []topAddrs{},
TopQueried: []topAddrs{},
BlockedFiltering: []uint64{},
DNSQueries: []uint64{},
ReplacedParental: []uint64{},
ReplacedSafebrowsing: []uint64{},
}, true
}
timeUnit := Hours
if limit/24 > 7 {
@ -698,7 +812,7 @@ func (s *statsCtx) getData() (statsResponse, bool) {
// Total counters:
sum := unitDB{
NResult: make([]uint64, rLast),
NResult: make([]uint64, resultLast),
}
timeN := 0
for _, u := range units {
@ -731,12 +845,12 @@ func (s *statsCtx) getData() (statsResponse, bool) {
return data, true
}
func (s *statsCtx) GetTopClientsIP(maxCount uint) []net.IP {
if s.conf.limit == 0 {
func (s *StatsCtx) GetTopClientsIP(maxCount uint) []net.IP {
if !s.isEnabled() {
return nil
}
units, _ := s.loadUnits(s.conf.limit)
units, _ := s.loadUnits(atomic.LoadUint32(&s.limitHours))
if units == nil {
return nil
}