diff --git a/internal/home/dns.go b/internal/home/dns.go index 0d44160d..572e22fe 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -397,7 +397,7 @@ func startDNSServer() error { Context.queryLog.Start() const topClientsNumber = 100 // the number of clients to get - for _, ip := range Context.stats.GetTopClientsIP(topClientsNumber) { + for _, ip := range Context.stats.TopClientsIP(topClientsNumber) { if ip == nil { continue } @@ -456,7 +456,12 @@ func closeDNSServer() { } if Context.stats != nil { - Context.stats.Close() + err := Context.stats.Close() + if err != nil { + log.Debug("closing stats: %s", err) + } + + // TODO(e.burkov): Find out if it's safe. Context.stats = nil } diff --git a/internal/stats/http.go b/internal/stats/http.go index 033dd3bb..ae980bf3 100644 --- a/internal/stats/http.go +++ b/internal/stats/http.go @@ -5,6 +5,7 @@ package stats import ( "encoding/json" "net/http" + "sync/atomic" "time" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" @@ -15,18 +16,10 @@ import ( // The key is either a client's address or a requested address. type topAddrs = map[string]uint64 -// statsResponse is a response for getting statistics. -type statsResponse struct { +// StatsResp is a response to the GET /control/stats. +type StatsResp struct { TimeUnits string `json:"time_units"` - NumDNSQueries uint64 `json:"num_dns_queries"` - NumBlockedFiltering uint64 `json:"num_blocked_filtering"` - NumReplacedSafebrowsing uint64 `json:"num_replaced_safebrowsing"` - NumReplacedSafesearch uint64 `json:"num_replaced_safesearch"` - NumReplacedParental uint64 `json:"num_replaced_parental"` - - AvgProcessingTime float64 `json:"avg_processing_time"` - TopQueried []topAddrs `json:"top_queried_domains"` TopClients []topAddrs `json:"top_clients"` TopBlocked []topAddrs `json:"top_blocked_domains"` @@ -36,16 +29,22 @@ type statsResponse struct { BlockedFiltering []uint64 `json:"blocked_filtering"` ReplacedSafebrowsing []uint64 `json:"replaced_safebrowsing"` ReplacedParental []uint64 `json:"replaced_parental"` + + NumDNSQueries uint64 `json:"num_dns_queries"` + NumBlockedFiltering uint64 `json:"num_blocked_filtering"` + NumReplacedSafebrowsing uint64 `json:"num_replaced_safebrowsing"` + NumReplacedSafesearch uint64 `json:"num_replaced_safesearch"` + NumReplacedParental uint64 `json:"num_replaced_parental"` + + AvgProcessingTime float64 `json:"avg_processing_time"` } -// handleStats is a handler for getting statistics. +// handleStats handles requests to the GET /control/stats endpoint. func (s *StatsCtx) handleStats(w http.ResponseWriter, r *http.Request) { + limit := atomic.LoadUint32(&s.limitHours) + start := time.Now() - - var resp statsResponse - var ok bool - resp, ok = s.getData() - + resp, ok := s.getData(limit) log.Debug("stats: prepared data in %v", time.Since(start)) if !ok { @@ -61,36 +60,30 @@ func (s *StatsCtx) handleStats(w http.ResponseWriter, r *http.Request) { err := json.NewEncoder(w).Encode(resp) if err != nil { aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err) - - return } } -type config struct { +// configResp is the response to the GET /control/stats_info. +type configResp struct { IntervalDays uint32 `json:"interval"` } -// Get configuration +// handleStatsInfo handles requests to the GET /control/stats_info endpoint. func (s *StatsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) { - resp := config{} - resp.IntervalDays = s.limitHours / 24 + resp := configResp{IntervalDays: atomic.LoadUint32(&s.limitHours) / 24} - data, err := json.Marshal(resp) + w.Header().Set("Content-Type", "application/json") + + err := json.NewEncoder(w).Encode(resp) if err != nil { aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err) - - return - } - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(data) - if err != nil { - aghhttp.Error(r, w, http.StatusInternalServerError, "http write: %s", err) } } -// Set configuration +// handleStatsConfig handles requests to the POST /control/stats_config +// endpoint. func (s *StatsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) { - reqData := config{} + reqData := configResp{} err := json.NewDecoder(r.Body).Decode(&reqData) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err) @@ -108,12 +101,15 @@ func (s *StatsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) { s.configModified() } -// Reset data +// handleStatsReset handles requests to the POST /control/stats_reset endpoint. func (s *StatsCtx) handleStatsReset(w http.ResponseWriter, r *http.Request) { - s.clear() + err := s.clear() + if err != nil { + aghhttp.Error(r, w, http.StatusInternalServerError, "stats: %s", err) + } } -// Register web handlers +// initWeb registers the handlers for web endpoints of statistics module. func (s *StatsCtx) initWeb() { if s.httpRegister == nil { return diff --git a/internal/stats/stats.go b/internal/stats/stats.go index 04a933d4..e483dbba 100644 --- a/internal/stats/stats.go +++ b/internal/stats/stats.go @@ -3,15 +3,20 @@ package stats import ( + "fmt" + "io" "net" + "os" + "sync" + "sync/atomic" + "time" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/log" + "go.etcd.io/bbolt" ) -// UnitIDGenFunc is the signature of a function that generates a unique ID for -// the statistics unit. -type UnitIDGenFunc func() (id uint32) - // DiskConfig is the configuration structure that is stored in file. type DiskConfig struct { // Interval is the number of days for which the statistics are collected @@ -19,6 +24,12 @@ type DiskConfig struct { Interval uint32 `yaml:"statistics_interval"` } +// checkInterval returns true if days is valid to be used as statistics +// retention interval. The valid values are 0, 1, 7, 30 and 90. +func checkInterval(days uint32) (ok bool) { + return days == 0 || days == 1 || days == 7 || days == 30 || days == 90 +} + // Config is the configuration structure for the statistics collecting. type Config struct { // UnitID is the function to generate the identifier for current unit. If @@ -46,58 +57,487 @@ type Interface interface { // Start begins the statistics collecting. Start() - // Close stops the statistics collecting. - Close() + io.Closer // Update collects the incoming statistics data. Update(e Entry) // GetTopClientIP returns at most limit IP addresses corresponding to the // clients with the most number of requests. - GetTopClientsIP(limit uint) []net.IP + TopClientsIP(limit uint) []net.IP // WriteDiskConfig puts the Interface's configuration to the dc. WriteDiskConfig(dc *DiskConfig) } -// TimeUnit is the unit of measuring time while aggregating the statistics. -type TimeUnit int - -// Supported TimeUnit values. -const ( - Hours TimeUnit = iota - Days -) - -// Result is the resulting code of processing the DNS request. -type Result int - -// Supported Result values. +// StatsCtx collects the statistics and flushes it to the database. Its default +// flushing interval is one hour. // -// TODO(e.burkov): Think about better naming. -const ( - RNotFiltered Result = iota + 1 - RFiltered - RSafeBrowsing - RSafeSearch - RParental - - resultLast = RParental + 1 -) - -// Entry is a statistics data entry. -type Entry struct { - // Clients is the client's primary ID. +// TODO(e.burkov): Use atomic.Pointer for accessing db in go1.19. +type StatsCtx struct { + // limitHours is the maximum number of hours to collect statistics into the + // current unit. // - // TODO(a.garipov): Make this a {net.IP, string} enum? - Client string + // It is of type uint32 to be accessed by atomic. It's arranged at the + // beginning of the structure to keep 64-bit alignment. + limitHours uint32 - // Domain is the domain name requested. - Domain string + // currMu protects curr. + currMu *sync.RWMutex + // curr is the actual statistics collection result. + curr *unit - // Result is the result of processing the request. - Result Result + // dbMu protects db. + dbMu *sync.Mutex + // db is the opened statistics database, if any. + db *bbolt.DB - // Time is the duration of the request processing in milliseconds. - Time uint32 + // unitIDGen is the function that generates an identifier for the current + // unit. It's here for only testing purposes. + unitIDGen UnitIDGenFunc + + // httpRegister is used to set HTTP handlers. + httpRegister aghhttp.RegisterFunc + + // configModified is called whenever the configuration is modified via web + // interface. + configModified func() + + // filename is the name of database file. + filename string +} + +var _ Interface = &StatsCtx{} + +// New creates s from conf and properly initializes it. Don't use s before +// calling it's Start method. +func New(conf Config) (s *StatsCtx, err error) { + defer withRecovered(&err) + + s = &StatsCtx{ + currMu: &sync.RWMutex{}, + dbMu: &sync.Mutex{}, + filename: conf.Filename, + configModified: conf.ConfigModified, + httpRegister: conf.HTTPRegister, + } + if s.limitHours = conf.LimitDays * 24; !checkInterval(conf.LimitDays) { + s.limitHours = 24 + } + if s.unitIDGen = newUnitID; conf.UnitID != nil { + s.unitIDGen = conf.UnitID + } + + // TODO(e.burkov): Move the code below to the Start method. + + err = s.openDB() + if err != nil { + return nil, fmt.Errorf("opening database: %w", err) + } + + var udb *unitDB + id := s.unitIDGen() + + tx, err := s.db.Begin(true) + if err != nil { + return nil, fmt.Errorf("stats: opening a transaction: %w", err) + } + + deleted := deleteOldUnits(tx, id-s.limitHours-1) + udb = loadUnitFromDB(tx, id) + + err = finishTxn(tx, deleted > 0) + if err != nil { + log.Error("stats: %s", err) + } + + s.curr = newUnit(id) + s.curr.deserialize(udb) + + log.Debug("stats: initialized") + + return s, nil +} + +// withRecovered turns the value recovered from panic if any into an error and +// combines it with the one pointed by orig. orig must be non-nil. +func withRecovered(orig *error) { + p := recover() + if p == nil { + return + } + + var err error + switch p := p.(type) { + case error: + err = fmt.Errorf("panic: %w", p) + default: + err = fmt.Errorf("panic: recovered value of type %[1]T: %[1]v", p) + } + + *orig = errors.WithDeferred(*orig, err) +} + +// Start implements the Interface interface for *StatsCtx. +func (s *StatsCtx) Start() { + s.initWeb() + + go s.periodicFlush() +} + +// Close implements the io.Closer interface for *StatsCtx. +func (s *StatsCtx) Close() (err error) { + defer func() { err = errors.Annotate(err, "stats: closing: %w") }() + + db := s.swapDatabase(nil) + if db == nil { + return nil + } + defer func() { + cerr := db.Close() + if cerr == nil { + log.Debug("stats: database closed") + } + + err = errors.WithDeferred(err, cerr) + }() + + tx, err := db.Begin(true) + if err != nil { + return fmt.Errorf("opening transaction: %w", err) + } + defer func() { err = errors.WithDeferred(err, finishTxn(tx, err == nil)) }() + + s.currMu.RLock() + defer s.currMu.RUnlock() + + udb := s.curr.serialize() + + return udb.flushUnitToDB(tx, s.curr.id) +} + +// Update implements the Interface interface for *StatsCtx. +func (s *StatsCtx) Update(e Entry) { + if atomic.LoadUint32(&s.limitHours) == 0 { + return + } + + if e.Result == 0 || e.Result >= resultLast || e.Domain == "" || e.Client == "" { + log.Debug("stats: malformed entry") + + return + } + + s.currMu.Lock() + defer s.currMu.Unlock() + + if s.curr == nil { + log.Error("stats: current unit is nil") + + return + } + + clientID := e.Client + if ip := net.ParseIP(clientID); ip != nil { + clientID = ip.String() + } + + s.curr.add(e.Result, e.Domain, clientID, uint64(e.Time)) +} + +// WriteDiskConfig implements the Interface interface for *StatsCtx. +func (s *StatsCtx) WriteDiskConfig(dc *DiskConfig) { + dc.Interval = atomic.LoadUint32(&s.limitHours) / 24 +} + +// TopClientsIP implements the Interface interface for *StatsCtx. +func (s *StatsCtx) TopClientsIP(maxCount uint) (ips []net.IP) { + limit := atomic.LoadUint32(&s.limitHours) + if limit == 0 { + return nil + } + + units, _ := s.loadUnits(limit) + if units == nil { + return nil + } + + // Collect data for all the clients to sort and crop it afterwards. + m := map[string]uint64{} + for _, u := range units { + for _, it := range u.Clients { + m[it.Name] += it.Count + } + } + + a := convertMapToSlice(m, int(maxCount)) + ips = []net.IP{} + for _, it := range a { + ip := net.ParseIP(it.Name) + if ip != nil { + ips = append(ips, ip) + } + } + + return ips +} + +// database returns the database if it's opened. It's safe for concurrent use. +func (s *StatsCtx) database() (db *bbolt.DB) { + s.dbMu.Lock() + defer s.dbMu.Unlock() + + return s.db +} + +// swapDatabase swaps the database with another one and returns it. It's safe +// for concurrent use. +func (s *StatsCtx) swapDatabase(with *bbolt.DB) (old *bbolt.DB) { + s.dbMu.Lock() + defer s.dbMu.Unlock() + + old, s.db = s.db, with + + return old +} + +// deleteOldUnits walks the buckets available to tx and deletes old units. It +// returns the number of deletions performed. +func deleteOldUnits(tx *bbolt.Tx, firstID uint32) (deleted int) { + log.Debug("stats: deleting old units until id %d", firstID) + + // TODO(a.garipov): See if this is actually necessary. Looks like a rather + // bizarre solution. + const errStop errors.Error = "stop iteration" + + walk := func(name []byte, _ *bbolt.Bucket) (err error) { + nameID, ok := unitNameToID(name) + if ok && nameID >= firstID { + return errStop + } + + err = tx.DeleteBucket(name) + if err != nil { + log.Debug("stats: deleting bucket: %s", err) + + return nil + } + + log.Debug("stats: deleted unit %d (name %x)", nameID, name) + + deleted++ + + return nil + } + + err := tx.ForEach(walk) + if err != nil && !errors.Is(err, errStop) { + log.Debug("stats: deleting units: %s", err) + } + + return deleted +} + +// openDB returns an error if the database can't be opened from the specified +// file. It's safe for concurrent use. +func (s *StatsCtx) openDB() (err error) { + log.Debug("stats: opening database") + + var db *bbolt.DB + db, err = bbolt.Open(s.filename, 0o644, nil) + if err != nil { + if err.Error() == "invalid argument" { + log.Error("AdGuard Home cannot be initialized due to an incompatible file system.\nPlease read the explanation here: https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#limitations") + } + + return err + } + + // Use defer to unlock the mutex as soon as possible. + defer log.Debug("stats: database opened") + + s.dbMu.Lock() + defer s.dbMu.Unlock() + + s.db = db + + return nil +} + +func (s *StatsCtx) flush() (cont bool, sleepFor time.Duration) { + id := s.unitIDGen() + + s.currMu.Lock() + defer s.currMu.Unlock() + + ptr := s.curr + if ptr == nil { + return false, 0 + } + + limit := atomic.LoadUint32(&s.limitHours) + if limit == 0 || ptr.id == id { + return true, time.Second + } + + db := s.database() + if db == nil { + return true, 0 + } + + tx, err := db.Begin(true) + if err != nil { + log.Error("stats: opening transaction: %s", err) + + return true, 0 + } + + s.curr = newUnit(id) + isCommitable := true + + ferr := ptr.serialize().flushUnitToDB(tx, ptr.id) + if ferr != nil { + log.Error("stats: flushing unit: %s", ferr) + isCommitable = false + } + + derr := tx.DeleteBucket(idToUnitName(id - limit)) + if derr != nil { + log.Error("stats: deleting unit: %s", derr) + if !errors.Is(derr, bbolt.ErrBucketNotFound) { + isCommitable = false + } + } + + err = finishTxn(tx, isCommitable) + if err != nil { + log.Error("stats: %s", err) + } + + return true, 0 +} + +// periodicFlush checks and flushes the unit to the database if the freshly +// generated unit ID differs from the current's ID. Flushing process includes: +// - swapping the current unit with the new empty one; +// - writing the current unit to the database; +// - removing the stale unit from the database. +func (s *StatsCtx) periodicFlush() { + for cont, sleepFor := true, time.Duration(0); cont; time.Sleep(sleepFor) { + cont, sleepFor = s.flush() + } + + log.Debug("periodic flushing finished") +} + +func (s *StatsCtx) setLimit(limitDays int) { + atomic.StoreUint32(&s.limitHours, uint32(24*limitDays)) + if limitDays == 0 { + if err := s.clear(); err != nil { + log.Error("stats: %s", err) + } + } + + log.Debug("stats: set limit: %d days", limitDays) +} + +// Reset counters and clear database +func (s *StatsCtx) clear() (err error) { + defer func() { err = errors.Annotate(err, "clearing: %w") }() + + db := s.swapDatabase(nil) + if db != nil { + var tx *bbolt.Tx + tx, err = db.Begin(true) + if err != nil { + log.Error("stats: opening a transaction: %s", err) + } else if err = finishTxn(tx, false); err != nil { + // Don't wrap the error since it's informative enough as is. + return err + } + + // Active transactions will continue using database, but new ones won't + // be created. + err = db.Close() + if err != nil { + return fmt.Errorf("closing database: %w", err) + } + + // All active transactions are now closed. + log.Debug("stats: database closed") + } + + err = os.Remove(s.filename) + if err != nil { + log.Error("stats: %s", err) + } + + err = s.openDB() + if err != nil { + log.Error("stats: opening database: %s", err) + } + + // Use defer to unlock the mutex as soon as possible. + defer log.Debug("stats: cleared") + + s.currMu.Lock() + defer s.currMu.Unlock() + + s.curr = newUnit(s.unitIDGen()) + + return nil +} + +func (s *StatsCtx) loadUnits(limit uint32) (units []*unitDB, firstID uint32) { + db := s.database() + if db == nil { + return nil, 0 + } + + // Use writable transaction to ensure any ongoing writable transaction is + // taken into account. + tx, err := db.Begin(true) + if err != nil { + log.Error("stats: opening transaction: %s", err) + + return nil, 0 + } + + s.currMu.RLock() + defer s.currMu.RUnlock() + + cur := s.curr + + var curID uint32 + if cur != nil { + curID = cur.id + } else { + curID = s.unitIDGen() + } + + // Per-hour units. + units = make([]*unitDB, 0, limit) + firstID = curID - limit + 1 + for i := firstID; i != curID; i++ { + u := loadUnitFromDB(tx, i) + if u == nil { + u = &unitDB{NResult: make([]uint64, resultLast)} + } + units = append(units, u) + } + + err = finishTxn(tx, false) + if err != nil { + log.Error("stats: %s", err) + } + + if cur != nil { + units = append(units, cur.serialize()) + } + + if unitsLen := len(units); unitsLen != int(limit) { + log.Fatalf("loaded %d units whilst the desired number is %d", unitsLen, limit) + } + + return units, firstID } diff --git a/internal/stats/stats_internal_test.go b/internal/stats/stats_internal_test.go new file mode 100644 index 00000000..28a556d3 --- /dev/null +++ b/internal/stats/stats_internal_test.go @@ -0,0 +1,26 @@ +package stats + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TODO(e.burkov): Use more realistic data. +func TestStatsCollector(t *testing.T) { + ng := func(_ *unitDB) uint64 { return 0 } + units := make([]*unitDB, 720) + + t.Run("hours", func(t *testing.T) { + statsData := statsCollector(units, 0, Hours, ng) + assert.Len(t, statsData, 720) + }) + + t.Run("days", func(t *testing.T) { + for i := 0; i != 25; i++ { + statsData := statsCollector(units, uint32(i), Days, ng) + require.Lenf(t, statsData, 30, "i=%d", i) + } + }) +} diff --git a/internal/stats/stats_test.go b/internal/stats/stats_test.go index 0cffd2e3..5d86024b 100644 --- a/internal/stats/stats_test.go +++ b/internal/stats/stats_test.go @@ -1,13 +1,17 @@ -package stats +package stats_test import ( + "encoding/json" "fmt" "net" - "os" + "net/http" + "net/http/httptest" + "path/filepath" "sync/atomic" "testing" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -17,147 +21,176 @@ func TestMain(m *testing.M) { aghtest.DiscardLogOutput(m) } -func UIntArrayEquals(a, b []uint64) bool { - if len(a) != len(b) { - return false +// constUnitID is the UnitIDGenFunc which always return 0. +func constUnitID() (id uint32) { return 0 } + +func assertSuccessAndUnmarshal(t *testing.T, to any, handler http.Handler, req *http.Request) { + t.Helper() + + require.NotNil(t, handler) + + rw := httptest.NewRecorder() + + handler.ServeHTTP(rw, req) + require.Equal(t, http.StatusOK, rw.Code) + + data := rw.Body.Bytes() + if to == nil { + assert.Empty(t, data) + + return } - for i := range a { - if a[i] != b[i] { - return false - } - } - - return true + err := json.Unmarshal(data, to) + require.NoError(t, err) } func TestStats(t *testing.T) { - conf := Config{ - Filename: "./stats.db", + cliIP := net.IP{127, 0, 0, 1} + cliIPStr := cliIP.String() + + handlers := map[string]http.Handler{} + conf := stats.Config{ + Filename: filepath.Join(t.TempDir(), "stats.db"), LimitDays: 1, + UnitID: constUnitID, + HTTPRegister: func(_, url string, handler http.HandlerFunc) { + handlers[url] = handler + }, } - s, err := New(conf) + s, err := stats.New(conf) require.NoError(t, err) - testutil.CleanupAndRequireSuccess(t, func() (err error) { - s.clear() - s.Close() - return os.Remove(conf.Filename) + s.Start() + testutil.CleanupAndRequireSuccess(t, s.Close) + + t.Run("data", func(t *testing.T) { + const reqDomain = "domain" + + entries := []stats.Entry{{ + Domain: reqDomain, + Client: cliIPStr, + Result: stats.RFiltered, + Time: 123456, + }, { + Domain: reqDomain, + Client: cliIPStr, + Result: stats.RNotFiltered, + Time: 123456, + }} + + wantData := &stats.StatsResp{ + TimeUnits: "hours", + TopQueried: []map[string]uint64{0: {reqDomain: 1}}, + TopClients: []map[string]uint64{0: {cliIPStr: 2}}, + TopBlocked: []map[string]uint64{0: {reqDomain: 1}}, + DNSQueries: []uint64{ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, + }, + BlockedFiltering: []uint64{ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + }, + ReplacedSafebrowsing: []uint64{ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, + ReplacedParental: []uint64{ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, + NumDNSQueries: 2, + NumBlockedFiltering: 1, + NumReplacedSafebrowsing: 0, + NumReplacedSafesearch: 0, + NumReplacedParental: 0, + AvgProcessingTime: 0.123456, + } + + for _, e := range entries { + s.Update(e) + } + + data := &stats.StatsResp{} + req := httptest.NewRequest(http.MethodGet, "/control/stats", nil) + assertSuccessAndUnmarshal(t, data, handlers["/control/stats"], req) + + assert.Equal(t, wantData, data) }) - s.Update(Entry{ - Domain: "domain", - Client: "127.0.0.1", - Result: RFiltered, - Time: 123456, - }) - s.Update(Entry{ - Domain: "domain", - Client: "127.0.0.1", - Result: RNotFiltered, - Time: 123456, + t.Run("tops", func(t *testing.T) { + topClients := s.TopClientsIP(2) + require.NotEmpty(t, topClients) + + assert.True(t, cliIP.Equal(topClients[0])) }) - d, ok := s.getData() - require.True(t, ok) + t.Run("reset", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/control/stats_reset", nil) + assertSuccessAndUnmarshal(t, nil, handlers["/control/stats_reset"], req) - a := []uint64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2} - assert.True(t, UIntArrayEquals(d.DNSQueries, a)) + _24zeroes := [24]uint64{} + emptyData := &stats.StatsResp{ + TimeUnits: "hours", + TopQueried: []map[string]uint64{}, + TopClients: []map[string]uint64{}, + TopBlocked: []map[string]uint64{}, + DNSQueries: _24zeroes[:], + BlockedFiltering: _24zeroes[:], + ReplacedSafebrowsing: _24zeroes[:], + ReplacedParental: _24zeroes[:], + } - a = []uint64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} - assert.True(t, UIntArrayEquals(d.BlockedFiltering, a)) + req = httptest.NewRequest(http.MethodGet, "/control/stats", nil) + data := &stats.StatsResp{} - a = []uint64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} - assert.True(t, UIntArrayEquals(d.ReplacedSafebrowsing, a)) - - a = []uint64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} - assert.True(t, UIntArrayEquals(d.ReplacedParental, a)) - - m := d.TopQueried - require.NotEmpty(t, m) - assert.EqualValues(t, 1, m[0]["domain"]) - - m = d.TopBlocked - require.NotEmpty(t, m) - assert.EqualValues(t, 1, m[0]["domain"]) - - m = d.TopClients - require.NotEmpty(t, m) - assert.EqualValues(t, 2, m[0]["127.0.0.1"]) - - assert.EqualValues(t, 2, d.NumDNSQueries) - assert.EqualValues(t, 1, d.NumBlockedFiltering) - assert.EqualValues(t, 0, d.NumReplacedSafebrowsing) - assert.EqualValues(t, 0, d.NumReplacedSafesearch) - assert.EqualValues(t, 0, d.NumReplacedParental) - assert.EqualValues(t, 0.123456, d.AvgProcessingTime) - - topClients := s.GetTopClientsIP(2) - require.NotEmpty(t, topClients) - assert.True(t, net.IP{127, 0, 0, 1}.Equal(topClients[0])) + assertSuccessAndUnmarshal(t, data, handlers["/control/stats"], req) + assert.Equal(t, emptyData, data) + }) } func TestLargeNumbers(t *testing.T) { - var hour int32 = 0 - newID := func() uint32 { - // Use "atomic" to make go race detector happy. - return uint32(atomic.LoadInt32(&hour)) + var curHour uint32 = 1 + handlers := map[string]http.Handler{} + + conf := stats.Config{ + Filename: filepath.Join(t.TempDir(), "stats.db"), + LimitDays: 1, + UnitID: func() (id uint32) { return atomic.LoadUint32(&curHour) }, + HTTPRegister: func(_, url string, handler http.HandlerFunc) { handlers[url] = handler }, } - conf := Config{ - Filename: "./stats.db", - LimitDays: 1, - UnitID: newID, - } - s, err := New(conf) + s, err := stats.New(conf) require.NoError(t, err) - testutil.CleanupAndRequireSuccess(t, func() (err error) { - s.Close() - return os.Remove(conf.Filename) - }) + s.Start() + testutil.CleanupAndRequireSuccess(t, s.Close) - // Number of distinct clients and domains every hour. - const n = 1000 + const ( + hoursNum = 12 + cliNumPerHour = 1000 + ) - for h := 0; h < 12; h++ { - atomic.AddInt32(&hour, 1) - for i := 0; i < n; i++ { - s.Update(Entry{ - Domain: fmt.Sprintf("domain%d", i), - Client: net.IP{ - 127, - 0, - byte((i & 0xff00) >> 8), - byte(i & 0xff), - }.String(), - Result: RNotFiltered, + req := httptest.NewRequest(http.MethodGet, "/control/stats", nil) + + for h := 0; h < hoursNum; h++ { + atomic.AddUint32(&curHour, 1) + + for i := 0; i < cliNumPerHour; i++ { + ip := net.IP{127, 0, byte((i & 0xff00) >> 8), byte(i & 0xff)} + e := stats.Entry{ + Domain: fmt.Sprintf("domain%d.hour%d", i, h), + Client: ip.String(), + Result: stats.RNotFiltered, Time: 123456, - }) + } + s.Update(e) } } - d, ok := s.getData() - require.True(t, ok) - assert.EqualValues(t, hour*n, d.NumDNSQueries) -} - -func TestStatsCollector(t *testing.T) { - ng := func(_ *unitDB) uint64 { - return 0 - } - units := make([]*unitDB, 720) - - t.Run("hours", func(t *testing.T) { - statsData := statsCollector(units, 0, Hours, ng) - assert.Len(t, statsData, 720) - }) - - t.Run("days", func(t *testing.T) { - for i := 0; i != 25; i++ { - statsData := statsCollector(units, uint32(i), Days, ng) - require.Lenf(t, statsData, 30, "i=%d", i) - } - }) + data := &stats.StatsResp{} + assertSuccessAndUnmarshal(t, data, handlers["/control/stats"], req) + assert.Equal(t, hoursNum*cliNumPerHour, int(data.NumDNSQueries)) } diff --git a/internal/stats/unit.go b/internal/stats/unit.go index 6d32a6d1..28e0b2bc 100644 --- a/internal/stats/unit.go +++ b/internal/stats/unit.go @@ -5,14 +5,9 @@ import ( "encoding/binary" "encoding/gob" "fmt" - "net" - "os" "sort" - "sync" - "sync/atomic" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "go.etcd.io/bbolt" @@ -22,51 +17,65 @@ import ( // inspection. Improve logging. Decrease complexity. const ( - maxDomains = 100 // max number of top domains to store in file or return via Get() - maxClients = 100 // max number of top clients to store in file or return via Get() + // maxDomains is the max number of top domains to return. + maxDomains = 100 + // maxClients is the max number of top clients to return. + maxClients = 100 ) -// StatsCtx collects the statistics and flushes it to the database. Its default -// flushing interval is one hour. +// UnitIDGenFunc is the signature of a function that generates a unique ID for +// the statistics unit. +type UnitIDGenFunc func() (id uint32) + +// TimeUnit is the unit of measuring time while aggregating the statistics. +type TimeUnit int + +// Supported TimeUnit values. +const ( + Hours TimeUnit = iota + Days +) + +// Result is the resulting code of processing the DNS request. +type Result int + +// Supported Result values. // -// TODO(e.burkov): Use atomic.Pointer for accessing curr and db in go1.19. -type StatsCtx struct { - // currMu protects the current unit. - currMu *sync.Mutex - // curr is the actual statistics collection result. - curr *unit +// TODO(e.burkov): Think about better naming. +const ( + RNotFiltered Result = iota + 1 + RFiltered + RSafeBrowsing + RSafeSearch + RParental - // dbMu protects db. - dbMu *sync.Mutex - // db is the opened statistics database, if any. - db *bbolt.DB + resultLast = RParental + 1 +) - // unitIDGen is the function that generates an identifier for the current - // unit. It's here for only testing purposes. - unitIDGen UnitIDGenFunc +// Entry is a statistics data entry. +type Entry struct { + // Clients is the client's primary ID. + // + // TODO(a.garipov): Make this a {net.IP, string} enum? + Client string - // httpRegister is used to set HTTP handlers. - httpRegister aghhttp.RegisterFunc + // Domain is the domain name requested. + Domain string - // configModified is called whenever the configuration is modified via web - // interface. - configModified func() + // Result is the result of processing the request. + Result Result - // filename is the name of database file. - filename string - - // limitHours is the maximum number of hours to collect statistics into the - // current unit. - limitHours uint32 + // Time is the duration of the request processing in milliseconds. + Time uint32 } // unit collects the statistics data for a specific period of time. type unit struct { - // mu protects all the fields of a unit. - mu *sync.RWMutex - // id is the unique unit's identifier. It's set to an absolute hour number // since the beginning of UNIX time by the default ID generating function. + // + // Must not be rewritten after creating to be accessed concurrently without + // using mu. id uint32 // nTotal stores the total number of requests. @@ -86,44 +95,15 @@ type unit struct { clients map[string]uint64 } -// ongoing returns the current unit. It's safe for concurrent use. -// -// Note that the unit itself should be locked before accessing. -func (s *StatsCtx) ongoing() (u *unit) { - s.currMu.Lock() - defer s.currMu.Unlock() - - return s.curr -} - -// swapCurrent swaps the current unit with another and returns it. It's safe -// for concurrent use. -func (s *StatsCtx) swapCurrent(with *unit) (old *unit) { - s.currMu.Lock() - defer s.currMu.Unlock() - - old, s.curr = s.curr, with - - return old -} - -// database returns the database if it's opened. It's safe for concurrent use. -func (s *StatsCtx) database() (db *bbolt.DB) { - s.dbMu.Lock() - defer s.dbMu.Unlock() - - return s.db -} - -// swapDatabase swaps the database with another one and returns it. It's safe -// for concurrent use. -func (s *StatsCtx) swapDatabase(with *bbolt.DB) (old *bbolt.DB) { - s.dbMu.Lock() - defer s.dbMu.Unlock() - - old, s.db = s.db, with - - return old +// newUnit allocates the new *unit. +func newUnit(id uint32) (u *unit) { + return &unit{ + id: id, + nResult: make([]uint64, resultLast), + domains: make(map[string]uint64), + blockedDomains: make(map[string]uint64), + clients: make(map[string]uint64), + } } // countPair is a single name-number pair for deserializing statistics data into @@ -133,7 +113,7 @@ type countPair struct { Count uint64 } -// unitDB is the structure for deserializing statistics data into the database. +// unitDB is the structure for serializing statistics data into the database. type unitDB struct { // NTotal is the total number of requests. NTotal uint64 @@ -152,157 +132,6 @@ type unitDB struct { TimeAvg uint32 } -// withRecovered turns the value recovered from panic if any into an error and -// combines it with the one pointed by orig. orig must be non-nil. -func withRecovered(orig *error) { - p := recover() - if p == nil { - return - } - - var err error - switch p := p.(type) { - case error: - err = fmt.Errorf("panic: %w", p) - default: - err = fmt.Errorf("panic: recovered value of type %[1]T: %[1]v", p) - } - - *orig = errors.WithDeferred(*orig, err) -} - -// isEnabled is a helper that check if the statistics collecting is enabled. -func (s *StatsCtx) isEnabled() (ok bool) { - return atomic.LoadUint32(&s.limitHours) != 0 -} - -// New creates s from conf and properly initializes it. Don't use s before -// calling it's Start method. -func New(conf Config) (s *StatsCtx, err error) { - defer withRecovered(&err) - - s = &StatsCtx{ - currMu: &sync.Mutex{}, - dbMu: &sync.Mutex{}, - filename: conf.Filename, - configModified: conf.ConfigModified, - httpRegister: conf.HTTPRegister, - } - if s.limitHours = conf.LimitDays * 24; !checkInterval(conf.LimitDays) { - s.limitHours = 24 - } - if s.unitIDGen = newUnitID; conf.UnitID != nil { - s.unitIDGen = conf.UnitID - } - - if err = s.dbOpen(); err != nil { - return nil, fmt.Errorf("opening database: %w", err) - } - - id := s.unitIDGen() - tx := beginTxn(s.db, true) - var udb *unitDB - if tx != nil { - log.Tracef("Deleting old units...") - firstID := id - s.limitHours - 1 - unitDel := 0 - - err = tx.ForEach(newBucketWalker(tx, &unitDel, firstID)) - if err != nil && !errors.Is(err, errStop) { - log.Debug("stats: deleting units: %s", err) - } - - udb = s.loadUnitFromDB(tx, id) - - if unitDel != 0 { - s.commitTxn(tx) - } else { - err = tx.Rollback() - if err != nil { - log.Debug("rolling back: %s", err) - } - } - } - - u := newUnit(id) - // This use of deserialize is safe since the accessed unit has just been - // created. - u.deserialize(udb) - s.curr = u - - log.Debug("stats: initialized") - - return s, nil -} - -// TODO(a.garipov): See if this is actually necessary. Looks like a rather -// bizarre solution. -const errStop errors.Error = "stop iteration" - -// newBucketWalker returns a new bucket walker that deletes old units. The -// integer that unitDelPtr points to is incremented for every successful -// deletion. If the bucket isn't deleted, f returns errStop. -func newBucketWalker( - tx *bbolt.Tx, - unitDelPtr *int, - firstID uint32, -) (f func(name []byte, b *bbolt.Bucket) (err error)) { - return func(name []byte, _ *bbolt.Bucket) (err error) { - nameID, ok := unitNameToID(name) - if !ok || nameID < firstID { - err = tx.DeleteBucket(name) - if err != nil { - log.Debug("stats: tx.DeleteBucket: %s", err) - - return nil - } - - log.Debug("stats: deleted unit %d (name %x)", nameID, name) - - *unitDelPtr++ - - return nil - } - - return errStop - } -} - -// Start makes s process the incoming data. -func (s *StatsCtx) Start() { - s.initWeb() - go s.periodicFlush() -} - -// checkInterval returns true if days is valid to be used as statistics -// retention interval. The valid values are 0, 1, 7, 30 and 90. -func checkInterval(days uint32) (ok bool) { - return days == 0 || days == 1 || days == 7 || days == 30 || days == 90 -} - -// dbOpen returns an error if the database can't be opened from the specified -// file. It's safe for concurrent use. -func (s *StatsCtx) dbOpen() (err error) { - log.Tracef("db.Open...") - - s.dbMu.Lock() - defer s.dbMu.Unlock() - - s.db, err = bbolt.Open(s.filename, 0o644, nil) - if err != nil { - log.Error("stats: open DB: %s: %s", s.filename, err) - if err.Error() == "invalid argument" { - log.Error("AdGuard Home cannot be initialized due to an incompatible file system.\nPlease read the explanation here: https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#limitations") - } - - return err - } - - log.Tracef("db.Open") - - return nil -} - // newUnitID is the default UnitIDGenFunc that generates the unique id hourly. func newUnitID() (id uint32) { const secsInHour = int64(time.Hour / time.Second) @@ -310,50 +139,14 @@ func newUnitID() (id uint32) { return uint32(time.Now().Unix() / secsInHour) } -// newUnit allocates the new *unit. -func newUnit(id uint32) (u *unit) { - return &unit{ - mu: &sync.RWMutex{}, - id: id, - nResult: make([]uint64, resultLast), - domains: make(map[string]uint64), - blockedDomains: make(map[string]uint64), - clients: make(map[string]uint64), - } -} - -// beginTxn opens a new database transaction. If writable is true, the -// transaction will be opened for writing, and for reading otherwise. It -// returns nil if the transaction can't be created. -func beginTxn(db *bbolt.DB, writable bool) (tx *bbolt.Tx) { - if db == nil { - return nil +func finishTxn(tx *bbolt.Tx, commit bool) (err error) { + if commit { + err = errors.Annotate(tx.Commit(), "committing: %w") + } else { + err = errors.Annotate(tx.Rollback(), "rolling back: %w") } - log.Tracef("opening a database transaction") - - tx, err := db.Begin(writable) - if err != nil { - log.Error("stats: opening a transaction: %s", err) - - return nil - } - - log.Tracef("transaction has been opened") - - return tx -} - -// commitTxn applies the changes made in tx to the database. -func (s *StatsCtx) commitTxn(tx *bbolt.Tx) { - err := tx.Commit() - if err != nil { - log.Error("stats: committing a transaction: %s", err) - - return - } - - log.Tracef("transaction has been committed") + return err } // bucketNameLen is the length of a bucket, a 64-bit unsigned integer. @@ -380,88 +173,34 @@ func unitNameToID(name []byte) (id uint32, ok bool) { return uint32(binary.BigEndian.Uint64(name)), true } -// Flush the current unit to DB and delete an old unit when a new hour is started -// If a unit must be flushed: -// . lock DB -// . atomically set a new empty unit as the current one and get the old unit -// This is important to do it inside DB lock, so the reader won't get inconsistent results. -// . write the unit to DB -// . remove the stale unit from DB -// . unlock DB -func (s *StatsCtx) periodicFlush() { - for ptr := s.ongoing(); ptr != nil; ptr = s.ongoing() { - id := s.unitIDGen() - // Access the unit's ID with atomic to avoid locking the whole unit. - if !s.isEnabled() || atomic.LoadUint32(&ptr.id) == id { - time.Sleep(time.Second) - - continue - } - - tx := beginTxn(s.database(), true) - - nu := newUnit(id) - u := s.swapCurrent(nu) - udb := u.serialize() - - if tx == nil { - continue - } - - flushOK := flushUnitToDB(tx, u.id, udb) - delOK := s.deleteUnit(tx, id-atomic.LoadUint32(&s.limitHours)) - if flushOK || delOK { - s.commitTxn(tx) - } else { - _ = tx.Rollback() - } - } - - log.Tracef("periodicFlush() exited") -} - -// deleteUnit removes the unit by it's id from the database the tx belongs to. -func (s *StatsCtx) deleteUnit(tx *bbolt.Tx, id uint32) bool { - err := tx.DeleteBucket(idToUnitName(id)) - if err != nil { - log.Tracef("stats: bolt DeleteBucket: %s", err) - - return false - } - - log.Debug("stats: deleted unit %d", id) - - return true -} - -func convertMapToSlice(m map[string]uint64, max int) []countPair { - a := []countPair{} +func convertMapToSlice(m map[string]uint64, max int) (s []countPair) { + s = make([]countPair, 0, len(m)) for k, v := range m { - a = append(a, countPair{Name: k, Count: v}) + s = append(s, countPair{Name: k, Count: v}) } - less := func(i, j int) bool { - return a[j].Count < a[i].Count + + sort.Slice(s, func(i, j int) bool { + return s[j].Count < s[i].Count + }) + if max > len(s) { + max = len(s) } - sort.Slice(a, less) - if max > len(a) { - max = len(a) - } - return a[:max] + + return s[:max] } -func convertSliceToMap(a []countPair) map[string]uint64 { - m := map[string]uint64{} +func convertSliceToMap(a []countPair) (m map[string]uint64) { + m = map[string]uint64{} for _, it := range a { m[it.Name] = it.Count } + return m } -// serialize converts u to the *unitDB. It's safe for concurrent use. +// serialize converts u to the *unitDB. It's safe for concurrent use. u must +// not be nil. func (u *unit) serialize() (udb *unitDB) { - u.mu.RLock() - defer u.mu.RUnlock() - var timeAvg uint32 = 0 if u.nTotal != 0 { timeAvg = uint32(u.timeSum / u.nTotal) @@ -477,6 +216,28 @@ func (u *unit) serialize() (udb *unitDB) { } } +func loadUnitFromDB(tx *bbolt.Tx, id uint32) (udb *unitDB) { + bkt := tx.Bucket(idToUnitName(id)) + if bkt == nil { + return nil + } + + log.Tracef("Loading unit %d", id) + + var buf bytes.Buffer + buf.Write(bkt.Get([]byte{0})) + udb = &unitDB{} + + err := gob.NewDecoder(&buf).Decode(udb) + if err != nil { + log.Error("gob Decode: %s", err) + + return nil + } + + return udb +} + // deserealize assigns the appropriate values from udb to u. u must not be nil. // It's safe for concurrent use. func (u *unit) deserialize(udb *unitDB) { @@ -484,9 +245,6 @@ func (u *unit) deserialize(udb *unitDB) { return } - u.mu.Lock() - defer u.mu.Unlock() - u.nTotal = udb.NTotal u.nResult = make([]uint64, resultLast) copy(u.nResult, udb.NResult) @@ -496,51 +254,41 @@ func (u *unit) deserialize(udb *unitDB) { u.timeSum = uint64(udb.TimeAvg) * udb.NTotal } -func flushUnitToDB(tx *bbolt.Tx, id uint32, udb *unitDB) bool { - log.Tracef("Flushing unit %d", id) +// add adds new data to u. It's safe for concurrent use. +func (u *unit) add(res Result, domain, cli string, dur uint64) { + u.nResult[res]++ + if res == RNotFiltered { + u.domains[domain]++ + } else { + u.blockedDomains[domain]++ + } + + u.clients[cli]++ + u.timeSum += dur + u.nTotal++ +} + +// flushUnitToDB puts udb to the database at id. +func (udb *unitDB) flushUnitToDB(tx *bbolt.Tx, id uint32) (err error) { + log.Debug("stats: flushing unit with id %d and total of %d", id, udb.NTotal) bkt, err := tx.CreateBucketIfNotExists(idToUnitName(id)) if err != nil { - log.Error("tx.CreateBucketIfNotExists: %s", err) - return false + return fmt.Errorf("creating bucket: %w", err) } - var buf bytes.Buffer - enc := gob.NewEncoder(&buf) - err = enc.Encode(udb) + buf := &bytes.Buffer{} + err = gob.NewEncoder(buf).Encode(udb) if err != nil { - log.Error("gob.Encode: %s", err) - return false + return fmt.Errorf("encoding unit: %w", err) } err = bkt.Put([]byte{0}, buf.Bytes()) if err != nil { - log.Error("bkt.Put: %s", err) - return false + return fmt.Errorf("putting unit to database: %w", err) } - return true -} - -func (s *StatsCtx) loadUnitFromDB(tx *bbolt.Tx, id uint32) *unitDB { - bkt := tx.Bucket(idToUnitName(id)) - if bkt == nil { - return nil - } - - // log.Tracef("Loading unit %d", id) - - var buf bytes.Buffer - buf.Write(bkt.Get([]byte{0})) - dec := gob.NewDecoder(&buf) - udb := unitDB{} - err := dec.Decode(&udb) - if err != nil { - log.Error("gob Decode: %s", err) - return nil - } - - return &udb + return nil } func convertTopSlice(a []countPair) (m []map[string]uint64) { @@ -552,144 +300,6 @@ func convertTopSlice(a []countPair) (m []map[string]uint64) { return m } -func (s *StatsCtx) setLimit(limitDays int) { - atomic.StoreUint32(&s.limitHours, uint32(24*limitDays)) - if limitDays == 0 { - s.clear() - } - - log.Debug("stats: set limit: %d days", limitDays) -} - -func (s *StatsCtx) WriteDiskConfig(dc *DiskConfig) { - dc.Interval = atomic.LoadUint32(&s.limitHours) / 24 -} - -func (s *StatsCtx) Close() { - u := s.swapCurrent(nil) - - db := s.database() - if tx := beginTxn(db, true); tx != nil { - udb := u.serialize() - if flushUnitToDB(tx, u.id, udb) { - s.commitTxn(tx) - } else { - _ = tx.Rollback() - } - } - - if db != nil { - log.Tracef("db.Close...") - _ = db.Close() - log.Tracef("db.Close") - } - - log.Debug("stats: closed") -} - -// Reset counters and clear database -func (s *StatsCtx) clear() { - db := s.database() - tx := beginTxn(db, true) - if tx != nil { - _ = s.swapDatabase(nil) - _ = tx.Rollback() - // the active transactions can continue using database, - // but no new transactions will be opened - _ = db.Close() - log.Tracef("db.Close") - // all active transactions are now closed - } - - u := newUnit(s.unitIDGen()) - _ = s.swapCurrent(u) - - err := os.Remove(s.filename) - if err != nil { - log.Error("os.Remove: %s", err) - } - - _ = s.dbOpen() - - log.Debug("stats: cleared") -} - -func (s *StatsCtx) Update(e Entry) { - if !s.isEnabled() { - return - } - - if e.Result == 0 || - e.Result >= resultLast || - e.Domain == "" || - e.Client == "" { - return - } - - clientID := e.Client - if ip := net.ParseIP(clientID); ip != nil { - clientID = ip.String() - } - - u := s.ongoing() - if u == nil { - return - } - - u.mu.Lock() - defer u.mu.Unlock() - - u.nResult[e.Result]++ - if e.Result == RNotFiltered { - u.domains[e.Domain]++ - } else { - u.blockedDomains[e.Domain]++ - } - - u.clients[clientID]++ - u.timeSum += uint64(e.Time) - u.nTotal++ -} - -func (s *StatsCtx) loadUnits(limit uint32) ([]*unitDB, uint32) { - tx := beginTxn(s.database(), false) - if tx == nil { - return nil, 0 - } - - cur := s.ongoing() - var curID uint32 - if cur != nil { - curID = atomic.LoadUint32(&cur.id) - } else { - curID = s.unitIDGen() - } - - // Per-hour units. - units := []*unitDB{} - firstID := curID - limit + 1 - for i := firstID; i != curID; i++ { - u := s.loadUnitFromDB(tx, i) - if u == nil { - u = &unitDB{} - u.NResult = make([]uint64, resultLast) - } - units = append(units, u) - } - - _ = tx.Rollback() - - if cur != nil { - units = append(units, cur.serialize()) - } - - if len(units) != int(limit) { - log.Fatalf("len(units) != limit: %d %d", len(units), limit) - } - - return units, firstID -} - // numsGetter is a signature for statsCollector argument. type numsGetter func(u *unitDB) (num uint64) @@ -697,6 +307,7 @@ type numsGetter func(u *unitDB) (num uint64) // timeUnit using ng to retrieve data. func statsCollector(units []*unitDB, firstID uint32, timeUnit TimeUnit, ng numsGetter) (nums []uint64) { if timeUnit == Hours { + nums = make([]uint64, 0, len(units)) for _, u := range units { nums = append(nums, ng(u)) } @@ -738,6 +349,7 @@ func topsCollector(units []*unitDB, max int, pg pairsGetter) []map[string]uint64 } } a2 := convertMapToSlice(m, max) + return convertTopSlice(a2) } @@ -768,10 +380,9 @@ func topsCollector(units []*unitDB, max int, pg pairsGetter) []map[string]uint64 * parental-blocked These values are just the sum of data for all units. */ -func (s *StatsCtx) getData() (statsResponse, bool) { - limit := atomic.LoadUint32(&s.limitHours) +func (s *StatsCtx) getData(limit uint32) (StatsResp, bool) { if limit == 0 { - return statsResponse{ + return StatsResp{ TimeUnits: "days", TopBlocked: []topAddrs{}, @@ -792,7 +403,7 @@ func (s *StatsCtx) getData() (statsResponse, bool) { units, firstID := s.loadUnits(limit) if units == nil { - return statsResponse{}, false + return StatsResp{}, false } dnsQueries := statsCollector(units, firstID, timeUnit, func(u *unitDB) (num uint64) { return u.NTotal }) @@ -800,7 +411,7 @@ func (s *StatsCtx) getData() (statsResponse, bool) { log.Fatalf("len(dnsQueries) != limit: %d %d", len(dnsQueries), limit) } - data := statsResponse{ + data := StatsResp{ DNSQueries: dnsQueries, BlockedFiltering: statsCollector(units, firstID, timeUnit, func(u *unitDB) (num uint64) { return u.NResult[RFiltered] }), ReplacedSafebrowsing: statsCollector(units, firstID, timeUnit, func(u *unitDB) (num uint64) { return u.NResult[RSafeBrowsing] }), @@ -844,31 +455,3 @@ func (s *StatsCtx) getData() (statsResponse, bool) { return data, true } - -func (s *StatsCtx) GetTopClientsIP(maxCount uint) []net.IP { - if !s.isEnabled() { - return nil - } - - units, _ := s.loadUnits(atomic.LoadUint32(&s.limitHours)) - if units == nil { - return nil - } - - // top clients - m := map[string]uint64{} - for _, u := range units { - for _, it := range u.Clients { - m[it.Name] += it.Count - } - } - a := convertMapToSlice(m, int(maxCount)) - d := []net.IP{} - for _, it := range a { - ip := net.ParseIP(it.Name) - if ip != nil { - d = append(d, ip) - } - } - return d -}