package stats import ( "bytes" "encoding/json" "net/http" "net/http/httptest" "path/filepath" "testing" "time" "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestHandleStatsConfig(t *testing.T) { const ( smallIvl = 1 * time.Minute minIvl = 1 * time.Hour maxIvl = 365 * timeutil.Day ) conf := Config{ UnitID: func() (id uint32) { return 0 }, ConfigModified: func() {}, ShouldCountClient: func([]string) bool { return true }, Filename: filepath.Join(t.TempDir(), "stats.db"), Limit: time.Hour * 24, Enabled: true, } testCases := []struct { name string wantErr string body getConfigResp wantCode int }{{ name: "set_ivl_1_minIvl", body: getConfigResp{ Enabled: aghalg.NBTrue, Interval: float64(minIvl.Milliseconds()), Ignored: []string{}, }, wantCode: http.StatusOK, wantErr: "", }, { name: "small_interval", body: getConfigResp{ Enabled: aghalg.NBTrue, Interval: float64(smallIvl.Milliseconds()), Ignored: []string{}, }, wantCode: http.StatusUnprocessableEntity, wantErr: "unsupported interval: less than an hour\n", }, { name: "big_interval", body: getConfigResp{ Enabled: aghalg.NBTrue, Interval: float64(maxIvl.Milliseconds() + minIvl.Milliseconds()), Ignored: []string{}, }, wantCode: http.StatusUnprocessableEntity, wantErr: "unsupported interval: more than a year\n", }, { name: "set_ignored_ivl_1_maxIvl", body: getConfigResp{ Enabled: aghalg.NBTrue, Interval: float64(maxIvl.Milliseconds()), Ignored: []string{ "ignor.ed", }, }, wantCode: http.StatusOK, wantErr: "", }, { name: "enabled_is_null", body: getConfigResp{ Enabled: aghalg.NBNull, Interval: float64(minIvl.Milliseconds()), Ignored: []string{}, }, wantCode: http.StatusUnprocessableEntity, wantErr: "enabled is null\n", }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { s, err := New(conf) require.NoError(t, err) s.Start() testutil.CleanupAndRequireSuccess(t, s.Close) buf, err := json.Marshal(tc.body) require.NoError(t, err) const ( configGet = "/control/stats/config" configPut = "/control/stats/config/update" ) req := httptest.NewRequest(http.MethodPut, configPut, bytes.NewReader(buf)) rw := httptest.NewRecorder() s.handlePutStatsConfig(rw, req) require.Equal(t, tc.wantCode, rw.Code) if tc.wantCode != http.StatusOK { assert.Equal(t, tc.wantErr, rw.Body.String()) return } resp := httptest.NewRequest(http.MethodGet, configGet, nil) rw = httptest.NewRecorder() s.handleGetStatsConfig(rw, resp) require.Equal(t, http.StatusOK, rw.Code) ans := getConfigResp{} err = json.Unmarshal(rw.Body.Bytes(), &ans) require.NoError(t, err) assert.Equal(t, tc.body, ans) }) } }